66import torch
77from torch import nn , optim
88from torch .utils .data import DataLoader
9- import argparse , time , os , sys , yaml
9+ import argparse , time , os , sys , yaml , logging
1010
1111def add_args_parser ():
1212 parser = argparse .ArgumentParser (add_help = False )
1313 parser .add_argument ('--config' , type = str )
1414
1515 return parser
16+
17+ def get_logger (expr_name ):
18+ logger = logging .getLogger ('train' )
19+ logger .setLevel (logging .INFO )
20+
21+ formatter = logging .Formatter (
22+ "%(asctime)s | %(levelname)s | %(message)s"
23+ )
24+
25+ file_h = logging .FileHandler (f"logs/{ expr_name } .log" , mode = 'w' )
26+ file_h .setLevel (logging .INFO )
27+ file_h .setFormatter (formatter )
28+ logger .addHandler (file_h )
29+
30+ console_h = logging .StreamHandler ()
31+ console_h .setLevel (logging .INFO )
32+ logger .addHandler (console_h )
33+
34+ return logger
1635
1736def main (cfg ):
18- print (f"=====================[{ cfg ['expr' ]} ]=====================" )
37+ logger = get_logger (cfg ['expr' ])
38+ logger .info (f"=====================[{ cfg ['expr' ]} ]=====================" )
1939
2040 # Device Setting
2141 device = None
2242 if cfg ['device' ] != 'cpu' and torch .cuda .is_available ():
2343 device = cfg ['device' ]
2444 else :
2545 device = 'cpu'
26- print (f"device: { device } " )
46+ logger . info (f"device: { device } " )
2747
2848 # Hyperparameter Settings
2949 hp_cfg = cfg ['hyperparameters' ]
@@ -35,11 +55,11 @@ def main(cfg):
3555 shuffle = True ,
3656 batch_size = hp_cfg ['batch_size' ],
3757 drop_last = True )
38- print (f"Load Dataset { data_cfg ['dataset' ]} " )
58+ logger . info (f"Load Dataset { data_cfg ['dataset' ]} " )
3959
4060 # Load Model
4161 model_cfg = cfg ['model' ]
42- print (model_cfg ['name' ])
62+ logger . info (model_cfg ['name' ])
4363 model = load_model (model_cfg ).to (device )
4464 if cfg ['parallel' ] == True :
4565 model = nn .DataParallel (model )
@@ -73,24 +93,24 @@ def main(cfg):
7393 min_loss = 1e4
7494
7595 for current_epoch in range (1 , hp_cfg ['epochs' ]+ 1 ):
76- print ("=======================================================" )
77- print (f"Epoch: [{ current_epoch :03d} /{ hp_cfg ['epochs' ]:03d} ]\n " )
96+ logger . info ("=======================================================" )
97+ logger . info (f"Epoch: [{ current_epoch :03d} /{ hp_cfg ['epochs' ]:03d} ]\n " )
7898
7999 # Training One Epoch
80100 start_time = int (time .time ())
81- train_loss = train_one_epoch (model , train_dl , loss_fn , optimizer , scheduler , device )
101+ train_loss = train_one_epoch (model , train_dl , loss_fn , optimizer , scheduler , device , logger )
82102 elapsed_time = int (time .time () - start_time )
83- print (f"Train Time: { elapsed_time // 60 :02d} m { elapsed_time % 60 :02d} s\n " )
103+ logger . info (f"Train Time: { elapsed_time // 60 :02d} m { elapsed_time % 60 :02d} s" )
84104
85105 if train_loss < min_loss :
86106 min_loss = train_loss
87- save_model_ckpt (model , save_cfg ['name' ], current_epoch , save_cfg ['weights_path' ])
107+ save_model_ckpt (model , save_cfg ['name' ], current_epoch , save_cfg ['weights_path' ], logger )
88108
89109 total_train_loss .append (train_loss )
90- save_loss_ckpt (save_cfg ['name' ], total_train_loss , save_cfg ['loss_path' ])
110+ save_loss_ckpt (save_cfg ['name' ], total_train_loss , save_cfg ['loss_path' ], logger )
91111
92112 total_elapsed_time = int (time .time ()) - total_start_time
93- print (f"<Total Train Time: { total_elapsed_time // 60 :02d} m { total_elapsed_time % 60 :02d} s>" )
113+ logger . info (f"<Total Train Time: { total_elapsed_time // 60 :02d} m { total_elapsed_time % 60 :02d} s>" )
94114
95115if __name__ == '__main__' :
96116 parser = argparse .ArgumentParser ('Training' , parents = [add_args_parser ()])
0 commit comments