This repository was archived by the owner on Aug 20, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
47 lines (35 loc) · 1.23 KB
/
train.py
File metadata and controls
47 lines (35 loc) · 1.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# -*- coding: utf-8 -*-
import argparse
from chainer import optimizers
from progressbar import ProgressBar
import random
import six.moves.cPickle as pickle
from alexnet import forward, model
from util import load_image, num_to_label, walk_dir
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data_dir', type=str, default='data')
parser.add_argument('-e', '--epoch', type=int, default=100)
args = parser.parse_args()
# init optimizer
optimizer = optimizers.Adam()
optimizer.setup(model)
# load data
data = []
walk_dir(args.data_dir, lambda i, f: data.extend([(num_to_label(i), load_image(f))]))
# learn
for i in range(args.epoch):
random.shuffle(data)
t = 0
pbar = ProgressBar(len(data))
for (label, img) in data:
optimizer.zero_grads()
loss, acc = forward(model, img, label, train=True)
loss.backward()
optimizer.update()
t += 1
pbar.update(t)
print '%s 回繰り返し学習を行った' % (i + 1)
print 'ヨッシャ! 学習おわったでw'
# dump model
pickle.dump(model, open('AlexNet_epoch_%s.pickle' % (args.epoch), 'wb'), -1)