-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfinal_model.py
More file actions
58 lines (41 loc) · 2.01 KB
/
final_model.py
File metadata and controls
58 lines (41 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, LSTM, SimpleRNN, GRU, Masking
from keras.optimizers import RMSprop, SGD, Adam, Nadam
from keras.callbacks import ModelCheckpoint, EarlyStopping
from data import getReadyData
import numpy as np
import pickle
def hand_model(cell_size, n_cell, epochs=10, dropout=0.5, activation='sigmoid', optimizer='adam', lr=0.03, decay=0.09, kernel_init='glorot_uniform'):
x_train, x_val, x_test, y_train, y_val, y_test = getReadyData()
saved_model_path = './hand_made_models/' \
+ str(cell_size) + '_' \
+ str(n_cell) + '_' \
+ 'GRU' + '_' \
+ str(dropout) + '_' \
+ activation + '_' \
+ optimizer + '_' \
+ str(lr) + '_' \
+ str(decay) + '.hdf5'
model = Sequential()
model.add(Masking(10.0, input_shape=(20, 2)))
for i in range(n_cell - 1):
model.add(GRU(cell_size, return_sequences=True, kernel_initializer=kernel_init))
model.add(GRU(cell_size, kernel_initializer=kernel_init))
model.add(Dropout(dropout))
model.add(Dense(2))
model.add(Activation(activation))
model.compile(optimizer=Adam(lr=lr, decay=decay), loss='categorical_crossentropy', metrics=['categorical_accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=1024, epochs=epochs,
validation_data=[x_val, y_val],
callbacks=[ModelCheckpoint(saved_model_path, monitor='val_loss',
verbose=2, save_best_only=True),
EarlyStopping(monitor='val_loss', patience=10)])
np.save('final_eval.npy', model.evaluate(x_test, y_test))
with open('history_cells{}_depth{}.p'.format(cell_size, n_cell), mode='w') as f:
pickle.dump(history.history, f)
def main():
hand_model(64, 5, epochs=10000)
if __name__ == '__main__':
main()