Skip to content

Commit 8e0b7ba

Browse files
authored
optim performance and cls/rec batchsize for inference (#480)
1 parent 05018ed commit 8e0b7ba

26 files changed

+355
-298
lines changed

deploy/py_infer/src/data_process/utils/constants.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

deploy/py_infer/src/infer/infer_base.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import argparse
2+
import gc
23
from abc import ABCMeta, abstractmethod
34

5+
from ..core import Model
6+
47

58
class InferBase(metaclass=ABCMeta):
69
"""
@@ -10,10 +13,47 @@ class InferBase(metaclass=ABCMeta):
1013
def __init__(self, args: argparse.Namespace, **kwargs):
1114
super().__init__()
1215
self.args = args
16+
1317
self.model = None
18+
self._bs_list = []
19+
self._hw_list = []
20+
21+
def init(self, *, preprocess=True, model=True, postprocess=True):
22+
if preprocess or model:
23+
self._init_model()
24+
25+
if preprocess:
26+
self._init_preprocess()
27+
28+
if postprocess:
29+
self._init_postprocess()
30+
31+
if not model:
32+
self.free_model()
33+
34+
if model:
35+
if isinstance(self.model, dict):
36+
for _model in self.model.values():
37+
_model.warmup()
38+
elif isinstance(self.model, Model):
39+
self.model.warmup()
40+
else:
41+
pass
42+
43+
@abstractmethod
44+
def _init_preprocess(self):
45+
pass
1446

1547
@abstractmethod
16-
def init(self, **kwargs):
48+
def _init_model(self):
49+
pass
50+
51+
@abstractmethod
52+
def _init_postprocess(self):
53+
pass
54+
55+
@abstractmethod
56+
def get_params(self):
1757
pass
1858

1959
@abstractmethod
@@ -34,12 +74,13 @@ def postprocess(self, *args, **kwargs):
3474

3575
def free_model(self):
3676
if hasattr(self, "model") and self.model:
37-
if isinstance(self.model, (tuple, list)):
38-
for model in self.model:
39-
del model
40-
else:
41-
del self.model
42-
self.model = None
77+
if isinstance(self.model, dict):
78+
self.model.clear()
79+
80+
del self.model
81+
gc.collect()
82+
83+
self.model = None
4384

4485
def __del__(self):
4586
self.free_model()

deploy/py_infer/src/infer/infer_cls.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
class TextClassifier(InferBase):
1212
def __init__(self, args):
1313
super(TextClassifier, self).__init__(args)
14-
self._bs_list = []
1514

16-
def init(self, warmup=False):
15+
def _init_preprocess(self):
16+
preprocess_ops = build_preprocess(self.args.cls_config_path)
17+
self.preprocess_ops = functools.partial(preprocess_ops, target_size=self._hw_list[0])
18+
19+
def _init_model(self):
1720
self.model = Model(
1821
backend=self.args.backend, model_path=self.args.cls_model_path, device_id=self.args.device_id
1922
)
@@ -28,12 +31,13 @@ def init(self, warmup=False):
2831
batchsize, _, model_height, model_width = shape_info
2932
self._bs_list = [batchsize]
3033

31-
preprocess_ops = build_preprocess(self.args.cls_config_path)
32-
self.preprocess_ops = functools.partial(preprocess_ops, target_size=(model_height, model_width))
34+
self._hw_list = [(model_height, model_width)]
35+
36+
def _init_postprocess(self):
3337
self.postprocess_ops = build_postprocess(self.args.cls_config_path)
3438

35-
if warmup:
36-
self.model.warmup()
39+
def get_params(self):
40+
return {"cls_batch_num": self._bs_list}
3741

3842
def __call__(self, image: Union[np.ndarray, List[np.ndarray]]):
3943
images = [image] if isinstance(image, np.ndarray) else image

deploy/py_infer/src/infer/infer_det.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
class TextDetector(InferBase):
1111
def __init__(self, args):
1212
super(TextDetector, self).__init__(args)
13-
self._hw_list = []
1413

15-
def init(self, warmup=False):
14+
def _init_preprocess(self):
15+
self.preprocess_ops = build_preprocess(self.args.det_config_path)
16+
17+
def _init_model(self):
1618
self.model = Model(
1719
backend=self.args.backend,
1820
model_path=self.args.det_model_path,
@@ -34,11 +36,13 @@ def init(self, warmup=False):
3436
raise ValueError("Input batch size must be 1 for detection model.")
3537

3638
self._hw_list = hw_list
37-
self.preprocess_ops = build_preprocess(self.args.det_config_path)
39+
self._bs_list = [batchsize]
40+
41+
def _init_postprocess(self):
3842
self.postprocess_ops = build_postprocess(self.args.det_config_path, rescale_fields=["polys"])
3943

40-
if warmup:
41-
self.model.warmup()
44+
def get_params(self):
45+
return {"det_batch_num": self._bs_list}
4246

4347
def __call__(self, image: np.ndarray):
4448
data = self.preprocess(image)
@@ -54,9 +58,7 @@ def preprocess(self, image: np.ndarray) -> Dict:
5458
def model_infer(self, data: Dict) -> List[np.ndarray]:
5559
return self.model.infer([data["image"]]) # model infer for single input
5660

57-
def postprocess(self, pred, shape_list: np.ndarray) -> np.ndarray:
61+
def postprocess(self, pred, shape_list: np.ndarray) -> List[np.ndarray]:
5862
polys = self.postprocess_ops(tuple(pred), shape_list)["polys"][0] # {'polys': [img0_polys, ...], ...}
59-
polys = np.array(polys)
60-
# polys.shape may be (0,), (polys_num, points_num, 2), (1, polys_num, points_num, 2)
61-
polys_shape = (-1, *polys.shape[-2:]) if polys.size != 0 else (0, 0, 2)
62-
return polys.reshape(*polys_shape) # (polys_num, points_num, 2), because bs=1
63+
polys = [np.array(x) for x in polys]
64+
return polys # [poly(points_num, 2), ...], bs=1

deploy/py_infer/src/infer/infer_rec.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import math
22
import os
3-
from collections import defaultdict
43
from typing import Dict, List, Tuple, Union
54

65
import numpy as np
@@ -14,11 +13,7 @@
1413
class TextRecognizer(InferBase):
1514
def __init__(self, args):
1615
super(TextRecognizer, self).__init__(args)
17-
18-
self._hw_list = []
19-
self._bs_list = []
20-
21-
self.model = defaultdict()
16+
self.model: Dict[int, Model] = {}
2217
self.shape_type = None
2318

2419
def __get_shape_for_single_model(self, filename):
@@ -65,7 +60,10 @@ def __get_resized_hw(self, image_list):
6560

6661
return max_h, max_w
6762

68-
def init(self, warmup=False):
63+
def _init_preprocess(self):
64+
self.preprocess_ops = build_preprocess(self.args.rec_config_path)
65+
66+
def _init_model(self):
6967
model_path = self.args.rec_model_path
7068

7169
if os.path.isfile(model_path):
@@ -91,13 +89,12 @@ def init(self, warmup=False):
9189

9290
self._bs_list.sort()
9391

94-
self.preprocess_ops = build_preprocess(self.args.rec_config_path)
92+
def _init_postprocess(self):
9593
params = {"character_dict_path": self.args.character_dict_path}
9694
self.postprocess_ops = build_postprocess(self.args.rec_config_path, **params)
9795

98-
if warmup:
99-
for model in self.model.values():
100-
model.warmup()
96+
def get_params(self):
97+
return {"rec_batch_num": self._bs_list}
10198

10299
def __call__(self, image: Union[np.ndarray, List[np.ndarray]]):
103100
images = [image] if isinstance(image, np.ndarray) else image
@@ -129,10 +126,10 @@ def preprocess(self, image: List[np.ndarray]) -> Tuple[List[int], List[Dict]]:
129126
return split_bs, split_data
130127

131128
def model_infer(self, data: Dict) -> List[np.ndarray]:
132-
input = data["image"]
133-
bs, *_ = input.shape
129+
input_data = data["image"]
130+
bs, *_ = input_data.shape
134131
n = bs if bs in self._bs_list else -1
135-
return self.model[n].infer([input])
132+
return self.model[n].infer([input_data])
136133

137134
def postprocess(self, pred, batch=None):
138135
pred = gear_utils.get_batch_from_padding(pred, batch)

deploy/py_infer/src/infer_args.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,6 @@ def update_task_info(args):
122122
task_order = (det, cls, rec)
123123
if task_order in task_map:
124124
setattr(args, "task_type", task_map[task_order])
125-
setattr(args, "save_vis_det_save_dir", bool(args.vis_det_save_dir))
126-
setattr(args, "save_vis_pipeline_save_dir", bool(args.vis_pipeline_save_dir))
127-
setattr(args, "save_crop_res_dir", bool(args.crop_save_dir))
128125
else:
129126
unsupported_task_map = {
130127
(False, False, False): "empty",
@@ -163,9 +160,6 @@ def check_and_update_args(args):
163160
if args.rec_model_path and not args.rec_model_name_or_config:
164161
raise ValueError("rec_model_name_or_config can't be emtpy when set rec_model_path for recognition.")
165162

166-
if args.parallel_num < 1 or args.parallel_num > 4:
167-
raise ValueError(f"parallel_num must between [1,4], current: {args.parallel_num}.")
168-
169163
if args.crop_save_dir and args.task_type not in (TaskType.DET_REC, TaskType.DET_CLS_REC):
170164
raise ValueError("det_model_path and rec_model_path can't be empty when set crop_save_dir.")
171165

@@ -219,11 +213,11 @@ def check_and_update_args(args):
219213
def init_save_dir(args):
220214
if args.res_save_dir:
221215
save_path_init(args.res_save_dir)
222-
if args.save_crop_res_dir:
216+
if args.crop_save_dir:
223217
save_path_init(args.crop_save_dir)
224-
if args.save_vis_pipeline_save_dir:
218+
if args.vis_pipeline_save_dir:
225219
save_path_init(args.vis_pipeline_save_dir)
226-
if args.save_vis_det_save_dir:
220+
if args.vis_det_save_dir:
227221
save_path_init(args.vis_det_save_dir)
228222
if args.save_log_dir:
229223
save_path_init(args.save_log_dir, exist_ok=True)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .message_data import ProfilingData, StartSign, StopSign
1+
from .message_data import ProfilingData, StopSign
22
from .module_data import ModuleConnectDesc, ModuleDesc, ModuleInitArgs
33
from .process_data import ProcessData, StopData

deploy/py_infer/src/parallel/datatype/message_data.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
from dataclasses import dataclass
22

33

4-
@dataclass
5-
class StartSign:
6-
start: bool = True
7-
8-
94
@dataclass
105
class StopSign:
116
stop: bool = True

deploy/py_infer/src/parallel/datatype/process_data.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,21 @@
66

77
@dataclass
88
class ProcessData:
9-
# infer_test info
10-
sub_image_total: int = 0
11-
image_total: int = 0
12-
infer_result: list = field(default_factory=lambda: [])
9+
# skip each compute node
1310
skip: bool = False
11+
# prediction results of each image
12+
infer_result: list = field(default_factory=lambda: [])
1413

1514
# image basic info
16-
image_path: str = ""
17-
image_name: str = ""
18-
image_id: int = ""
19-
frame: np.ndarray = None
15+
image_path: List[str] = field(default_factory=lambda: [])
16+
frame: List[np.ndarray] = field(default_factory=lambda: [])
2017

18+
# sub image of detection box, for det (+ cls) + rec
19+
sub_image_total: int = 0 # len(sub_image_list_0) + len(sub_image_list_1) + ...
2120
sub_image_list: list = field(default_factory=lambda: [])
22-
sub_image_size: int = 0
21+
sub_image_size: int = 0 # len of sub_image_list
22+
23+
# data for preprocess -> infer -> postprocess
2324
data: Union[np.ndarray, List[np.ndarray], Dict] = None
2425

2526

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .module_base import ModuleBase
22
from .module_manager import ModuleManager
3+
from .pipeline_manager import ParallelPipelineManager

0 commit comments

Comments
 (0)