|
12 | 12 | import sys |
13 | 13 | import glob |
14 | 14 | import yaml |
| 15 | +import pytest |
15 | 16 |
|
16 | 17 | sys.path.append(".") |
17 | 18 |
|
18 | | -import pytest |
19 | | - |
| 19 | +from tests.ut._common import gen_dummpy_data, update_config_for_CI |
20 | 20 | from mindocr.models.backbones.mindcv_models.download import DownLoad |
21 | 21 |
|
22 | 22 |
|
23 | 23 | @pytest.mark.parametrize("task", ["det", "rec"]) |
24 | 24 | @pytest.mark.parametrize("val_while_train", [False, True]) |
25 | 25 | def test_train_eval(task, val_while_train): |
| 26 | + |
26 | 27 | # prepare dummy images |
27 | | - data_dir = "data/Canidae" |
28 | | - dataset_url = ( |
29 | | - "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip" |
30 | | - ) |
31 | | - if not os.path.exists(data_dir): |
32 | | - DownLoad().download_and_extract_archive(dataset_url, "./") |
33 | | - |
34 | | - # prepare dummy labels |
35 | | - for split in ['train', 'val']: |
36 | | - label_path = f'tests/st/dummy_labels/{task}_{split}_gt.txt' |
37 | | - image_dir = f'{data_dir}/{split}/dogs' |
38 | | - new_label_path = f'data/Canidae/{split}/{task}_gt.txt' |
39 | | - img_paths = glob.glob(os.path.join(image_dir, '*.JPEG')) |
40 | | - #print(len(img_paths)) |
41 | | - with open(new_label_path, 'w') as f_w: |
42 | | - with open(label_path, 'r') as f_r: |
43 | | - i = 0 |
44 | | - for line in f_r: |
45 | | - _, label = line.strip().split('\t') |
46 | | - #print(i) |
47 | | - img_name = os.path.basename(img_paths[i]) |
48 | | - new_img_label = img_name + '\t' + label |
49 | | - f_w.write(new_img_label + '\n') |
50 | | - i += 1 |
51 | | - print(f'Dummpy annotation file is generated in {new_label_path}') |
52 | | - |
53 | | - # modify ocr predefined yaml for minimum test |
| 28 | + data_dir = gen_dummpy_data(task) |
| 29 | + |
| 30 | + # modify ocr predefined yaml for minimum test |
54 | 31 | if task == 'det': |
55 | 32 | config_fp = 'configs/det/dbnet/db_r50_icdar15.yaml' |
56 | 33 | elif task=='rec': |
57 | 34 | #config_fp = 'configs/rec/vgg7_bilstm_ctc.yaml' # TODO: change on lmdb datasset |
58 | 35 | config_fp = 'configs/rec/crnn/crnn_icdar15.yaml' |
59 | 36 |
|
60 | | - with open(config_fp) as fp: |
61 | | - config = yaml.safe_load(fp) |
62 | | - config['system']['distribute'] = False |
63 | | - config['system']['val_while_train'] = val_while_train |
64 | | - #if 'common' in config: |
65 | | - # config['batch_size'] = 8 |
66 | | - config['train']['dataset_sink_mode'] = False |
67 | | - |
68 | | - config['train']['dataset']['dataset_root'] = 'data/Canidae/' |
69 | | - config['train']['dataset']['data_dir'] = 'train/dogs' |
70 | | - config['train']['dataset']['label_file'] = f'train/{task}_gt.txt' |
71 | | - config['train']['dataset']['sample_ratio'] = 0.1 # TODO: 120 training samples in total, don't be larger than batchsize after sampling |
72 | | - config['train']['loader']['num_workers'] = 1 # github server only support 2 workers at most |
73 | | - #if config['train']['loader']['batch_size'] > 120: |
74 | | - config['train']['loader']['batch_size'] = 2 # to save memory |
75 | | - config['train']['loader']['max_rowsize'] = 16 # to save memory |
76 | | - config['train']['loader']['prefetch_size'] = 2 # to save memory |
77 | | - if 'common' in config: |
78 | | - config['common']['batch_size'] = 2 |
79 | | - if 'batch_size' in config['loss']: |
80 | | - config['loss']['batch_size'] = 2 |
81 | | - |
82 | | - config['eval']['dataset']['dataset_root'] = 'data/Canidae/' |
83 | | - config['eval']['dataset']['data_dir'] = 'val/dogs' |
84 | | - config['eval']['dataset']['label_file'] = f'val/{task}_gt.txt' |
85 | | - config['eval']['dataset']['sample_ratio'] = 0.1 |
86 | | - config['eval']['loader']['num_workers'] = 1 # github server only support 2 workers at most |
87 | | - config['eval']['loader']['batch_size'] = 1 |
88 | | - config['eval']['loader']['max_rowsize'] = 16 # to save memory |
89 | | - config['eval']['loader']['prefetch_size'] = 2 # to save memory |
90 | | - |
91 | | - config['eval']['ckpt_load_path'] = os.path.join(config['train']['ckpt_save_dir'], 'best.ckpt') |
92 | | - |
93 | | - config['scheduler']['num_epochs'] = 2 |
94 | | - config['scheduler']['warmup_epochs'] = 1 |
95 | | - config['scheduler']['decay_epochs'] = 1 |
96 | | - |
97 | | - dummpy_config_fp =os.path.join('tests/st', os.path.basename(config_fp.replace('.yaml', '_dummpy.yaml'))) |
98 | | - with open(dummpy_config_fp, 'w') as f: |
99 | | - args_text = yaml.safe_dump(config, default_flow_style=False, sort_keys=False) |
100 | | - f.write(args_text) |
101 | | - print('Genearted yaml: ') |
102 | | - print(args_text) |
103 | | - |
| 37 | + dummpy_config_fp = update_config_for_CI(config_fp, task) |
104 | 38 |
|
105 | 39 | #dummpy_config_fp = 'tests/st/rec_crnn_test.yaml' |
106 | 40 | # ---------------- test running train.py using the toy data --------- |
@@ -129,5 +63,5 @@ def test_train_eval(task, val_while_train): |
129 | 63 |
|
130 | 64 |
|
131 | 65 | if __name__ == '__main__': |
132 | | - test_train_eval('det', True) |
133 | | - #test_train_eval('rec', True) |
| 66 | + #test_train_eval('det', True) |
| 67 | + test_train_eval('rec', True) |
0 commit comments