Skip to content

Commit 96fff90

Browse files
safeandnewYHHaoyangLeeSongyuanwei
authored
add ABINet (#385)
* Add ABINet 0616 Add ABINet delete abinet_dataset Add ABINet 0616 delete abinet_dataset Add ABINet 0617 Add ABINet Transformer has been Simplified Add ABINet 0618 Add ABINet 0619 debug Add ABINet 0619 Add ABINet 0619 Add ABINet 0620 Add ABINet 0620 add README.md for ABINet Add ABINet 0620 add evaluation Add ABINet 0620 Add ABINet 0620 add README.md for ABINet * Add ABINet Modify README.md * Add ABINet 0616 Add ABINet delete abinet_dataset Add ABINet 0616 delete abinet_dataset Add ABINet 0617 Add ABINet Transformer has been Simplified Add ABINet 0618 Add ABINet 0619 debug Add ABINet 0619 Add ABINet 0619 Add ABINet 0620 Add ABINet 0620 add README.md for ABINet Add ABINet 0620 add evaluation Add ABINet 0620 Add ABINet 0620 add README.md for ABINet * Add ABINet Modify README.md * Add ABINet 0621 * Add ABINet 0621 * Add ABINet 0622 * Add ABINet 0623 * Add ABINet 0624 * Add ABINet 0624 add README_CN.md * Add ABINet 0626 * Add ABINet 0626 * Add ABINet 0626 * Add ABINet 0626 * Add ABINet 0628 * Add ABINet 0629 * Add ABINet 0629 * Add ABINet 0630 * Add ABINet 0630 * Add ABINet 0630 * Update ABINet README.md * Update ABINet README_CN.md * Update ABINet README.md * Update ABINet README_CN.md * Update README.md * Update README_CN.md * Update rec_abinet_transforms.py --------- Co-authored-by: abinet <zhangyh0123.mail.ustc.edu.cn> Co-authored-by: HaoyangLI <417493727@qq.com> Co-authored-by: Songyuanwei <52945530+Songyuanwei@users.noreply.github.com>
1 parent 041d061 commit 96fff90

File tree

16 files changed

+3397
-0
lines changed

16 files changed

+3397
-0
lines changed

configs/rec/abinet/README.md

Lines changed: 313 additions & 0 deletions
Large diffs are not rendered by default.

configs/rec/abinet/README_CN.md

Lines changed: 327 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
system:
2+
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
3+
distribute: True
4+
amp_level: 'O0'
5+
seed: 42
6+
log_interval: 100
7+
val_while_train: False
8+
drop_overflow_update: False
9+
10+
common:
11+
character_dict_path: &character_dict_path
12+
num_classes: &num_classes 37
13+
max_text_len: &max_text_len 25
14+
infer_mode: &infer_mode False
15+
use_space_char: &use_space_char False
16+
batch_size: &batch_size 96
17+
18+
model:
19+
type: rec
20+
pretrained : "./tmp_rec/pretrain.ckpt"
21+
transform: null
22+
backbone:
23+
name: abinet_backbone
24+
pretrained: False
25+
batchsize: *batch_size
26+
head:
27+
name: ABINetHead
28+
batchsize: *batch_size
29+
30+
postprocess:
31+
name: ABINetLabelDecode
32+
33+
metric:
34+
name: RecMetric
35+
main_indicator: acc
36+
character_dict_path: *character_dict_path
37+
ignore_space: True
38+
print_flag: False
39+
filter_ood: False
40+
41+
loss:
42+
name: ABINetLoss
43+
44+
45+
scheduler:
46+
scheduler: step_decay
47+
decay_rate: 0.1
48+
decay_epochs: 6
49+
warmup_epochs: 0
50+
lr: 0.0001
51+
num_epochs : 10
52+
53+
54+
optimizer:
55+
opt: adam
56+
57+
58+
train:
59+
clip_grad: True
60+
clip_norm: 20.0
61+
ckpt_save_dir: './tmp_rec'
62+
dataset_sink_mode: False
63+
dataset:
64+
type: LMDBDataset
65+
dataset_root: path/to/data_lmdb_release/
66+
data_dir: train/
67+
# label_files: # not required when using LMDBDataset
68+
sample_ratio: 1.0
69+
shuffle: True
70+
transform_pipeline:
71+
- ABINetTransforms:
72+
- ABINetRecAug:
73+
- NormalizeImage:
74+
is_hwc: False
75+
mean: [0.485, 0.456, 0.406]
76+
std: [0.485, 0.456, 0.406]
77+
# # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
78+
output_columns: ['image','label','length','label_for_mask'] #'img_path']
79+
80+
loader:
81+
shuffle: True # TODO: tbc
82+
batch_size: *batch_size
83+
drop_remainder: True
84+
max_rowsize: 128
85+
num_workers: 20
86+
87+
eval:
88+
ckpt_load_path: ./tmp_rec/best.ckpt
89+
dataset_sink_mode: False
90+
dataset:
91+
type: LMDBDataset
92+
dataset_root: path/to/data_lmdb_release/
93+
data_dir: evaluation/
94+
# label_files: # not required when using LMDBDataset
95+
sample_ratio: 1.0
96+
shuffle: False
97+
transform_pipeline:
98+
- ABINetEvalTransforms:
99+
- ABINetEval:
100+
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
101+
output_columns: ['image','label','length','label_for_mask'] # TODO return text string padding w/ fixed length, and a scaler to indicate the length
102+
net_input_column_index: [0] # input indices for network forward func in output_columns
103+
label_column_index: [1, 2] # input indices marked as label
104+
105+
loader:
106+
shuffle: False # TODO: tbc
107+
batch_size: *batch_size
108+
drop_remainder: False
109+
max_rowsize: 128
110+
num_workers: 8
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
"""
2+
transform for text recognition tasks.
3+
"""
4+
import copy
5+
import logging
6+
import random
7+
import re
8+
import warnings
9+
10+
import cv2
11+
import numpy as np
12+
import PIL
13+
import six
14+
from PIL import Image
15+
16+
import mindspore.dataset as ds
17+
18+
from ...models.utils.abinet_layers import CharsetMapper, onehot
19+
from .svtr_transform import (
20+
CVColorJitter,
21+
CVGaussianNoise,
22+
CVMotionBlur,
23+
CVRandomAffine,
24+
CVRandomPerspective,
25+
CVRandomRotation,
26+
CVRescale,
27+
)
28+
29+
_logger = logging.getLogger(__name__)
30+
__all__ = ["ABINetTransforms", "ABINetRecAug", "ABINetEval", "ABINetEvalTransforms"]
31+
32+
33+
class ABINetTransforms(object):
34+
"""Convert text label (str) to a sequence of character indices according to the char dictionary
35+
36+
Args:
37+
38+
"""
39+
40+
def __init__(
41+
self,
42+
):
43+
# ABINet_Transforms
44+
self.case_sensitive = False
45+
self.charset = CharsetMapper(max_length=26)
46+
47+
def __call__(self, data: dict):
48+
img_lmdb = data["img_lmdb"]
49+
label = data["label"]
50+
label = label.encode("utf-8")
51+
label = str(label, "utf-8")
52+
try:
53+
label = re.sub("[^0-9a-zA-Z]+", "", label)
54+
if len(label) > 25 or len(label) <= 0:
55+
string_false2 = f"len(label) > 25 or len(label) <= 0: {label}, {len(label)}"
56+
_logger.warning(string_false2)
57+
label = label[:25]
58+
buf = six.BytesIO()
59+
buf.write(img_lmdb)
60+
buf.seek(0)
61+
with warnings.catch_warnings():
62+
warnings.simplefilter("ignore", UserWarning)
63+
image = PIL.Image.open(buf).convert("RGB")
64+
if not _check_image(image, pixels=6):
65+
string_false1 = f"_check_image false: {label}, {len(label)}"
66+
_logger.warning(string_false1)
67+
except Exception:
68+
string_false = f"Corrupted image is found: {label}, {len(label)}"
69+
_logger.warning(string_false)
70+
71+
image = np.array(image)
72+
73+
text = label
74+
75+
length = len(text) + 1
76+
length = float(length)
77+
78+
label = self.charset.get_labels(text, case_sensitive=self.case_sensitive)
79+
label_for_mask = copy.deepcopy(label)
80+
label_for_mask[int(length - 1)] = 1
81+
label = onehot(label, self.charset.num_classes)
82+
data_dict = {"image": image, "label": label, "length": length, "label_for_mask": label_for_mask}
83+
return data_dict
84+
85+
86+
class ABINetRecAug(object):
87+
def __init__(self):
88+
self.transforms = ds.transforms.Compose(
89+
[
90+
CVGeometry(
91+
degrees=45,
92+
translate=(0.0, 0.0),
93+
scale=(0.5, 2.0),
94+
shear=(45, 15),
95+
distortion=0.5,
96+
p=0.5,
97+
),
98+
CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
99+
CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25),
100+
]
101+
)
102+
self.toTensor = ds.vision.ToTensor()
103+
self.w = 128
104+
self.h = 32
105+
106+
def __call__(self, data):
107+
img = data["image"]
108+
img = self.transforms(img)
109+
img = cv2.resize(img, (self.w, self.h))
110+
img = self.toTensor(img)
111+
data["image"] = img
112+
return data
113+
114+
115+
def _check_image(x, pixels=6):
116+
if x.size[0] <= pixels or x.size[1] <= pixels:
117+
return False
118+
else:
119+
return True
120+
121+
122+
class ABINetEvalTransforms(object):
123+
"""Convert text label (str) to a sequence of character indices according to the char dictionary
124+
125+
Args:
126+
127+
"""
128+
129+
def __init__(
130+
self,
131+
):
132+
# ABINet_Transforms
133+
self.case_sensitive = False
134+
self.charset = CharsetMapper(max_length=26)
135+
136+
def __call__(self, data: dict):
137+
img_lmdb = data["img_lmdb"]
138+
label = data["label"]
139+
label = label.encode("utf-8")
140+
label = str(label, "utf-8")
141+
try:
142+
label = re.sub("[^0-9a-zA-Z]+", "", label)
143+
if len(label) > 25 or len(label) <= 0:
144+
string_false2 = f"en(label) > 25 or len(label) <= 0: {label}, {len(label)}"
145+
_logger.warning(string_false2)
146+
label = label[:25]
147+
buf = six.BytesIO()
148+
buf.write(img_lmdb)
149+
buf.seek(0)
150+
with warnings.catch_warnings():
151+
warnings.simplefilter("ignore", UserWarning)
152+
image = PIL.Image.open(buf).convert("RGB")
153+
if not _check_image(image, pixels=6):
154+
string_false1 = f"_check_image false: {label}, {len(label)}"
155+
_logger.warning(string_false1)
156+
except Exception:
157+
string_false = f"Corrupted image is found: {label}, {len(label)}"
158+
_logger.warning(string_false)
159+
160+
image = np.array(image)
161+
162+
text = label
163+
length = len(text) + 1
164+
length = float(length)
165+
data_dict = {"image": image, "label": text, "length": length}
166+
return data_dict
167+
168+
169+
class ABINetEval(object):
170+
def __init__(self):
171+
self.toTensor = ds.vision.ToTensor()
172+
self.w = 128
173+
self.h = 32
174+
175+
def __call__(self, data):
176+
img = data["image"]
177+
img = cv2.resize(img, (self.w, self.h))
178+
img = self.toTensor(img)
179+
data["image"] = img
180+
length = data["length"]
181+
length = int(length)
182+
data["length"] = length
183+
return data
184+
185+
186+
class CVGeometry(object):
187+
def __init__(self, degrees=15, translate=(0.3, 0.3), scale=(0.5, 2.0), shear=(45, 15), distortion=0.5, p=0.5):
188+
self.p = p
189+
type_p = random.random()
190+
if type_p < 0.33:
191+
self.transforms = CVRandomRotation(degrees=degrees)
192+
elif type_p < 0.66:
193+
self.transforms = CVRandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear)
194+
else:
195+
self.transforms = CVRandomPerspective(distortion=distortion)
196+
197+
def __call__(self, img):
198+
if random.random() < self.p:
199+
img = np.array(img)
200+
return Image.fromarray(self.transforms(img))
201+
else:
202+
return img
203+
204+
205+
class CVDeterioration(object):
206+
def __init__(self, var, degrees, factor, p=0.5):
207+
self.p = p
208+
transforms = []
209+
if var is not None:
210+
transforms.append(CVGaussianNoise(var=var))
211+
if degrees is not None:
212+
transforms.append(CVMotionBlur(degrees=degrees))
213+
if factor is not None:
214+
transforms.append(CVRescale(factor=factor))
215+
216+
random.shuffle(transforms)
217+
218+
transforms = ds.transforms.Compose(transforms)
219+
self.transforms = transforms
220+
221+
def __call__(self, img):
222+
if random.random() < self.p:
223+
img = np.array(img)
224+
return Image.fromarray(self.transforms(img))
225+
else:
226+
return img

mindocr/data/transforms/transforms_factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .det_fce_transforms import *
1111
from .det_transforms import *
1212
from .general_transforms import *
13+
from .rec_abinet_transforms import *
1314
from .rec_transforms import *
1415
from .svtr_transform import *
1516

0 commit comments

Comments
 (0)