1+ import numpy as np
2+ from sklearn .metrics import confusion_matrix , precision_score , recall_score
3+ import matplotlib .pyplot as plt
4+ import matplotlib .patches as ptch
5+
6+ # Appendix A - working with single threshold
7+ pred_scores = [0.7 , 0.3 , 0.5 , 0.6 , 0.55 , 0.9 , 0.4 , 0.2 , 0.4 , 0.3 ]
8+ y_true = ["positive" , "negative" , "negative" , "positive" , "positive" , "positive" , "negative" , "positive" , "negative" , "positive" ]
9+
10+ # To convert the scores into a class label, a threshold is used.
11+ # When the score is equal to or above the threshold, the sample is classified as one class.
12+ # Otherwise, it is classified as the other class.
13+ # Suppose a sample is Positive if its score is above or equal to the threshold. Otherwise, it is Negative.
14+ # The next block of code converts the scores into class labels with a threshold of 0.5.
15+
16+ threshold = 0.5
17+
18+ y_pred = ["positive" if score >= threshold else "negative" for score in pred_scores ]
19+ print (y_pred )
20+
21+ r = np .flip (confusion_matrix (y_true , y_pred ))
22+ print ("\n # Confusion Matrix (From Left to Right & Top to Bottom: \n True Positive, False Negative, \n False Positive, True Negative)" )
23+ print (r )
24+
25+ # Remember that the higher the precision, the more confident the model is when it classifies a sample as Positive.
26+ # Higher the recall, the more positive samples the model correctly classified as Positive.
27+
28+ precision = precision_score (y_true = y_true , y_pred = y_pred , pos_label = "positive" )
29+ print ("\n # Precision = 4/(4+1)" )
30+ print (precision )
31+
32+ recall = recall_score (y_true = y_true , y_pred = y_pred , pos_label = "positive" )
33+ print ("\n # Recall = 4/(4+2)" )
34+ print (recall )
35+
36+ # Appendix B - working with multiple thresholds
37+ y_true = ["positive" , "negative" , "negative" , "positive" , "positive" , "positive" , "negative" , "positive" , "negative" , "positive" , "positive" , "positive" , "positive" , "negative" , "negative" , "negative" ]
38+
39+ pred_scores = [0.7 , 0.3 , 0.5 , 0.6 , 0.55 , 0.9 , 0.4 , 0.2 , 0.4 , 0.3 , 0.7 , 0.5 , 0.8 , 0.2 , 0.3 , 0.35 ]
40+
41+ thresholds = np .arange (start = 0.2 , stop = 0.7 , step = 0.05 )
42+
43+ # Due to the importance of both precision and recall, there is a precision-recall curve that shows
44+ # the tradeoff between the precision and recall values for different thresholds.
45+ # This curve helps to select the best threshold to maximize both metrics
46+
47+ def precision_recall_curve (y_true , pred_scores , thresholds ):
48+ precisions = []
49+ recalls = []
50+ f1_scores = []
51+
52+ for threshold in thresholds :
53+ y_pred = ["positive" if score >= threshold else "negative" for score in pred_scores ]
54+
55+ precision = precision_score (y_true = y_true , y_pred = y_pred , pos_label = "positive" )
56+ recall = recall_score (y_true = y_true , y_pred = y_pred , pos_label = "positive" )
57+ f1_score = (2 * precision * recall ) / (precision + recall )
58+
59+ precisions .append (precision )
60+ recalls .append (recall )
61+ f1_scores .append (f1_score )
62+
63+ return precisions , recalls , f1_scores
64+
65+ precisions , recalls , f1_scores = precision_recall_curve (y_true = y_true ,
66+ pred_scores = pred_scores ,
67+ thresholds = thresholds )
68+
69+ print ("\n Recall:: Precision :: F1-Score" ,)
70+ for p , r , f in zip (precisions , recalls , f1_scores ):
71+ print (round (r ,4 ),"\t ::\t " ,round (p ,4 ),"\t ::\t " ,round (f ,4 ))
72+
73+ # np.max() returns the max. value in the array
74+ # np.argmax() will return the index of the value found by np.max()
75+
76+ print ('Best F1-Score: ' , np .max (f1_scores ))
77+ idx_best_f1 = np .argmax (f1_scores )
78+ print ('\n Best threshold: ' , thresholds [idx_best_f1 ])
79+ print ('Index of threshold: ' , idx_best_f1 )
80+
81+ # Can disable comment to display the plot
82+
83+ # plt.plot(recalls, precisions, linewidth=4, color="red")
84+ # plt.scatter(recalls[idx_best_f1], precisions[idx_best_f1], zorder=1, linewidth=6)
85+ # plt.xlabel("Recall", fontsize=12, fontweight='bold')
86+ # plt.ylabel("Precision", fontsize=12, fontweight='bold')
87+ # plt.title("Precision-Recall Curve", fontsize=15, fontweight="bold")
88+ # plt.show()
89+
90+ # Appendix C - average precision (AP)
91+ precisions , recalls , f1_scores = precision_recall_curve (y_true = y_true ,
92+ pred_scores = pred_scores ,
93+ thresholds = thresholds )
94+
95+ precisions .append (1 )
96+ recalls .append (0 )
97+
98+ precisions = np .array (precisions )
99+ recalls = np .array (recalls )
100+
101+ print ('\n Recall ::' ,recalls )
102+ print ('Precision ::' ,precisions )
103+
104+ AP = np .sum ((recalls [:- 1 ] - recalls [1 :]) * precisions [:- 1 ])
105+ print ("\n AP --" , AP )
106+
107+ # Appendix D - Intersection over Union
108+
109+ # gt_box -- ground-truth bounding box
110+ # pred_box -- prediction bounding box
111+ def intersection_over_union (gt_box , pred_box ):
112+
113+ inter_box_top_left = [max (gt_box [0 ], pred_box [0 ]), max (gt_box [1 ], pred_box [1 ])]
114+
115+ print ("\n inter_box_top_left:" , inter_box_top_left )
116+ print ("gt_box:" , gt_box )
117+ print ("pred_box:" , pred_box )
118+ inter_box_bottom_right = [min (gt_box [0 ]+ gt_box [2 ], pred_box [0 ]+ pred_box [2 ]), min (gt_box [1 ]+ gt_box [3 ], pred_box [1 ]+ pred_box [3 ])]
119+ print ("inter_box_bottom_right:" , inter_box_bottom_right )
120+
121+ inter_box_w = inter_box_bottom_right [0 ] - inter_box_top_left [0 ]
122+ print ("inter_box_w:" , inter_box_w )
123+ inter_box_h = inter_box_bottom_right [1 ] - inter_box_top_left [1 ]
124+ print ("inter_box_h:" , inter_box_h )
125+
126+ intersection = inter_box_w * inter_box_h
127+ union = gt_box [2 ] * gt_box [3 ] + pred_box [2 ] * pred_box [3 ] - intersection
128+
129+ iou = intersection / union
130+
131+ return iou , intersection , union
132+
133+ gt_box1 = [320 , 220 , 680 , 900 ]
134+ pred_box1 = [500 , 320 , 550 , 700 ]
135+
136+ gt_box2 = [645 , 130 , 310 , 320 ]
137+ pred_box2 = [500 , 60 , 310 , 320 ]
138+
139+ iou1 = intersection_over_union (gt_box1 , pred_box1 )
140+ print ("\n IOU1 ::" , iou1 )
141+
142+ iou2 = intersection_over_union (gt_box2 , pred_box2 )
143+ print ("\n IOU2 ::" , iou2 )
0 commit comments