Skip to content

Commit e6aebcf

Browse files
zhtmikehadipash
andauthored
Move save config before training start log (#478)
* Move save config before training start log * Update tools/train.py Co-authored-by: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> --------- Co-authored-by: Rustam Khadipash <16683750+hadipash@users.noreply.github.com>
1 parent a60dc65 commit e6aebcf

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

tools/train.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ def main(cfg):
195195
start_epoch=start_epoch,
196196
)
197197

198+
# save args used for training
199+
if rank_id in [None, 0]:
200+
with open(os.path.join(cfg.train.ckpt_save_dir, "args.yaml"), "w") as f:
201+
yaml.safe_dump(cfg.to_dict(), stream=f, default_flow_style=False, sort_keys=False)
202+
198203
# log
199204
num_devices = device_num if device_num is not None else 1
200205
global_batch_size = cfg.train.loader.batch_size * num_devices * gradient_accumulation_steps
@@ -229,12 +234,6 @@ def main(cfg):
229234
f"\nStart training... (The first epoch takes longer, please wait...)\n"
230235
)
231236

232-
# save args used for training
233-
if rank_id in [None, 0]:
234-
with open(os.path.join(cfg.train.ckpt_save_dir, "args.yaml"), "w") as f:
235-
args_text = yaml.safe_dump(cfg.to_dict(), default_flow_style=False, sort_keys=False)
236-
f.write(args_text)
237-
238237
# training
239238
model = ms.Model(train_net)
240239
model.train(

0 commit comments

Comments
 (0)