@@ -38,11 +38,11 @@ def load_all_files(self):
3838 print ('Incorrect threshold value! It should be in [0, 1]. Please check and retry ~' )
3939 return 0
4040 pre_files = [x for x in os .listdir (prediction_path ) if (x .endswith ('txt' ) or x .endswith ('xml' ))]
41- print ("Num of prediction files: " , len (pre_files ))
41+ print ("Num of prediction files: " , len (pre_files ))
4242 gt_files = os .listdir (gt_path )
43- print ("Num of ground truth files: " , len (gt_files ))
43+ print ("Num of ground truth files: " , len (gt_files ))
4444 if len (pre_files ) != len (gt_files ):
45- print ("groundtruths ' size does not match predictions' size, please check ~ " )
45+ print ("ground truths ' size does not match predictions' size, please check ~ " )
4646 return 0
4747 elif len (pre_files ) < 1 :
4848 print ('No files! Please check~' )
@@ -226,19 +226,19 @@ def computeAp(self, label):
226226 plt .ylabel ('recall' )
227227 plt .draw () # 显示绘图
228228 # plt.pause(5) # 显示5秒
229- plt .savefig ("class_{}_roc.jpg " .format (label )) # 保存图象
229+ plt .savefig ("class_{}_roc.png " .format (label )) # 保存图象
230230 plt .close ()
231231
232232 if self .pr :
233- # 画roc曲线图
233+ # 画pr曲线图
234234 plt .figure ('Draw_pr' )
235235 plt .plot (rec , prec ) # plot绘制折线图
236236 plt .grid (True )
237237 plt .xlabel ('recall' )
238238 plt .ylabel ('precision' )
239239 plt .draw () # 显示绘图
240240 # plt.pause(5) # 显示5秒
241- plt .savefig ("class_{}_pr.jpg " .format (label )) # 保存图象
241+ plt .savefig ("class_{}_pr.png " .format (label )) # 保存图象
242242 plt .close ()
243243
244244 fppi = 0
@@ -276,11 +276,12 @@ def run(self):
276276 prediction_path , gt_path , predictions , groundtruths , file_format = self .load_all_files ()
277277 aps = 0
278278
279- # temp
280- class_map_temp = {1 : 'Person' , 2 : 'Vehicle' , 3 : 'Dryer' }
279+ # modify as you need
280+ # list your label names as below ['class 1', 'class 2'......]
281+ class_names = ['face' ]
281282
282- for label in range ( 1 , self . cls ) :
283- semantic_label = class_map_temp [ label ]
283+ for label in class_names :
284+ semantic_label = label
284285 print ('Processing label: {}' .format (semantic_label ))
285286 self .get_tp_fp (gt_path , prediction_path , groundtruths , predictions , semantic_label , file_format )
286287 precision , recall , fppi , fppw , ap = self .computeAp (semantic_label )
@@ -294,8 +295,9 @@ def run(self):
294295 if self .FPPIW :
295296 print ('FPPW: ' , fppw , 'FPPI' , fppi )
296297 aps += ap
298+
297299 mAp = aps / (self .cls - 1 )
298- print ("mAp: " , mAp )
300+ print ("mAp: " , mAp )
299301
300302 return 0
301303
0 commit comments