Skip to content

Commit bb9dfe5

Browse files
committed
Add gamma and beta attrs for batchnormalization. Account for no scale and offset
1 parent 8523b2b commit bb9dfe5

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

hls4ml/converters/keras/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader):
129129
elif len(input_shapes[0]) == 4:
130130
layer['n_filt'] = input_shapes[0][3]
131131

132+
layer['use_gamma'] = False if keras_layer['config']['scale'] == False else True
133+
layer['use_beta'] = False if keras_layer['config']['center'] == False else True
134+
132135
return layer, [shape for shape in input_shapes[0]]
133136

134137

hls4ml/converters/pytorch/core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ def parse_batchnorm_layer(pytorch_layer, layer_name, input_shapes, data_reader,
6161
#batchnorm para
6262
layer['epsilon'] = pytorch_layer.eps
6363

64+
if pytorch_layer.affine is False:
65+
layer['use_gamma'], layer['use_beta'] = False, False
66+
else:
67+
layer['use_gamma'], layer['use_beta'] = True, True
68+
6469
in_size = 1
6570
for dim in input_shapes[0][1:]:
6671
in_size *= dim

hls4ml/model/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -800,8 +800,8 @@ def initialize(self):
800800
dims = inp.dim_names
801801
self.add_output_variable(shape, dims)
802802

803-
gamma = self.model.get_weights_data(self.name, 'gamma')
804-
beta = self.model.get_weights_data(self.name, 'beta')
803+
gamma = 1 if self.attributes['use_gamma'] is False else self.model.get_weights_data(self.name, 'gamma')
804+
beta = 0 if self.attributes['use_beta'] is False else self.model.get_weights_data(self.name, 'beta')
805805
mean = self.model.get_weights_data(self.name, 'moving_mean')
806806
var = self.model.get_weights_data(self.name, 'moving_variance')
807807

0 commit comments

Comments
 (0)