-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodelTrainer.py
More file actions
executable file
·57 lines (51 loc) · 2.73 KB
/
Copy pathmodelTrainer.py
File metadata and controls
executable file
·57 lines (51 loc) · 2.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# coding: utf-8
import os
import contextlib
class ModelTrainer:
def __init__(self, numberOfStepsBeforeTesting, filenameToStoreBestModel = None):
self.__totalNumberOfStepsMade = 0
self.__numberOfStepsBeforeTesting = numberOfStepsBeforeTesting
self.__filenameToStoreBestModel = filenameToStoreBestModel
def setGenerators(self, trainGen, testGen):
self.__trainGenerator = trainGen
self.__testGenerator = testGen
def setTrainFiniteGeneratorForTestModel(self, trainFiniteGeneratorForTestModel):
self.__trainFiniteGeneratorForTestModel = trainFiniteGeneratorForTestModel
def getModel(self):
return self.__model
def setModel(self, model):
self.__model = model
def storeModel(self):
if self.__filenameToStoreBestModel != None:
self.__model.save(self.__filenameToStoreBestModel)
def train(self, showModelOutputDuringTrainingSteps=False):
self.__continueTraining = True
self.__totalNumberOfStepsMade = 0
minimumLoss = None
print("Number of steps before testing step: " + str(self.__numberOfStepsBeforeTesting))
while self.__continueTraining == True:
if showModelOutputDuringTrainingSteps == False:
with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
self.__model.fit_generator(self.__trainGenerator, steps_per_epoch=self.__numberOfStepsBeforeTesting, epochs=1, use_multiprocessing=True)
else:
self.__model.fit_generator(self.__trainGenerator, steps_per_epoch=self.__numberOfStepsBeforeTesting, epochs=1, use_multiprocessing=True)
self.__totalNumberOfStepsMade += self.__numberOfStepsBeforeTesting
print("Number of steps made: " + str(self.__totalNumberOfStepsMade))
if self.__trainFiniteGeneratorForTestModel != None: # test model. Can it predict correctly at least on train data
loss = self.__model.evaluate_generator(self.__trainFiniteGeneratorForTestModel, use_multiprocessing=True)
print("Current loss on train data: " + str(loss))
loss = self.__model.evaluate_generator(self.__testGenerator, use_multiprocessing=True)
print("Current loss: " + str(loss))
if minimumLoss == None or loss < minimumLoss:
self.storeModel()
minimumLoss = loss
def stopTraining(self):
self.__continueTraining = False
__model = None
__trainGenerator = None
__testGenerator = None # child of keras.utils.Sequence
__totalNumberOfStepsMade = None
__numberOfStepsBeforeTesting = None
__continueTraining = None
__filenameToStoreBestModel = None
__trainFiniteGeneratorForTestModel = None