Skip to content

Commit bb3a9a7

Browse files
author
wangzhe
committed
update: label map
1 parent 6225b49 commit bb3a9a7

File tree

11 files changed

+53
-35
lines changed

11 files changed

+53
-35
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
*.pyc
2+
__pycache__/
3+
__pycache__/*
4+
.idea/
5+
.DS_Store
6+
results/*

evaluation.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ def load_all_files(self):
3838
print('Incorrect threshold value! It should be in [0, 1]. Please check and retry ~')
3939
return 0
4040
pre_files = [x for x in os.listdir(prediction_path) if (x.endswith('txt') or x.endswith('xml'))]
41-
print ("Num of prediction files: ", len(pre_files))
41+
print("Num of prediction files: ", len(pre_files))
4242
gt_files = os.listdir(gt_path)
43-
print ("Num of ground truth files: ", len(gt_files))
43+
print("Num of ground truth files: ", len(gt_files))
4444
if len(pre_files) != len(gt_files):
45-
print("groundtruths' size does not match predictions' size please check ~ ")
45+
print("ground truths' size does not match predictions' size, please check ~ ")
4646
return 0
4747
elif len(pre_files) < 1:
4848
print('No files! Please check~')
@@ -226,19 +226,19 @@ def computeAp(self, label):
226226
plt.ylabel('recall')
227227
plt.draw() # 显示绘图
228228
# plt.pause(5) # 显示5秒
229-
plt.savefig("class_{}_roc.jpg".format(label)) # 保存图象
229+
plt.savefig("class_{}_roc.png".format(label)) # 保存图象
230230
plt.close()
231231

232232
if self.pr:
233-
# 画roc曲线图
233+
# 画pr曲线图
234234
plt.figure('Draw_pr')
235235
plt.plot(rec, prec) # plot绘制折线图
236236
plt.grid(True)
237237
plt.xlabel('recall')
238238
plt.ylabel('precision')
239239
plt.draw() # 显示绘图
240240
# plt.pause(5) # 显示5秒
241-
plt.savefig("class_{}_pr.jpg".format(label)) # 保存图象
241+
plt.savefig("class_{}_pr.png".format(label)) # 保存图象
242242
plt.close()
243243

244244
fppi = 0
@@ -276,11 +276,12 @@ def run(self):
276276
prediction_path, gt_path, predictions, groundtruths, file_format = self.load_all_files()
277277
aps = 0
278278

279-
# temp
280-
class_map_temp = {1: 'Person', 2: 'Vehicle', 3: 'Dryer'}
279+
# modify as you need
280+
# list your label names as below ['class 1', 'class 2'......]
281+
class_names = ['face']
281282

282-
for label in range(1, self.cls):
283-
semantic_label = class_map_temp[label]
283+
for label in class_names:
284+
semantic_label = label
284285
print('Processing label: {}'.format(semantic_label))
285286
self.get_tp_fp(gt_path, prediction_path, groundtruths, predictions, semantic_label, file_format)
286287
precision, recall, fppi, fppw, ap = self.computeAp(semantic_label)
@@ -294,8 +295,9 @@ def run(self):
294295
if self.FPPIW:
295296
print('FPPW: ', fppw, 'FPPI', fppi)
296297
aps += ap
298+
297299
mAp = aps / (self.cls - 1)
298-
print ("mAp: ", mAp)
300+
print("mAp: ", mAp)
299301

300302
return 0
301303

io_file.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,24 @@
33

44

55
# modify 'class_map' as you need
6-
class_map = {'Person': 'Person', 'Vehicle': 'Vehicle', 'Dryer': 'Dryer'}
6+
# {id or label name in gt label: class name}
7+
class_map = {'face': 'face'}
78

89

10+
# parse pascal voc style label file
911
def parse_xml(xml_path):
10-
dom = xml.dom.minidom.parse(xml_path)
12+
gts = []
13+
try:
14+
dom = xml.dom.minidom.parse(xml_path)
15+
print('{} parse failed! Use empty label instead \n'.format(xml_path))
16+
except:
17+
return gts
1118
root = dom.documentElement
1219
objects = root.getElementsByTagName('object')
13-
gts = []
1420
for index, obj in enumerate(objects):
15-
name = obj.getElementsByTagName('name')[0].firstChild.data.decode('utf8')
21+
name = obj.getElementsByTagName('name')[0].firstChild.data.strip("\ufeff")
22+
if name not in class_map:
23+
continue
1624
label = class_map[name]
1725
bndbox = obj.getElementsByTagName('bndbox')[0]
1826
x1 = int(bndbox.getElementsByTagName('xmin')[0].firstChild.data)
@@ -21,4 +29,4 @@ def parse_xml(xml_path):
2129
y2 = int(bndbox.getElementsByTagName('ymax')[0].firstChild.data)
2230
gt_one = [label, x1, y1, x2, y2]
2331
gts.append(gt_one)
24-
return gts
32+
return gts

results/class_1_pr.jpg

-35.1 KB
Binary file not shown.

results/class_1_roc.jpg

-35.8 KB
Binary file not shown.

results/class_2_pr.jpg

-34.5 KB
Binary file not shown.

results/class_2_roc.jpg

-36.2 KB
Binary file not shown.

results/class_3_pr.jpg

-42.2 KB
Binary file not shown.

results/class_3_roc.jpg

-42.2 KB
Binary file not shown.

sample/.DS_Store

-6 KB
Binary file not shown.

0 commit comments

Comments
 (0)