Skip to content

Commit 1894819

Browse files
committed
Alpha Relase
1 parent ea5af2a commit 1894819

File tree

2 files changed

+15
-23
lines changed

2 files changed

+15
-23
lines changed

run/run_CNN.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import theano.tensor as T
2424
from theano.tensor.shared_randomstreams import RandomStreams
2525

26-
from utils.load_conf import load_model,load_conv_spec,load_mlp_spec,load_data_spec
26+
from utils.load_conf import load_model,load_conv_spec,load_data_spec
2727
from io_modules.file_reader import read_dataset
2828
from utils.learn_rates import LearningRate
2929
from utils.utils import parse_activation
@@ -43,8 +43,10 @@ def runCNN(arg):
4343
else :
4444
model_config = load_model(arg,'CNN')
4545

46-
conv_config,conv_layer_config,mlp_config = load_conv_spec(model_config['nnet_spec'],model_config['batch_size'],
47-
model_config['input_shape'])
46+
conv_config,conv_layer_config,mlp_config = load_conv_spec(
47+
model_config['nnet_spec'],
48+
model_config['batch_size'],
49+
model_config['input_shape'])
4850

4951
data_spec = load_data_spec(model_config['data_spec']);
5052

@@ -59,6 +61,7 @@ def runCNN(arg):
5961
createDir(model_config['wdir']);
6062
#create working dir
6163

64+
batch_size = model_config['batch_size'];
6265
cnn = CNN(numpy_rng,theano_rng,conv_layer_configs = conv_layer_config, batch_size = batch_size,
6366
n_outs=model_config['n_outs'],hidden_layers_sizes=mlp_config['layers'],
6467
conv_activation = conv_activation,hidden_activation = hidden_activation,
@@ -68,7 +71,6 @@ def runCNN(arg):
6871

6972
#learning rate, batch-size and momentum
7073
lrate = LearningRate.get_instance(model_config['l_rate_method'],model_config['l_rate']);
71-
batch_size = model_config['batch_size'];
7274
momentum = model_config['momentum']
7375

7476
train_sets, train_xy, train_x, train_y = read_dataset(data_spec['training'],model_config['batch_size'])

utils/load_conf.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,15 @@ def load_model(input_file,nnetType=None):
2424
logger.critical(" 'nnetType' is missing in model properties file..")
2525
exit(1)
2626

27-
if checkConfig(data,nnetType):
27+
requiredKeys = ['data_spec','wdir','processes','nnet_spec','output_file','n_outs']
28+
if not isKeysPresents(data,requiredKeys):
29+
logger.critical(" the mandatory arguments are missing in model properties file..")
30+
exit(1)
31+
32+
if data.has_key('n_ins') or data.has_key('input_shape'):
33+
pass
34+
else:
35+
logger.error('Neither n_ins nor input_shape is present')
2836
logger.critical(" the mandatory arguments are missing in model properties file..")
2937
exit(1)
3038

@@ -69,19 +77,6 @@ def correctPath(data,keys,basePath):
6977
data[key] = makeAbsolute(data[key],basePath)
7078
return data
7179

72-
def checkConfig(data,nnetType):
73-
requiredKeys = [
74-
'data_spec','wdir','processes',
75-
'nnet_spec','output_file','n_outs'
76-
]
77-
if isKeysPresents(data,requiredKeys):
78-
return False
79-
if data.has_key('n_ins') or data.has_key('input_shape'):
80-
return True
81-
else:
82-
logger.error('Neither n_ins nor input_shape is present')
83-
return False
84-
8580
def isKeysPresents(data,requiredKeys):
8681
for key in requiredKeys:
8782
if not data.has_key(key):
@@ -114,11 +109,6 @@ def load_data_spec(input_file):
114109
return data
115110

116111

117-
def load_mlp_spec(input_file):
118-
logger.info("Loading mlp properties from %s ...",input_file)
119-
return load_json(input_file);
120-
121-
122112
#############################################################################
123113
#CNN
124114
#############################################################################

0 commit comments

Comments
 (0)