-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
126 lines (107 loc) · 4.08 KB
/
utils.py
File metadata and controls
126 lines (107 loc) · 4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import numpy as np
import os
import sys
import yaml
import logging
from subprocess import Popen, PIPE
import torch.autograd
import torch.nn
def print_metrics(metrics, fp=None):
metric_str = ""
for metric in metrics:
metric_str += '\t%s: %.4f' % (metric, metrics[metric])
if fp is None:
print(metric_str)
else:
with open(fp, 'wb') as f:
f.write(metric_str)
def setup_logger(loglevel, logfile=None):
"""Sets up the logger
Arguments:
loglevel (str): The log level (INFO|DEBUG|..)
logfile Optional[str]: Add a file handle
Returns:
None
"""
numeric_level = getattr(logging, loglevel, None)
if not isinstance(numeric_level, int):
raise ValueError("Invalid log level: %s" % loglevel)
logger = logging.getLogger()
logging.basicConfig(
format='%(asctime)s: %(levelname)s: %(message)s',
level=numeric_level, stream=sys.stdout)
if logfile is not None:
fmt = logging.Formatter('%(asctime)s: %(levelname)s: %(message)s')
logfile_handle = logging.FileHandler(logfile, 'w')
logfile_handle.setFormatter(fmt)
logger.addHandler(logfile_handle)
def setup_output_dir(output_dir, config, loglevel):
"""
Takes in the output_dir. Note that the output_dir stores each run as run-1, ....
Makes the next run directory. This also sets up the logger
A run directory has the following structure
run-1:
|_ best_model
|_ model_params_and_metrics.tar.gz
|_ validation paths.txt
|_ last_model_params_and_metrics.tar.gz
|_ config.yaml
|_ githash.log of current run
|_ gitdiff.log of current run
|_ logfile.log (the log of the current run)
This also changes the config, to add the save directory
"""
make_directory(output_dir, recursive=True)
last_run = -1
for dirname in os.listdir(output_dir):
if dirname.startswith('run-'):
last_run = max(last_run, int(dirname.split('-')[1]))
new_dirname = os.path.join(output_dir, 'run-%d' % (last_run + 1))
make_directory(new_dirname)
best_model_dirname = os.path.join(new_dirname, 'best_model')
make_directory(best_model_dirname)
config_file = os.path.join(new_dirname, 'config.yaml')
config['data_params']['save_dir'] = new_dirname
write_to_yaml(config_file, config)
# Save the git hash
process = Popen('git log -1 --format="%H"'.split(), stdout=PIPE, stderr=PIPE)
stdout, stderr = process.communicate()
stdout = stdout.decode('utf-8').strip('\n').strip('"')
with open(os.path.join(new_dirname, "githash.log"), "w") as fp:
fp.write(stdout)
# Save the git diff
process = Popen('git diff'.split(), stdout=PIPE, stderr=PIPE)
stdout, stderr = process.communicate()
with open(os.path.join(new_dirname, "gitdiff.log"), "w") as fp:
stdout = stdout.decode('utf-8')
fp.write(stdout)
# Set up the logger
logfile = os.path.join(new_dirname, 'logfile.log')
setup_logger(loglevel, logfile)
return new_dirname, config
def read_from_yaml(filepath):
with open(filepath, 'r') as fd:
data = yaml.load(fd, Loader=yaml.FullLoader)
return data
def write_to_yaml(filepath, data):
with open(filepath, 'w') as fd:
yaml.dump(data=data, stream=fd, default_flow_style=False)
def make_directory(dirname, recursive=False):
os.makedirs(dirname, exist_ok=True)
def disp_params(params, name):
print_string = "{0}".format(name)
for param in params:
print_string += '\n\t%s: %s' % (param, str(params[param]))
logger = logging.getLogger(__name__)
logger.info(print_string)
def to_cuda(t, gpu):
if gpu != None:
device = "cuda:{}".format(gpu)
return t.to(device) if gpu != None else t
def to_numpy(t, gpu):
"""
Takes in a Variable, and returns numpy
"""
ret = t.data if isinstance(t, (torch.autograd.Variable, torch.nn.Parameter)) else t
ret = ret.cpu() if gpu != None else ret # this brings it back to cpu
return ret.numpy()