Skip to content

Commit c4ad503

Browse files
authored
Merge branch 'main' into docs
2 parents ba54691 + e119858 commit c4ad503

File tree

5 files changed

+391
-1
lines changed

5 files changed

+391
-1
lines changed

README_CN.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ MX ([MindX](https://www.hiascend.com/zh/software/mindx-sdk)的缩写) 是一个
7979

8080
MindOCR集成了MX推理引擎,支持文本检测识别任务,请参考[mx_infer](docs/cn/inference_cn.md).
8181

82+
8283
#### 使用Lite推理
8384

8485
敬请期待
@@ -108,7 +109,7 @@ MindOCR集成了MX推理引擎,支持文本检测识别任务,请参考[mx_i
108109

109110
模型训练的配置及性能结果请见[configs](./configs).
110111

111-
基于MX Engine推理的模型性能结果请见[mx inference performance](docs/cn/inference_models_cn.md)
112+
基于MX引擎的推理性能结果及支持模型列表,请见[mx inference performance](docs/cn/inference_models_cn.md)
112113

113114
## 重要信息
114115

deploy/eval_utils/eval_script.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
import argparse
2+
import codecs
3+
import json
4+
import logging
5+
import os
6+
7+
import numpy as np
8+
from joblib import Parallel, delayed
9+
from shapely.geometry import Polygon
10+
from tqdm import tqdm
11+
12+
"""
13+
Evaluate the accuracy of detection and Recognition results compared to samples
14+
15+
params:
16+
--gt_path: path to the test dataset label file.
17+
--pred_path: path to store running inference results.
18+
--parallel_num: parallel number, default is 32.
19+
20+
for example:
21+
python eval_script.py --gt_path=/xx/xx/icdar2019_lsvt/labels --pred_path=/xx/xx/pipeline_results.txt
22+
"""
23+
24+
25+
def transform_pred_to_dir(file_path):
26+
with open(file_path, encoding='utf-8') as file:
27+
file_path = os.path.join(os.getcwd(), 'temp')
28+
for line in tqdm(file.readlines()):
29+
line = line.strip()
30+
line_list = line.split('\t')
31+
file_name = line_list[0]
32+
res_list = json.loads(line_list[1]) if len(line_list) >= 2 else ''
33+
file_name = file_name.replace('gt', 'infer_img')
34+
file_name = file_name.replace('jpg', 'txt')
35+
36+
if not os.path.exists(file_path):
37+
os.mkdir(file_path)
38+
with open(os.path.join(file_path, file_name), 'w', encoding='utf-8') as new_file:
39+
for res in res_list:
40+
transcription = res.get('transcription', '')
41+
points = res.get('points', [])
42+
if not transcription and not points:
43+
continue
44+
points_str = ','.join(str(x) for x in points) if isinstance(points, list) else ''
45+
new_file.writelines(points_str + ',' + transcription + '\n')
46+
return file_path
47+
48+
49+
def get_image_info_list(file_list, ratio_list=[1.0]):
50+
if isinstance(file_list, str):
51+
file_list = [file_list]
52+
else:
53+
raise NotImplementedError
54+
data_lines = []
55+
for idx, file in enumerate(file_list):
56+
with open(file, "rb") as f:
57+
lines = f.readlines()
58+
if lines and lines[0][0:3] == codecs.BOM_UTF8:
59+
lines[0] = lines[0].replace(codecs.BOM_UTF8, b'')
60+
lines = lines[:int(len(lines) * ratio_list[idx])]
61+
data_lines.extend(lines)
62+
return data_lines
63+
64+
65+
def intersection(g, p):
66+
"""
67+
Intersection.
68+
"""
69+
g = Polygon(g[:8].reshape((4, 2)))
70+
p = Polygon(p[:8].reshape((4, 2)))
71+
g = g.buffer(0)
72+
p = p.buffer(0)
73+
if not g.is_valid or not p.is_valid:
74+
return 0
75+
inter = Polygon(g).intersection(Polygon(p)).area
76+
union = g.area + p.area - inter
77+
if union == 0:
78+
return 0
79+
else:
80+
return inter / union
81+
82+
83+
def process_words(items, prediction, thresh=0.5):
84+
"""
85+
:param items: list of word level group truth
86+
:param prediction: item of line level inference result
87+
:param thresh: threshold to decide whether word box belong to inference box
88+
:return: candidate words with covered area for line prediction ordered from left to right
89+
"""
90+
pred = np.array([int(j) for j in prediction[:8]])
91+
pred_poly = Polygon(pred.reshape((4, 2))).buffer(0)
92+
if not pred_poly.is_valid:
93+
return 0
94+
matched_count = 0
95+
for it in items:
96+
gt = np.array([int(i) for i in it[:8]]).reshape((4, 2))
97+
gt_poly = Polygon(gt).buffer(0)
98+
if not gt_poly.is_valid:
99+
return 0
100+
inter = Polygon(gt_poly).intersection(Polygon(pred_poly)).area
101+
ratio = 0
102+
if gt_poly.area:
103+
ratio = inter / gt_poly.area
104+
105+
if ratio > thresh:
106+
# only with valid word label proves the validity of item
107+
word = it[8]
108+
if word and not word.startswith("###"):
109+
word = word.replace(" ", "")
110+
pred_word = prediction[8].replace(" ", "")
111+
if word in pred_word:
112+
matched_count += 1
113+
return matched_count
114+
115+
116+
def process_box_2015(items, pred_poly, thresh=0.8):
117+
valid_count = 0
118+
for k in range(len(items)):
119+
gt = np.array([int(j) for j in items[k][:8]]).reshape((4, 2))
120+
gt_poly = Polygon(gt).buffer(0)
121+
inter = Polygon(gt_poly).intersection(pred_poly).area
122+
ratio = inter / gt_poly.area
123+
if ratio > thresh:
124+
valid_count += 1
125+
126+
return valid_count
127+
128+
129+
def process_box_2019(items, pred_poly, thresh=0.5):
130+
valid_count = 0
131+
for item in items:
132+
gt = np.array([int(j) for j in item[:8]]).reshape((4, 2))
133+
inter = Polygon(gt).intersection(pred_poly).area
134+
union = Polygon(gt).union(pred_poly).area
135+
136+
if union > 0 and inter / union > thresh:
137+
valid_count += 1
138+
return valid_count
139+
140+
141+
def process_files(filepath):
142+
items = []
143+
data_lines = get_image_info_list(filepath)
144+
for data_line in data_lines:
145+
data_line = data_line.decode('utf-8').strip("\n").strip("\r").split(",")
146+
data_line = data_line[:8] + [','.join(data_line[8:])]
147+
items.append(data_line)
148+
return items
149+
150+
151+
def recognition_eval(gt_pth, pred_pth):
152+
gt_items = process_files(gt_pth)
153+
if os.path.exists(pred_pth):
154+
pred_items = process_files(pred_pth)
155+
else:
156+
pred_items = []
157+
158+
correct_num, total_num = 0, 0
159+
for item in gt_items:
160+
if len(item) != 9:
161+
raise ValueError("invalid gt file!")
162+
if item[8] and not item[8].startswith("###"):
163+
total_num += 1
164+
165+
for prediction in pred_items:
166+
if len(prediction) != 9:
167+
raise ValueError("invalid pred file!")
168+
if not prediction:
169+
continue
170+
matched_num = process_words(gt_items, prediction)
171+
correct_num += matched_num
172+
return correct_num, total_num
173+
174+
175+
def detection_eval(box_func, gt_pth, pred_pth):
176+
gt_items = process_files(gt_pth)
177+
if os.path.exists(pred_pth):
178+
pred_items = process_files(pred_pth)
179+
else:
180+
pred_items = []
181+
valid_items = []
182+
matched = 0
183+
for item in gt_items:
184+
if len(item) != 9:
185+
continue
186+
gt = np.array([int(j) for j in item[:8]]).reshape((4, 2))
187+
gt_poly = Polygon(gt)
188+
if not gt_poly.is_valid or not gt_poly.is_simple:
189+
continue
190+
word = item[8]
191+
if word in ["*", "###"]:
192+
continue
193+
valid_items.append(item)
194+
for prediction in pred_items:
195+
pred = np.array([int(i) for i in prediction[:8]])
196+
pred_poly = Polygon(pred.reshape((4, 2))).buffer(0)
197+
if not pred_poly.is_valid or not pred_poly.is_simple:
198+
continue
199+
matched += box_func(valid_items, pred_poly)
200+
return {
201+
"matched": matched,
202+
"gt_num": len(valid_items),
203+
"det_num": len(pred_items)
204+
}
205+
206+
207+
def eval_each_det(gt_file, eval_func, gt, pred, box_func):
208+
gt_pth = os.path.join(gt, gt_file)
209+
pred_pth = os.path.join(pred, "infer_{}".format(gt_file.split('_', 1)[1]))
210+
return eval_func(box_func, gt_pth, pred_pth)
211+
212+
213+
def eval_each_rec(gt_file, gt, pred, eval_func):
214+
gt_pth = os.path.join(gt, gt_file)
215+
pred_pth = os.path.join(pred, "infer_{}".format(gt_file.split('_', 1)[1]))
216+
correct, total = eval_func(gt_pth, pred_pth)
217+
return correct, total
218+
219+
220+
def eval_rec(eval_func, gt, pred, parallel_num):
221+
"""
222+
:param eval_func:
223+
detection_eval:评估检测指标
224+
recognition_eval: 评估识别指标
225+
:param gt: 标签路径
226+
:param pred: 预测路径
227+
:param parallel_num: 并行度
228+
:return: 指标评估结果
229+
"""
230+
gt_list = os.listdir(gt)
231+
res = Parallel(n_jobs=parallel_num, backend="multiprocessing")(delayed(eval_each_rec)(
232+
gt_file, gt, pred, eval_func) for gt_file in tqdm(gt_list))
233+
res = np.array(res)
234+
correct_num = sum(res[:, 0])
235+
total_num = sum(res[:, 1])
236+
acc = correct_num / total_num if total_num else 0
237+
return {
238+
"acc:": acc,
239+
"correct_num:": correct_num,
240+
"total_num:": total_num
241+
}
242+
243+
244+
def eval_det(eval_func, box_func, gt, pred, parallel_num):
245+
"""
246+
:param eval_func:
247+
detection_eval:评估检测指标
248+
recognition_eval: 评估识别指标
249+
:param gt: 标签路径
250+
:param pred: 预测路径
251+
:return: 指标评估结果
252+
"""
253+
gt_list = os.listdir(gt)
254+
res = Parallel(n_jobs=parallel_num, backend="multiprocessing")(delayed(eval_each_det)(
255+
gt_file, eval_func, gt, pred, box_func) for gt_file in tqdm(gt_list))
256+
257+
matched_num = 0
258+
gt_num = 0
259+
det_num = 0
260+
for result in res:
261+
matched_num += result['matched']
262+
gt_num += result['gt_num']
263+
det_num += result['det_num']
264+
265+
precision = 0 if not det_num else float(matched_num) / det_num
266+
recall = 0 if not gt_num else float(matched_num) / gt_num
267+
h_mean = 0 if not precision + recall else 2 * float(precision * recall) / (precision + recall)
268+
return {
269+
"precision:": precision,
270+
"recall:": recall,
271+
"Hmean:": h_mean,
272+
"matched:": matched_num,
273+
"det_num": det_num,
274+
"gt_num": gt_num
275+
}
276+
277+
278+
def parse_args():
279+
parser = argparse.ArgumentParser()
280+
parser.add_argument('--gt_path', required=True, type=str, help="label storage path")
281+
parser.add_argument('--pred_path', required=True, type=str, help="predicted file or folder path")
282+
parser.add_argument('--parallel_num', required=False, type=int, default=32, help="parallelism, default value is 32")
283+
return parser.parse_args()
284+
285+
286+
def custom_islink(path):
287+
"""Remove ending path separators before checking soft links.
288+
289+
e.g. /xxx/ -> /xxx
290+
"""
291+
return os.path.islink(os.path.abspath(path))
292+
293+
294+
def check_directory_ok(pathname: str):
295+
safe_name = os.path.relpath(pathname)
296+
if not os.path.exists(pathname):
297+
raise ValueError(f'input path {safe_name} does not exist!')
298+
if custom_islink(pathname):
299+
raise ValueError(f'Error! {safe_name} cannot be a soft link!')
300+
if not os.path.isdir(pathname):
301+
raise NotADirectoryError(f'Error! Please check if {safe_name} is a dir.')
302+
if not os.access(pathname, mode=os.R_OK):
303+
raise ValueError(f'Error! Please check if {safe_name} is readable.')
304+
if not os.listdir(pathname):
305+
raise ValueError(f'input path {safe_name} should contain at least one file!')
306+
307+
308+
if __name__ == '__main__':
309+
logging.getLogger().setLevel(logging.INFO)
310+
args = parse_args()
311+
gt_path = args.gt_path
312+
pred_path = args.pred_path
313+
parallel_num = args.parallel_num
314+
315+
check_directory_ok(gt_path)
316+
317+
if os.path.isfile(pred_path):
318+
pred_path = transform_pred_to_dir(pred_path)
319+
check_directory_ok(pred_path)
320+
321+
result = eval_det(detection_eval, process_box_2019, gt_path, pred_path, parallel_num)
322+
logging.info(f'det: {result}')
323+
324+
result = eval_rec(recognition_eval, gt_path, pred_path, parallel_num)
325+
logging.info(f'rec: {result}')

deploy/eval_utils/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
shapely>=1.8.2
2+
numpy>=1.22.4
3+
joblib>=1.1.0
4+
tqdm>=4.64.0

0 commit comments

Comments
 (0)