Skip to content

Commit b5ed5a5

Browse files
committed
[Add] logging & logs directory
1 parent 036b089 commit b5ed5a5

File tree

5 files changed

+69
-25
lines changed

5 files changed

+69
-25
lines changed

logs/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.log

test.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
import argparse
4-
import time, sys, os, yaml
4+
import time, sys, os, yaml, logging
55

66
from utils import evaluate
77
from models import load_model
@@ -13,16 +13,36 @@ def add_args_parser():
1313

1414
return parser
1515

16+
def get_logger(expr_name):
17+
logger = logging.getLogger('test')
18+
logger.setLevel(logging.INFO)
19+
20+
formatter = logging.Formatter(
21+
"%(asctime)s | %(levelname)s | %(message)s"
22+
)
23+
24+
file_h = logging.FileHandler(f"logs/{expr_name}.log", mode='w')
25+
file_h.setLevel(logging.INFO)
26+
file_h.setFormatter(formatter)
27+
logger.addHandler(file_h)
28+
29+
console_h = logging.StreamHandler()
30+
console_h.setLevel(logging.INFO)
31+
logger.addHandler(console_h)
32+
33+
return logger
34+
1635
def main(cfg):
17-
print(f"=====================[{cfg['expr']}]=====================")
36+
logger = get_logger(cfg['expr'])
37+
logger.info(f"=====================[{cfg['expr']}]=====================")
1838

1939
# Device Setting
2040
device = None
2141
if cfg['device'] != 'cpu' and torch.cuda.is_available():
2242
device = cfg['device']
2343
else:
2444
device = 'cpu'
25-
print(f"device: {device}")
45+
logger.info(f"device: {device}")
2646

2747
# Hyperparameter Settings
2848
hp_cfg = cfg['hyperparameters']
@@ -32,7 +52,7 @@ def main(cfg):
3252
test_ds = load_dataset(data_cfg)
3353
test_dl = torch.utils.data.DataLoader(test_ds,
3454
batch_size=hp_cfg['batch_size'])
35-
print(f"Load Dataset {data_cfg['dataset']}")
55+
logger.info(f"Load Dataset {data_cfg['dataset']}")
3656

3757
# Load Model
3858
save_cfg = cfg['save']
@@ -45,10 +65,10 @@ def main(cfg):
4565
start_time = int(time.time())
4666
result = evaluate(model, test_dl, device)
4767
test_time = int(time.time() - start_time)
48-
print(f"Test Time: {test_time//60:02d}m {test_time%60:02d}s")
68+
logger.info(f"Test Time: {test_time//60:02d}m {test_time%60:02d}s")
4969

5070
for key, value in result.items():
51-
print(f"{key}: {value:.4f}")
71+
logger.info(f"{key}: {value:.4f}")
5272

5373
if __name__ == '__main__':
5474
parser = argparse.ArgumentParser('Test', parents=[add_args_parser()])

train.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,44 @@
66
import torch
77
from torch import nn, optim
88
from torch.utils.data import DataLoader
9-
import argparse, time, os, sys, yaml
9+
import argparse, time, os, sys, yaml, logging
1010

1111
def add_args_parser():
1212
parser = argparse.ArgumentParser(add_help=False)
1313
parser.add_argument('--config', type=str)
1414

1515
return parser
16+
17+
def get_logger(expr_name):
18+
logger = logging.getLogger('train')
19+
logger.setLevel(logging.INFO)
20+
21+
formatter = logging.Formatter(
22+
"%(asctime)s | %(levelname)s | %(message)s"
23+
)
24+
25+
file_h = logging.FileHandler(f"logs/{expr_name}.log", mode='w')
26+
file_h.setLevel(logging.INFO)
27+
file_h.setFormatter(formatter)
28+
logger.addHandler(file_h)
29+
30+
console_h = logging.StreamHandler()
31+
console_h.setLevel(logging.INFO)
32+
logger.addHandler(console_h)
33+
34+
return logger
1635

1736
def main(cfg):
18-
print(f"=====================[{cfg['expr']}]=====================")
37+
logger = get_logger(cfg['expr'])
38+
logger.info(f"=====================[{cfg['expr']}]=====================")
1939

2040
# Device Setting
2141
device = None
2242
if cfg['device'] != 'cpu' and torch.cuda.is_available():
2343
device = cfg['device']
2444
else:
2545
device = 'cpu'
26-
print(f"device: {device}")
46+
logger.info(f"device: {device}")
2747

2848
# Hyperparameter Settings
2949
hp_cfg = cfg['hyperparameters']
@@ -35,11 +55,11 @@ def main(cfg):
3555
shuffle=True,
3656
batch_size=hp_cfg['batch_size'],
3757
drop_last=True)
38-
print(f"Load Dataset {data_cfg['dataset']}")
58+
logger.info(f"Load Dataset {data_cfg['dataset']}")
3959

4060
# Load Model
4161
model_cfg = cfg['model']
42-
print(model_cfg['name'])
62+
logger.info(model_cfg['name'])
4363
model = load_model(model_cfg).to(device)
4464
if cfg['parallel'] == True:
4565
model = nn.DataParallel(model)
@@ -73,24 +93,24 @@ def main(cfg):
7393
min_loss = 1e4
7494

7595
for current_epoch in range(1, hp_cfg['epochs']+1):
76-
print("=======================================================")
77-
print(f"Epoch: [{current_epoch:03d}/{hp_cfg['epochs']:03d}]\n")
96+
logger.info("=======================================================")
97+
logger.info(f"Epoch: [{current_epoch:03d}/{hp_cfg['epochs']:03d}]\n")
7898

7999
# Training One Epoch
80100
start_time = int(time.time())
81-
train_loss = train_one_epoch(model, train_dl, loss_fn, optimizer, scheduler, device)
101+
train_loss = train_one_epoch(model, train_dl, loss_fn, optimizer, scheduler, device, logger)
82102
elapsed_time = int(time.time() - start_time)
83-
print(f"Train Time: {elapsed_time//60:02d}m {elapsed_time%60:02d}s\n")
103+
logger.info(f"Train Time: {elapsed_time//60:02d}m {elapsed_time%60:02d}s")
84104

85105
if train_loss < min_loss:
86106
min_loss = train_loss
87-
save_model_ckpt(model, save_cfg['name'], current_epoch, save_cfg['weights_path'])
107+
save_model_ckpt(model, save_cfg['name'], current_epoch, save_cfg['weights_path'], logger)
88108

89109
total_train_loss.append(train_loss)
90-
save_loss_ckpt(save_cfg['name'], total_train_loss, save_cfg['loss_path'])
110+
save_loss_ckpt(save_cfg['name'], total_train_loss, save_cfg['loss_path'], logger)
91111

92112
total_elapsed_time = int(time.time()) - total_start_time
93-
print(f"<Total Train Time: {total_elapsed_time//60:02d}m {total_elapsed_time%60:02d}s>")
113+
logger.info(f"<Total Train Time: {total_elapsed_time//60:02d}m {total_elapsed_time%60:02d}s>")
94114

95115
if __name__ == '__main__':
96116
parser = argparse.ArgumentParser('Training', parents=[add_args_parser()])

utils/engine.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .metrics import get_metrics
77

8-
def train_one_epoch(model, dataloader, loss_fn, optimizer, scheduler, device):
8+
def train_one_epoch(model, dataloader, loss_fn, optimizer, scheduler, device, logger):
99
model.train()
1010
total_loss = []
1111

@@ -23,10 +23,12 @@ def train_one_epoch(model, dataloader, loss_fn, optimizer, scheduler, device):
2323
loss.backward()
2424
optimizer.step()
2525

26+
# Only stream (not log, because logging don't support the carriage return.)
2627
print(f"\rTraining: {100*batch_idx/len(dataloader):.2f}%, Loss: {sum(total_loss)/len(total_loss):.6f}, LR: {scheduler.get_last_lr()[0]:.6f}", end="")
2728
print()
2829

2930
scheduler.step(sum(total_loss)/len(total_loss))
31+
logger.info(f"Loss: {sum(total_loss)/len(total_loss):.6f}, LR: {scheduler.get_last_lr()[0]:.6f}")
3032

3133
return sum(total_loss)/len(total_loss)
3234

@@ -48,6 +50,7 @@ def evaluate(model, dataloader, device):
4850
total_outputs.extend(out.tolist())
4951
total_targets.extend(target.tolist())
5052

53+
# Only stream (not log, because logging don't support the carriage return.)
5154
print(f"\rEvaluate: {100*batch_idx/len(dataloader):.2f}%", end="")
5255
print()
5356

utils/save_ckpt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import os
44

5-
def save_model_ckpt(model, model_name, current_epoch, save_dir):
5+
def save_model_ckpt(model, model_name, current_epoch, save_dir, logger):
66
ckpt = {}
77
ckpt['model'] = model.state_dict()
88
ckpt['epochs'] = current_epoch
@@ -11,13 +11,13 @@ def save_model_ckpt(model, model_name, current_epoch, save_dir):
1111

1212
try:
1313
torch.save(ckpt, os.path.join(save_dir, save_name))
14-
print(f"Save Model @epoch: {current_epoch}")
14+
logger.info(f"Save Model @epoch: {current_epoch}")
1515
except:
16-
print(f"Can\'t Save Model @epoch: {current_epoch}")
16+
logger.info(f"Can\'t Save Model @epoch: {current_epoch}")
1717

18-
def save_loss_ckpt(model_name, train_loss, save_dir):
18+
def save_loss_ckpt(model_name, train_loss, save_dir, logger):
1919
try:
2020
np.save(os.path.join(save_dir, f'train_loss_{model_name}.npy'), np.array(train_loss))
21-
print('Save Train Loss')
21+
logger.info('Save Train Loss')
2222
except:
23-
print('Can\'t Save Train Loss')
23+
logger.info('Can\'t Save Train Loss')

0 commit comments

Comments
 (0)