Skip to content

Commit b8092f2

Browse files
authored
Add RobustScanner rec model (#444)
1 parent e1c02de commit b8092f2

File tree

18 files changed

+2190
-13
lines changed

18 files changed

+2190
-13
lines changed

configs/rec/crnn/README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ eval:
287287
```
288288

289289
**注意:**
290-
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当GPU/NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或将学习率线性调整为新的全局批大小
290+
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当GPU/NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或根据新的全局批大小线性调整学习率
291291

292292

293293
### 3.2 模型训练

configs/rec/master/README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ eval:
297297
```
298298

299299
**注意:**
300-
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当GPU/NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或将学习率线性调整为新的全局批大小
300+
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当GPU/NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或根据新的全局批大小线性调整学习率
301301

302302

303303
### 3.2 模型训练

configs/rec/rare/README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ eval:
262262
```
263263

264264
**注意:**
265-
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当GPU/NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或将学习率线性调整为新的全局批大小
265+
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当GPU/NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或根据新的全局批大小线性调整学习率
266266

267267

268268
### 3.2 模型训练

configs/rec/robustscanner/README.md

Lines changed: 389 additions & 0 deletions
Large diffs are not rendered by default.

configs/rec/robustscanner/README_CN.md

Lines changed: 394 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
system:
2+
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
3+
distribute: True
4+
amp_level: 'O0'
5+
seed: 42
6+
log_interval: 100
7+
val_while_train: True
8+
drop_overflow_update: False
9+
10+
common:
11+
character_dict_path: &character_dict_path mindocr/utils/dict/en_dict90.txt
12+
max_text_len: &max_text_len 40
13+
use_space_char: &use_space_char False
14+
batch_size: &batch_size 64
15+
16+
model:
17+
type: rec
18+
transform: null
19+
backbone:
20+
name: rec_resnet31
21+
pretrained: False
22+
head:
23+
name: RobustScannerHead
24+
out_channels: 93 # 90 + unknown + start + padding
25+
enc_outchannles: 128
26+
hybrid_dec_rnn_layers: 2
27+
hybrid_dec_dropout: 0.
28+
position_dec_rnn_layers: 2
29+
start_idx: 91
30+
mask: True
31+
padding_idx: 92
32+
encode_value: False
33+
max_text_len: *max_text_len
34+
35+
postprocess:
36+
name: SARLabelDecode
37+
character_dict_path: *character_dict_path
38+
use_space_char: *use_space_char
39+
rm_symbol: True
40+
41+
metric:
42+
name: RecMetric
43+
main_indicator: acc
44+
character_dict_path: *character_dict_path
45+
ignore_space: True
46+
print_flag: False
47+
48+
loss:
49+
name: SARLoss
50+
ignore_index: 92
51+
52+
scheduler:
53+
scheduler: multi_step_decay
54+
milestones: [6, 8]
55+
decay_rate: 0.1
56+
lr: 0.001
57+
num_epochs: 10
58+
warmup_epochs: 0
59+
60+
optimizer:
61+
opt: adamW
62+
beta1: 0.9
63+
beta2: 0.999
64+
65+
loss_scaler:
66+
type: static
67+
loss_scale: 512
68+
69+
train:
70+
ema: True
71+
ckpt_save_dir: './tmp_rec'
72+
dataset_sink_mode: False
73+
dataset:
74+
type: LMDBDataset
75+
dataset_root: path/to/data/ # Optional, if set, dataset_root will be used as a prefix for data_dir
76+
data_dir: training/
77+
# label_files: # not required when using LMDBDataset
78+
sample_ratio: 1.0
79+
shuffle: True
80+
random_choice_if_none: True # Random choose another data if the result returned from data transform is none
81+
transform_pipeline:
82+
- DecodeImage:
83+
img_mode: BGR
84+
to_float32: False
85+
- SARLabelEncode: # Class handling label
86+
max_text_len: *max_text_len
87+
character_dict_path: *character_dict_path
88+
use_space_char: *use_space_char
89+
lower: True
90+
- RobustScannerRecResizeImg:
91+
image_shape: [ 3, 48, 48, 160 ] # h:48 w:[48,160]
92+
width_downsample_ratio: 0.25
93+
max_text_len: *max_text_len
94+
output_columns: ['image', 'label', 'valid_width_mask', 'word_positions']
95+
net_input_column_index: [0, 1, 2, 3] # input indices for network forward func in output_columns
96+
label_column_index: [1] # input indices marked as label
97+
#keys_for_loss: 4 # num labels for loss func
98+
99+
loader:
100+
shuffle: True # TODO: tbc
101+
batch_size: *batch_size
102+
drop_remainder: True
103+
max_rowsize: 12
104+
num_workers: 8
105+
106+
eval:
107+
ckpt_load_path: ./tmp_rec/best.ckpt
108+
dataset_sink_mode: False
109+
dataset:
110+
type: LMDBDataset
111+
dataset_root: path/to/data/
112+
data_dir: evaluation/
113+
# label_files: # not required when using LMDBDataset
114+
sample_ratio: 1.0
115+
shuffle: False
116+
transform_pipeline:
117+
- DecodeImage:
118+
img_mode: BGR
119+
to_float32: False
120+
- SARLabelEncode: # Class handling label
121+
max_text_len: *max_text_len
122+
# character_dict_path: *character_dict_path
123+
use_space_char: *use_space_char
124+
is_training: False
125+
lower: True
126+
- RobustScannerRecResizeImg:
127+
image_shape: [ 3, 48, 48, 160 ]
128+
width_downsample_ratio: 0.25
129+
max_text_len: *max_text_len
130+
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
131+
output_columns: [ 'image', 'valid_width_mask', 'word_positions', 'text_padded', 'text_length' ]
132+
net_input_column_index: [ 0, 1, 2 ] # input indices for network forward func in output_columns
133+
label_column_index: [3, 4]
134+
135+
loader:
136+
shuffle: False # TODO: tbc
137+
batch_size: 64
138+
drop_remainder: True
139+
max_rowsize: 12
140+
num_workers: 8

mindocr/data/rec_lmdb_dataset.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class LMDBDataset(BaseDataset):
2727
if None, all data keys will be used for return.
2828
filter_max_len (bool): Filter the records where the label is longer than the `max_text_len`.
2929
max_text_len (int): The maximum text length the dataloader expected.
30+
random_choice_if_none (bool): Random choose another data if the result returned from data transform is none
3031
3132
Returns:
3233
data (tuple): Depending on the transform pipeline, __get_item__ returns a tuple for the specified data item.
@@ -54,11 +55,13 @@ def __init__(
5455
output_columns: Optional[List[str]] = None,
5556
filter_max_len: bool = False,
5657
max_text_len: Optional[int] = None,
58+
random_choice_if_none: bool = False,
5759
**kwargs: Any,
5860
):
5961
self.data_dir = data_dir
6062
self.filter_max_len = filter_max_len
6163
self.max_text_len = max_text_len
64+
self.random_choice_if_none = random_choice_if_none
6265

6366
shuffle = shuffle if shuffle is not None else is_train
6467

@@ -197,10 +200,25 @@ def __getitem__(self, idx):
197200
lmdb_idx, file_idx = self.data_idx_order_list[idx]
198201
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[int(lmdb_idx)]["txn"], int(file_idx))
199202

203+
if sample_info is None and self.random_choice_if_none:
204+
_logger.warning("sample_info is None, randomly choose another data.")
205+
random_idx = np.random.randint(self.__len__())
206+
return self.__getitem__(random_idx)
207+
200208
data = {"img_lmdb": sample_info[0], "label": sample_info[1]}
201209

202210
# perform transformation on data
203-
data = run_transforms(data, transforms=self.transforms)
211+
try:
212+
data = run_transforms(data, transforms=self.transforms)
213+
except Exception as e:
214+
if self.random_choice_if_none:
215+
_logger.warning("data is None after transforms, randomly choose another data.")
216+
random_idx = np.random.randint(self.__len__())
217+
return self.__getitem__(random_idx)
218+
else:
219+
_logger.warning(f"Error occurred during preprocess.\n {e}")
220+
raise e
221+
204222
output_tuple = tuple(data[k] for k in self.output_columns)
205223

206224
return output_tuple

mindocr/data/transforms/rec_transforms.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
"SVTRRecResizeImg",
2020
"Rotate90IfVertical",
2121
"ClsLabelEncode",
22+
"SARLabelEncode",
23+
"RobustScannerRecResizeImg",
2224
]
2325
_logger = logging.getLogger(__name__)
2426

@@ -665,3 +667,155 @@ def __call__(self, data):
665667
data["label"] = label
666668

667669
return data
670+
671+
672+
class SARLabelEncode(object):
673+
"""Convert between text-label and text-index"""
674+
675+
def __init__(self, max_text_len, character_dict_path=None, use_space_char=False, lower=False, is_training=True):
676+
self.max_text_len = max_text_len
677+
self.beg_str = "sos"
678+
self.end_str = "eos"
679+
self.lower = lower
680+
self.is_training = is_training
681+
682+
if character_dict_path is None:
683+
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
684+
dict_character = list(self.character_str)
685+
self.lower = True
686+
if self.is_training:
687+
_logger.warning("The character_dict_path is None, model can only recognize number and lower letters")
688+
else:
689+
self.character_str = []
690+
with open(character_dict_path, "rb") as fin:
691+
lines = fin.readlines()
692+
for line in lines:
693+
line = line.decode("utf-8").strip("\n").strip("\r\n")
694+
self.character_str.append(line)
695+
if use_space_char:
696+
self.character_str.append(" ")
697+
dict_character = list(self.character_str)
698+
dict_character = self.add_special_char(dict_character)
699+
self.dict = {}
700+
for i, char in enumerate(dict_character):
701+
self.dict[char] = i
702+
self.character = dict_character
703+
704+
def encode(self, text):
705+
"""convert text-label into text-index.
706+
input:
707+
text: text labels of each image. [batch_size]
708+
709+
output:
710+
text: concatenated text index for CTCLoss.
711+
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
712+
length: length of each text. [batch_size]
713+
"""
714+
if len(text) == 0 or len(text) > self.max_text_len:
715+
return None
716+
if self.lower:
717+
text = text.lower()
718+
text_list = []
719+
for char in text:
720+
if char not in self.dict:
721+
continue
722+
text_list.append(self.dict[char])
723+
if len(text_list) == 0:
724+
return None
725+
return text_list
726+
727+
def add_special_char(self, dict_character):
728+
beg_end_str = "<BOS/EOS>"
729+
unknown_str = "<UKN>"
730+
padding_str = "<PAD>"
731+
dict_character = dict_character + [unknown_str]
732+
self.unknown_idx = len(dict_character) - 1
733+
dict_character = dict_character + [beg_end_str]
734+
self.start_idx = len(dict_character) - 1
735+
self.end_idx = len(dict_character) - 1
736+
dict_character = dict_character + [padding_str]
737+
self.padding_idx = len(dict_character) - 1
738+
739+
return dict_character
740+
741+
def __call__(self, data):
742+
text = data["label"]
743+
text_str = text
744+
text = self.encode(text)
745+
if text is None:
746+
return None
747+
if len(text) >= self.max_text_len - 1:
748+
return None
749+
data["text_length"] = np.array(len(text))
750+
target = [self.start_idx] + text + [self.end_idx]
751+
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
752+
753+
padded_text[: len(target)] = target
754+
data["label"] = np.array(padded_text)
755+
data["text_padded"] = text_str + " " * (self.max_text_len - len(text_str))
756+
757+
return data
758+
759+
def get_ignored_tokens(self):
760+
return [self.padding_idx]
761+
762+
763+
class RobustScannerRecResizeImg(object):
764+
def __init__(self, image_shape, max_text_len, width_downsample_ratio=0.25, **kwargs):
765+
self.image_shape = image_shape
766+
self.width_downsample_ratio = width_downsample_ratio
767+
self.max_text_len = max_text_len
768+
769+
def __call__(self, data):
770+
img = data["image"]
771+
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
772+
img, self.image_shape, self.width_downsample_ratio
773+
)
774+
valid_ratio = np.array(valid_ratio, dtype=np.float32)
775+
width_downsampled = int(self.image_shape[-1] * self.width_downsample_ratio)
776+
valid_width_mask = np.full([1, width_downsampled], 1)
777+
valid_width = min(width_downsampled, int(width_downsampled * valid_ratio + 0.5))
778+
valid_width_mask[:, valid_width:] = 0
779+
word_positons = np.array(range(0, self.max_text_len)).astype("int64")
780+
data["image"] = norm_img
781+
data["resized_shape"] = resize_shape
782+
data["pad_shape"] = pad_shape
783+
data["valid_ratio"] = valid_ratio
784+
data["valid_width_mask"] = valid_width_mask
785+
data["word_positions"] = word_positons
786+
return data
787+
788+
789+
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
790+
imgC, imgH, imgW_min, imgW_max = image_shape
791+
h = img.shape[0]
792+
w = img.shape[1]
793+
valid_ratio = 1.0
794+
# make sure new_width is an integral multiple of width_divisor.
795+
width_divisor = int(1 / width_downsample_ratio)
796+
# resize
797+
ratio = w / float(h)
798+
resize_w = math.ceil(imgH * ratio)
799+
if resize_w % width_divisor != 0:
800+
resize_w = round(resize_w / width_divisor) * width_divisor
801+
if imgW_min is not None:
802+
resize_w = max(imgW_min, resize_w)
803+
if imgW_max is not None:
804+
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
805+
resize_w = min(imgW_max, resize_w)
806+
resized_image = cv2.resize(img, (resize_w, imgH))
807+
resized_image = resized_image.astype("float32")
808+
# norm
809+
if image_shape[0] == 1:
810+
resized_image = resized_image / 255
811+
resized_image = resized_image[np.newaxis, :]
812+
else:
813+
resized_image = resized_image.transpose((2, 0, 1)) / 255
814+
resized_image -= 0.5
815+
resized_image /= 0.5
816+
resize_shape = resized_image.shape
817+
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
818+
padding_im[:, :, 0:resize_w] = resized_image
819+
pad_shape = padding_im.shape
820+
821+
return padding_im, resize_shape, pad_shape, valid_ratio

mindocr/data/transforms/transforms_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def run_transforms(data, transforms=None, verbose=False):
7171
)
7272

7373
if data is None:
74-
raise RuntimeError("Empty result is returned from transform `{transform}`")
74+
raise RuntimeError(f"Empty result is returned from transform `{transform}`")
7575
return data
7676

7777

0 commit comments

Comments
 (0)