Skip to content

Commit 69ccf26

Browse files
committed
Set strategy for pointwise conv
1 parent a3cca72 commit 69ccf26

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

hls4ml/backends/vivado/passes/pointwise.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ def transform(self, model, node):
8282
expand_axis = tuple(range(int(dim[0])))
8383
pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=expand_axis)
8484
pw_node.weights['bias'].data = node.weights['bias'].data
85+
# Set strategy to ensure lowercase string is passed to the template
86+
if model.config.is_resource_strategy(pw_node):
87+
pw_node.set_attr('strategy', 'resource')
88+
else:
89+
pw_node.set_attr('strategy', 'latency')
8590
model.replace_node(node, pw_node)
8691

8792
return True

test/pytest/test_pointwiseconv.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,30 @@ def test_pointwiseconv2d(chans, padds, strides, backend, io_type, strategy):
127127

128128
assert 'Pointwise' in list(hls_model.graph.values())[1].class_name
129129
np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=0.001)
130+
131+
132+
@pytest.mark.parametrize('strategy', ['Latency', 'Resource'])
133+
def test_pointwise_config(strategy):
134+
model = tf.keras.models.Sequential()
135+
input_shape = (8, 8, 3)
136+
model.add(
137+
Conv2D(
138+
filters=8,
139+
kernel_size=(1, 1),
140+
input_shape=input_shape,
141+
kernel_initializer='normal',
142+
use_bias=False,
143+
name='conv2d_1x1',
144+
)
145+
)
146+
147+
model.compile(optimizer='adam', loss='mse')
148+
149+
config = hls4ml.utils.config_from_keras_model(model, granularity='name')
150+
config['Model']['Strategy'] = strategy
151+
config['LayerName']['conv2d_1x1']['Strategy'] = strategy # Will fail if the strategy is not lowercase
152+
output_dir = str(test_root_path / f'hls4mlprj_pointwise2d_config_{strategy}')
153+
154+
hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir)
155+
# Model will fail to compile if strategy was set incorrectly
156+
hls_model.compile()

0 commit comments

Comments
 (0)