Skip to content

Commit 8dcec2b

Browse files
committed
update
2 parents bf87266 + bb9dfe5 commit 8dcec2b

File tree

4 files changed

+10
-9
lines changed

4 files changed

+10
-9
lines changed

hls4ml/converters/keras/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader):
131131
elif len(input_shapes[0]) == 4:
132132
layer['n_filt'] = input_shapes[0][3]
133133

134+
layer['use_gamma'] = keras_layer['config']['scale']
135+
layer['use_beta'] = keras_layer['config']['center']
136+
134137
return layer, [shape for shape in input_shapes[0]]
135138

136139

hls4ml/converters/pytorch/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def parse_batchnorm_layer(pytorch_layer, layer_name, input_shapes, data_reader,
6565

6666
# batchnorm para
6767
layer['epsilon'] = pytorch_layer.eps
68+
layer['use_gamma'] = layer['use_beta'] = not pytorch_layer.affine
6869

6970
in_size = 1
7071
for dim in input_shapes[0][1:]:

hls4ml/model/layers.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,8 @@ class BatchNormalization(Layer):
863863
WeightAttribute('bias'),
864864
TypeAttribute('scale'),
865865
TypeAttribute('bias'),
866+
Attribute('use_gamma', value_type=bool, default=True),
867+
Attribute('use_beta', value_type=bool, default=True),
866868
]
867869

868870
def initialize(self):
@@ -871,16 +873,11 @@ def initialize(self):
871873
dims = inp.dim_names
872874
self.add_output_variable(shape, dims)
873875

874-
gamma = self.model.get_weights_data(self.name, 'gamma')
875-
beta = self.model.get_weights_data(self.name, 'beta')
876+
gamma = self.model.get_weights_data(self.name, 'gamma') if self.get_attr('use_gamma') else 1
877+
beta = self.model.get_weights_data(self.name, 'beta') if self.get_attr('use_beta') else 0
876878
mean = self.model.get_weights_data(self.name, 'moving_mean')
877879
var = self.model.get_weights_data(self.name, 'moving_variance')
878880

879-
if gamma is None:
880-
gamma = np.ones(mean.shape)
881-
if beta is None:
882-
beta = np.zeros(mean.shape)
883-
884881
scale = gamma / np.sqrt(var + self.get_attr('epsilon'))
885882
bias = beta - scale * mean
886883

test/pytest/test_batchnorm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ def data():
2323
@pytest.fixture(scope='module')
2424
def model(request):
2525
model = Sequential()
26-
model.add(BatchNormalization(input_shape=(in_shape,), center=request.param[0], scale=request.param[1]))
26+
model.add(BatchNormalization(input_shape=(in_shape,), center=request.param, scale=request.param))
2727
model.compile()
2828
return model
2929

3030

3131
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
3232
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
33-
@pytest.mark.parametrize('model', ([True, True], [True, False], [False, True], [False, False]), indirect=True)
33+
@pytest.mark.parametrize('model', [True, False], indirect=True)
3434
def test_batchnorm(model, data, backend, io_type):
3535

3636
default_precision = 'ac_fixed<32, 1, true>' if backend == 'Quartus' else 'ac_fixed<32, 1>'

0 commit comments

Comments
 (0)