Skip to content

Commit 6a27a54

Browse files
authored
Merge pull request #754 from fastmachinelearning/batchnorm_fix
Fix for BatchNormalization layers with `center=False` or `scale=False`
2 parents a031b6a + cadd0fa commit 6a27a54

File tree

4 files changed

+18
-7
lines changed

4 files changed

+18
-7
lines changed

hls4ml/converters/keras/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader):
9898
elif len(input_shapes[0]) == 4:
9999
layer['n_filt'] = input_shapes[0][3]
100100

101+
layer['use_gamma'] = keras_layer['config']['scale']
102+
layer['use_beta'] = keras_layer['config']['center']
103+
101104
return layer, [shape for shape in input_shapes[0]]
102105

103106

hls4ml/converters/pytorch/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def parse_batchnorm_layer(pytorch_layer, layer_name, input_shapes, data_reader,
6565

6666
# batchnorm para
6767
layer['epsilon'] = pytorch_layer.eps
68+
layer['use_gamma'] = layer['use_beta'] = pytorch_layer.affine
6869

6970
in_size = 1
7071
for dim in input_shapes[0][1:]:

hls4ml/model/layers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,9 @@ def _get_folded_weights(self):
541541

542542
# wrap conv kernel and bias with bn parameters
543543
folded_kernel = inv * kernel
544-
folded_bias = inv * (bias - moving_mean) + beta
544+
folded_bias = inv * (bias - moving_mean)
545+
if beta is not None:
546+
folded_bias += beta
545547

546548
return [folded_kernel, folded_bias]
547549

@@ -832,6 +834,8 @@ class BatchNormalization(Layer):
832834
WeightAttribute('bias'),
833835
TypeAttribute('scale'),
834836
TypeAttribute('bias'),
837+
Attribute('use_gamma', value_type=bool, default=True),
838+
Attribute('use_beta', value_type=bool, default=True),
835839
]
836840

837841
def initialize(self):
@@ -840,13 +844,13 @@ def initialize(self):
840844
dims = inp.dim_names
841845
self.add_output_variable(shape, dims)
842846

843-
gamma = self.model.get_weights_data(self.name, 'gamma')
844-
beta = self.model.get_weights_data(self.name, 'beta')
847+
gamma = self.model.get_weights_data(self.name, 'gamma') if self.get_attr('use_gamma') else 1
848+
beta = self.model.get_weights_data(self.name, 'beta') if self.get_attr('use_beta') else 0
845849
mean = self.model.get_weights_data(self.name, 'moving_mean')
846850
var = self.model.get_weights_data(self.name, 'moving_variance')
847851

848852
scale = gamma / np.sqrt(var + self.get_attr('epsilon'))
849-
bias = beta - gamma * mean / np.sqrt(var + self.get_attr('epsilon'))
853+
bias = beta - scale * mean
850854

851855
self.add_weights_variable(name='scale', var_name='s{index}', data=scale)
852856
self.add_weights_variable(name='bias', var_name='b{index}', data=bias)

test/pytest/test_batchnorm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,24 @@ def data():
2121

2222

2323
@pytest.fixture(scope='module')
24-
def model():
24+
def model(request):
2525
model = Sequential()
26-
model.add(BatchNormalization(input_shape=(in_shape,)))
26+
model.add(BatchNormalization(input_shape=(in_shape,), center=request.param, scale=request.param))
2727
model.compile()
2828
return model
2929

3030

3131
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
3232
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
33+
@pytest.mark.parametrize('model', [True, False], indirect=True)
3334
def test_batchnorm(model, data, backend, io_type):
3435

3536
default_precision = 'ac_fixed<32, 1, true>' if backend == 'Quartus' else 'ac_fixed<32, 1>'
3637

38+
center = model.layers[0].center
39+
scale = model.layers[0].scale
3740
config = hls4ml.utils.config_from_keras_model(model, default_precision=default_precision, granularity='name')
38-
output_dir = str(test_root_path / f'hls4mlprj_batchnorm_{backend}_{io_type}')
41+
output_dir = str(test_root_path / f'hls4mlprj_batchnorm_{backend}_{io_type}_center{center}_scale{scale}')
3942
hls_model = hls4ml.converters.convert_from_keras_model(
4043
model, backend=backend, hls_config=config, io_type=io_type, output_dir=output_dir
4144
)

0 commit comments

Comments
 (0)