Skip to content

Commit 027e50b

Browse files
authored
visionlan gpu updates (#459)
1 parent 567c710 commit 027e50b

25 files changed

+2110
-44
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ For more illustration and usage, please refer to the model training section in [
168168
- [x] [CRNN-Seq2Seq/RARE](configs/rec/rare/README.md) (CVPR'2016)
169169
- [x] [SVTR](configs/rec/svtr/README.md) (IJCAI'2022)
170170
- [x] [MASTER](configs/rec/master/README.md) (PR'2019)
171+
- [x] [VISIONLAN](configs/rec/visionlan/README.md) (ICCV'2021)
171172
- [ ] [ABINet](https://arxiv.org/abs/2103.06495) (CVPR'2021) [coming soon]
172173

173174
</details>
@@ -212,10 +213,13 @@ We will include more datasets for training and evaluation. This list will be con
212213
## Notes
213214

214215
### What is New
216+
- 2023/07/05
217+
1. Add new trained models
218+
- [VISIONLAN](configs/rec/visionlan) for text recognition
215219
- 2023/06/29
216220
1. Add new trained models
217-
- [FCENet](configs/det/facenet) for text detection
218-
- [MASTER](configs/det/facenet) for text recognition
221+
- [FCENet](configs/det/fcenet) for text detection
222+
- [MASTER](configs/rec/master) for text recognition
219223
- 2023/06/07
220224
1. Add new trained models
221225
- [PSENet](configs/det/psenet) for text detection

README_CN.md

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,20 +149,21 @@ python tools/eval.py \
149149
<details open markdown>
150150
<summary>文本检测</summary>
151151

152-
- [x] [DBNet](configs/det/dbnet/README.md) (AAAI'2020)
153-
- [x] [DBNet++](configs/det/dbnet/README.md) (TPAMI'2022)
154-
- [x] [PSENet](configs/det/psenet/README.md) (CVPR'2019)
155-
- [x] [EAST](configs/det/east/README.md)(CVPR'2017)
156-
- [x] [FCENet](configs/det/fcenet/README.md) (CVPR'2021)
152+
- [x] [DBNet](configs/det/dbnet/README_CN.md) (AAAI'2020)
153+
- [x] [DBNet++](configs/det/dbnet/README_CN.md) (TPAMI'2022)
154+
- [x] [PSENet](configs/det/psenet/README_CN.md) (CVPR'2019)
155+
- [x] [EAST](configs/det/east/README_CN.md)(CVPR'2017)
156+
- [x] [FCENet](configs/det/fcenet/README_CN.md) (CVPR'2021)
157157
</details>
158158
159159
<details open markdown>
160160
<summary>文本识别</summary>
161161
162-
- [x] [CRNN](configs/rec/crnn/README.md) (TPAMI'2016)
163-
- [x] [CRNN-Seq2Seq/RARE](configs/rec/rare/README.md) (CVPR'2016)
164-
- [x] [SVTR](configs/rec/svtr/README.md) (IJCAI'2022)
165-
- [x] [MASTER](configs/rec/master/README.md) (PR'2019)
162+
- [x] [CRNN](configs/rec/crnn/README_CN.md) (TPAMI'2016)
163+
- [x] [CRNN-Seq2Seq/RARE](configs/rec/rare/README_CN.md) (CVPR'2016)
164+
- [x] [SVTR](configs/rec/svtr/README_CN.md) (IJCAI'2022)
165+
- [x] [MASTER](configs/rec/master/README_CN.md) (PR'2019)
166+
- [x] [VISIONLAN](configs/rec/visionlan/README_CN.md) (ICCV'2021)
166167
- [ ] [ABINet](https://arxiv.org/abs/2103.06495) (CVPR'2021) [coming soon]
167168
</details>
168169
@@ -207,10 +208,13 @@ MindOCR提供了[数据格式转换工具](tools/dataset_converters) ,以支
207208
## 重要信息
208209
209210
### 更新日志
211+
- 2023/07/05
212+
1. 增加新模型
213+
- 文本识别[VISIONLAN](configs/rec/visionlan)
210214
- 2023/06/29
211215
1. 新增2个SoTA模型
212-
- 文本检测[FCENet](configs/det/facenet)
213-
- 文本识别[MASTER](configs/det/facenet)
216+
- 文本检测[FCENet](configs/det/fcenet)
217+
- 文本识别[MASTER](configs/rec/master)
214218
- 2023/06/07
215219
1. 增加新模型
216220
- 文本检测[PSENet](configs/det/psenet)

configs/rec/visionlan/README.md

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
English | [中文](README_CN.md)
2+
3+
# VisionLAN
4+
5+
<!--- Guideline: use url linked to abstract in ArXiv instead of PDF for fast loading. -->
6+
7+
> VisionLAN: [From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network](https://arxiv.org/abs/2108.09661)
8+
9+
## 1. Introduction
10+
11+
### 1.1 VisionLAN
12+
13+
Visual Language Modeling Network (VisionLAN) [<a href="#5-references">1</a>] is a text recognion model that learns the visual and linguistic information simultaneously via **character-wise occluded feature maps** in the training stage. This model does not require an extra language model to extract linguistic information, since the visual and linguistic information can be learned as a union.
14+
15+
<!--- Guideline: If an architecture table/figure is available in the paper, put one here and cite for intuitive illustration. -->
16+
<p align="center">
17+
<img src="https://raw.githubusercontent.com/wtomin/mindocr-asset/main/images/visionlan_architecture.PNG" width=450 />
18+
</p>
19+
<p align="center">
20+
<em> Figure 1. The architecture of visionlan [<a href="#5-references">1</a>] </em>
21+
</p>
22+
23+
24+
25+
As shown above, the training pipeline of VisionLAN consists of three modules:
26+
27+
- The backbone extract visual feature maps from the input image;
28+
29+
- The Masked Language-aware Module (MLM) takes the visual feature maps and a randomly selected character index as inputs, and generates position-aware character mask map to create character-wise occluded feature maps;
30+
31+
- Finally, the Visual Reasonin Module (VRM) takes occluded feature maps as inputs and makes prediction under the complete word-level supervision.
32+
33+
While in the test stage, MLM is not used. Only the backbone and VRM are used for prediction.
34+
35+
## 2. Results
36+
<!--- Guideline:
37+
Table Format:
38+
- Model: model name in lower case with _ seperator.
39+
- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode.
40+
- Top-1 and Top-5: Keep 2 digits after the decimal point.
41+
- Params (M): # of model parameters in millions (10^6). Keep 2 digits after the decimal point
42+
- Recipe: Training recipe/configuration linked to a yaml config file. Use absolute url path.
43+
- Download: url of the pretrained model weights. Use absolute url path.
44+
-->
45+
46+
### 2.1 Accuracy
47+
48+
According to our experiments, the evaluation results on ten public benchmark datasets is as follow:
49+
50+
<div align="center">
51+
52+
| **Model** | **Context** | **Backbone**| **Train Dataset** | **Model Params **|**Avg Accuracy** | **Train Time** | **FPS** | **Recipe** | **Download** |
53+
| :-----: | :-----------: | :--------------: | :----------: | :--------: | :--------: |:----------: |:--------: | :--------: |:----------: |
54+
| visionlan | D910x4-MS2.0-G | resnet45 | MJ+ST| 42.2M | 90.61% | 7718s/epoch | 1,840 | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/visionlan/visionlan_resnet45_LF_1.yaml) | [ckpt files](https://download.mindspore.cn/toolkits/mindocr/visionlan/visionlan_resnet45_ckpts-7d6e9c04.tar.gz) |
55+
56+
</div>
57+
58+
<details open markdown>
59+
<div align="center">
60+
<summary>Detailed accuracy results for ten benchmark datasets</summary>
61+
62+
| **Model** | **Context** | **IC03_860**| **IC03_867**| **IC13_857**|**IC13_1015** | **IC15_1811** |**IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **Average** |
63+
| :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: |:------: |:------: | :------: |:------: |
64+
| visionlan | D910x4-MS2.0-G | 96.16% | 95.16% | 95.92%| 94.19% | 84.04% | 77.46% | 95.53% | 92.27% | 85.74% |89.58% | 90.61% |
65+
66+
</div>
67+
68+
</details>
69+
70+
**Notes:**
71+
72+
- Context: Training context denoted as `{device}x{pieces}-{MS version}-{MS mode}`. Mindspore mode can be either `G` (graph mode) or `F` (pynative mode). For example, `D910x4-MS2.0-G` denotes training on 4 pieces of 910 NPUs using graph mode based on MindSpore version 2.0.0.
73+
- Train datasets: MJ+ST stands for the combination of two synthetic datasets, SynthText(800k) and MJSynth.
74+
- To reproduce the result on other contexts, please ensure the global batch size is the same.
75+
- The models are trained from scratch without any pre-training. For more dataset details of training and evaluation, please refer to [3.2 Dataset preparation](#32-dataset-preparation) section.
76+
77+
78+
## 3. Quick Start
79+
80+
### 3.1 Installation
81+
82+
Please refer to the [installation instruction](https://github.com/mindspore-lab/mindocr#installation) in MindOCR.
83+
84+
### 3.2 Dataset preparation
85+
86+
* Training sets
87+
88+
The authors of VisionLAN used two synthetic text datasets for training: SynthText(800k) and MJSynth. Please follow the instructions of the [original VisionLAN repository](https://github.com/wangyuxin87/VisionLAN) to download the training sets.
89+
90+
After download `SynthText.zip` and `MJSynth.zip`, please unzip and place them under `./datasets/train`. The training set contain 14,200,701 samples in total. More details are as follows:
91+
92+
93+
> [SynText](http://www.robots.ox.ac.uk/~vgg/data/scenetext/): 25GB, 6,976,115 samples<br>
94+
[MJSynth](http://www.robots.ox.ac.uk/~vgg/data/text/): 21GB, 7,224,586 samples
95+
96+
* Validation sets
97+
98+
The authors of VisionLAN used six real text datasets for evaluation: IIIT5K Words (IIIT5K_3000) ICDAR 2013 (IC13_857), Street View Text (SVT), ICDAR 2015 (IC15), Street View Text-Perspective (SVTP), CUTE80 (CUTE). We used the sum of the six benchmarks as validation sets. Please follow the instructions of the [original VisionLAN repository](https://github.com/wangyuxin87/VisionLAN) to download the validation sets.
99+
100+
After download `evaluation.zip`, please unzip this zip file, and place them under `./datasets`. Under `./datasets/evaluation`, there are seven folders:
101+
102+
103+
> [IIIT5K](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html): 50M, 3000 samples<br>
104+
[IC13](http://rrc.cvc.uab.es/?ch=2): 72M, 857 samples<br>
105+
[SVT](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset): 2.4M, 647 samples<br>
106+
[IC15](http://rrc.cvc.uab.es/?ch=4): 21M, 1811 samples<br>
107+
[SVTP](http://openaccess.thecvf.com/content_iccv_2013/papers/Phan_Recognizing_Text_with_2013_ICCV_paper.pdf): 1.8M, 645 samples<br>
108+
[CUTE](http://cs-chan.com/downloads_CUTE80_dataset.html): 8.8M, 288 samples<br>
109+
Sumof6benchmarks: 155M, 7248 samples
110+
111+
During training, we only used the data under `./datasets/evaluation/Sumof6benchmarks` as the validation sets. Users can delete the other folders `./datasets/evaluation` optionally.
112+
113+
114+
* Test Sets
115+
116+
We choose ten benchmarks as the test sets to evaluate the model's performance. Users can download the test sets from [here](https://www.dropbox.com/sh/i39abvnefllx2si/AAAbAYRvxzRp3cIE5HzqUw3ra?dl=0) (ref: [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here)). Only the `evaluation.zip` is required for testing.
117+
118+
After downloading the `evaluation.zip`, please unzip it, and rename the folder name from `evaluation` to `test`. Please place this folder under `./datasets/`.
119+
120+
The test sets contain 12,067 samples in total. The detailed information is as follows:
121+
122+
123+
> [CUTE80](http://cs-chan.com/downloads_CUTE80_dataset.html): 8.8 MB, 288 samples<br>
124+
[IC03_860](http://www.iapr-tc11.org/mediawiki/index.php/ICDAR_2003_Robust_Reading_Competitions): 36 MB, 860 samples<br>
125+
[IC03_867](http://www.iapr-tc11.org/mediawiki/index.php/ICDAR_2003_Robust_Reading_Competitions): 4.9 MB, 867 samples<br>
126+
[IC13_857](http://rrc.cvc.uab.es/?ch=2): 72 MB, 857 samples<br>
127+
[IC13_1015](http://rrc.cvc.uab.es/?ch=2): 77 MB, 1015 samples<br>
128+
[IC15_1811](http://rrc.cvc.uab.es/?ch=4): 21 MB, 1811 samples<br>
129+
[IC15_2077](http://rrc.cvc.uab.es/?ch=4): 25 MB, 2077 samples<br>
130+
[IIIT5k_3000](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html): 50 MB, 3000 samples<br>
131+
[SVT](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset): 2.4 MB, 647 samples<br>
132+
[SVTP](http://openaccess.thecvf.com/content_iccv_2013/papers/Phan_Recognizing_Text_with_2013_ICCV_paper.pdf): 1.8 MB, 645 samples
133+
134+
135+
In the end of preparation, the file structure should be like:
136+
137+
``` text
138+
datasets
139+
├── test
140+
│ ├── CUTE80
141+
│ ├── IC03_860
142+
│ ├── IC03_867
143+
│ ├── IC13_857
144+
│ ├── IC13_1015
145+
│ ├── IC15_1811
146+
│ ├── IC15_2077
147+
│ ├── IIIT5k_3000
148+
│ ├── SVT
149+
│ ├── SVTP
150+
├── evaluation
151+
│ ├── Sumof6benchmarks
152+
│ ├── ...
153+
└── train
154+
├── MJSynth
155+
└── SynText
156+
```
157+
158+
### 3.3 Update yaml config file
159+
160+
If the datasets are placed under `./datasets`, there is no need to change the `train.dataset.dataset_root` in the yaml configuration file `configs/rec/visionlan/visionlan_L*.yaml`.
161+
162+
Otherwise, change the following fields accordingly:
163+
164+
```yaml
165+
...
166+
train:
167+
dataset_sink_mode: False
168+
dataset:
169+
type: LMDBDataset
170+
dataset_root: dir/to/dataset <--- Update
171+
data_dir: train <--- Update
172+
...
173+
eval:
174+
dataset_sink_mode: False
175+
dataset:
176+
type: LMDBDataset
177+
dataset_root: dir/to/dataset <--- Update
178+
data_dir: evaluation/Sumof6benchmarks <--- Update
179+
...
180+
```
181+
182+
> Optionally, change `train.loader.num_workers` according to the cores of CPU.
183+
184+
185+
Apart from the dataset setting, please also check the following important args: `system.distribute`, `system.val_while_train`, `common.batch_size`. Explanations of these important args:
186+
187+
```yaml
188+
system:
189+
distribute: True # `True` for distributed training, `False` for standalone training
190+
amp_level: 'O0'
191+
seed: 42
192+
val_while_train: True # Validate while training
193+
common:
194+
...
195+
batch_size: &batch_size 192 # Batch size for training
196+
...
197+
loader:
198+
shuffle: False
199+
batch_size: 64 # Batch size for validation/evaluation
200+
...
201+
```
202+
203+
**Notes:**
204+
- As the global batch size (batch_size x num_devices) is important for reproducing the result, please adjust `batch_size` accordingly to keep the global batch size unchanged for a different number of GPUs/NPUs, or adjust the learning rate linearly to a new global batch size.
205+
206+
207+
### 3.4 Training
208+
209+
The training stages include Language-free (LF) and Language-aware (LA) process, and in total three steps for training:
210+
211+
```text
212+
LF_1: train backbone and VRM, without training MLM
213+
LF_2: train MLM and finetune backbone and VRM
214+
LA: using the mask generated by MLM to occlude feature maps, train backbone, MLM, and VRM
215+
```
216+
217+
We used distributed training for the three steps. For standalone training, please refer to the [recognition tutorial](../../../docs/en/tutorials/training_recognition_custom_dataset.md#model-training-and-evaluation).
218+
219+
```shell
220+
mpirun --allow-run-as-root -n 4 python tools/train.py --config configs/rec/visionlan/visionlan_resnet45_LF_1.yaml
221+
mpirun --allow-run-as-root -n 4 python tools/train.py --config configs/rec/visionlan/visionlan_resnet45_LF_2.yaml
222+
mpirun --allow-run-as-root -n 4 python tools/train.py --config configs/rec/visionlan/visionlan_resnet45_LA.yaml
223+
```
224+
225+
The training result (including checkpoints, per-epoch performance and curves) will be saved in the directory parsed by the arg `ckpt_save_dir` in yaml config file. The default directory is `./tmp_visionlan`.
226+
227+
228+
### 3.5 Test
229+
230+
After all three steps training, change the `system.distribute` to `False` in `configs/rec/visionlan/visionlan_resnet45_LA.yaml` before testing.
231+
232+
To evaluate the model's accuracy, users can choose from two options:
233+
234+
- Option 1: Repeat the evaluation step for all individual datasets: CUTE80, IC03_860, IC03_867, IC13_857, IC131015, IC15_1811, IC15_2077, IIIT5k_3000, SVT, SVTP. Then take the average score.
235+
236+
An example of evaluation script fort the CUTE80 dataset is shown below.
237+
```shell
238+
model_name="e8"
239+
yaml_file="configs/rec/visionlan/visionlan_resnet45_LA.yaml"
240+
training_step="LA"
241+
242+
python tools/eval.py --config $yaml_file --opt eval.dataset.data_dir=test/CUTE80 eval.ckpt_load_path="./tmp_visionlan/${training_step}/${model_name}.ckpt"
243+
244+
```
245+
246+
- Option 2: Given that all the benchmark datasets folder are under the same directory, e.g. `test/`. And use the script `tools/benchmarking/multi_dataset_eval.py`. The example evaluation script is like:
247+
248+
```shell
249+
model_name="e8"
250+
yaml_file="configs/rec/visionlan/visionlan_resnet45_LA.yaml"
251+
training_step="LA"
252+
253+
python tools/benchmarking/multi_dataset_eval.py --config $yaml_file --opt eval.dataset.data_dir="test" eval.ckpt_load_path="./tmp_visionlan/${training_step}/${model_name}.ckpt"
254+
```
255+
256+
257+
## 4. Inference
258+
259+
Coming Soon...
260+
261+
262+
## 5. References
263+
<!--- Guideline: Citation format GB/T 7714 is suggested. -->
264+
265+
[1] Yuxin Wang, Hongtao Xie, Shancheng Fang, Jing Wang, Shenggao Zhu, Yongdong Zhang: From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network. ICCV 2021: 14174-14183

0 commit comments

Comments
 (0)