Skip to content

Commit 27f935d

Browse files
authored
Merge pull request #887 from calad0i/softmax_fix
Fix bit overflow with softmax
2 parents 0d48aff + 3db1e83 commit 27f935d

File tree

4 files changed

+97
-35
lines changed

4 files changed

+97
-35
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import warnings
2+
3+
from hls4ml.model.layers import Layer, Softmax
4+
from hls4ml.model.optimizer import OptimizerPass
5+
6+
7+
class FixSoftmaxTableSize(OptimizerPass):
8+
def match(self, node):
9+
return isinstance(node, Softmax)
10+
11+
def transform(self, model, node: Layer):
12+
inp_layer = node.get_input_node() # type: ignore
13+
if not isinstance(inp_layer, Layer):
14+
raise RuntimeError(f'Softmax layer {node.name} does not have an input layer')
15+
16+
input_bw: int = inp_layer.get_attr('result_t').precision.width # type: ignore
17+
table_bw: int = node.get_attr('inv_table_t').precision.width # type: ignore
18+
table_size = int(node.get_attr('table_size')) # type: ignore
19+
20+
backend = model.config.config['Backend']
21+
22+
# Somehow, Intel want one extra bits for the table.
23+
# I don't know why but if not simulation will crash with segmentation fault.
24+
backend_limitation = -1 if backend == 'Quartus' else 0
25+
26+
if 2 ** (min(input_bw, table_bw) + backend_limitation) < table_size:
27+
# If table size is too large w.r.t. input bitwidth and table bitwidth,
28+
# reduce table size to avoid undefined behavior when cutting indices from,
29+
# fixed point number.
30+
node.set_attr('table_size', str(2 ** (min(input_bw, table_bw) + backend_limitation)))
31+
if 2**input_bw < table_size:
32+
# The warning message does not have to be looking like this, but you are asking
33+
# 125 characters long line.
34+
warnings.warn(
35+
(
36+
f"Softmax layer {node.name} table size is too large for input"
37+
f"bitwidth {input_bw}. Setting table size to {2**input_bw}."
38+
"To avoid this warning, please increase input bitwidth or"
39+
"decrease table size."
40+
),
41+
stacklevel=1,
42+
)
43+
if 2**table_bw < table_size:
44+
warnings.warn(
45+
(
46+
f"Softmax layer {node.name} table size is too large for input"
47+
f"bitwidth {input_bw}. Setting table size to {2**input_bw}."
48+
"To avoid this warning, please increase input bitwidth or"
49+
"decrease table size."
50+
),
51+
stacklevel=1,
52+
)
53+
if backend == 'Quartus':
54+
warnings.warn(
55+
(
56+
"Quartus backend's table size is half of 2^min(input_bw-1,table_bw-1)"
57+
" instead of 2^min(input_bw,table_bw)."
58+
),
59+
stacklevel=1,
60+
)
61+
return False
62+
63+
64+
def register_softmax__table_size_fix(backend):
65+
backend.register_pass('fix_softmax_table_size', FixSoftmaxTableSize)

hls4ml/backends/quartus/quartus_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def _register_flows(self):
7272
'quartus:inplace_parallel_reshape',
7373
'quartus:inplace_stream_flatten',
7474
'quartus:skip_softmax',
75+
'quartus:fix_softmax_table_size',
7576
]
7677
optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name)
7778

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def _register_flows(self):
108108
'vivado:inplace_parallel_reshape',
109109
'vivado:inplace_stream_flatten',
110110
'vivado:skip_softmax',
111+
'vivado:fix_softmax_table_size',
111112
]
112113
optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name)
113114

test/pytest/test_softmax.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,50 +10,46 @@
1010
test_root_path = Path(__file__).parent
1111

1212

13-
def flat_distribution(shape):
14-
return np.random.rand(*shape)
15-
16-
17-
def high_accuracy_distribution(shape):
18-
'''Start with a flat distribution, then pick a random member of each row to amplify'''
19-
x = np.random.rand(*shape)
20-
imax = np.random.randint(0, shape[1], size=shape[0])
21-
x[:, imax] *= 10
22-
return x
23-
24-
2513
@pytest.fixture()
26-
def generate_data(function, input_shape):
27-
return function((1000, *input_shape))
14+
def generate_data(input_shape):
15+
shape = (5000, *input_shape)
16+
d = np.random.normal(0, 2, shape)
17+
modify_entries = np.random.randint(0, 1, shape) < 0.05
18+
d[modify_entries] = d[modify_entries] * 5 + 10
19+
return np.clip(d, -32, 31)
2820

2921

3022
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
31-
@pytest.mark.parametrize('strategy', ['stable', 'argmax'])
23+
@pytest.mark.parametrize('strategy', ['stable', 'latency', 'argmax'])
3224
@pytest.mark.parametrize(
33-
'function,input_shape,io_type',
25+
'input_bits,input_shape,table_bits,io_type',
3426
[
35-
(flat_distribution, (8,), 'io_parallel'),
36-
(high_accuracy_distribution, (8,), 'io_parallel'),
37-
(flat_distribution, (8,), 'io_stream'),
38-
(high_accuracy_distribution, (8,), 'io_stream'),
39-
(flat_distribution, (8, 8, 3), 'io_stream'),
40-
(high_accuracy_distribution, (8, 8, 3), 'io_stream'),
27+
('16,6', (8,), '18,8', 'io_parallel'),
28+
('16,6', (8,), '18,8', 'io_stream'),
29+
('16,6', (8,), '9,6', 'io_parallel'),
30+
('16,6', (8,), '9,6', 'io_stream'),
31+
('9,6', (8,), '18,8', 'io_parallel'),
32+
('9,6', (8,), '18,8', 'io_stream'),
33+
('16,6', (8, 8, 3), '18,8', 'io_stream'),
4134
],
4235
)
43-
def test_softmax(backend, strategy, generate_data, input_shape, io_type, function):
36+
def test_softmax(backend, strategy, generate_data, input_bits, input_shape, table_bits, io_type):
4437
X = generate_data
4538
model = tf.keras.models.Sequential()
4639
model.add(tf.keras.layers.Activation(input_shape=input_shape, activation='softmax', name='softmax'))
4740
model.compile()
4841

49-
f_type = 'ac_fixed<18,8,true,AC_RND,AC_SAT>' if backend == 'Quartus' else 'ap_fixed<18,8,AP_RND,AP_SAT>'
42+
table_type = f'fixed<{table_bits}, RND, SAT>'
43+
5044
cfg = hls4ml.utils.config_from_keras_model(model, granularity='name')
5145
cfg['LayerName']['softmax']['Strategy'] = strategy
52-
cfg['LayerName']['softmax']['inv_table_t'] = f_type
53-
cfg['LayerName']['softmax']['exp_table_t'] = f_type
46+
cfg['LayerName']['softmax']['inv_table_t'] = table_type
47+
cfg['LayerName']['softmax']['exp_table_t'] = table_type
48+
cfg['LayerName']['softmax_input']['Precision']['result'] = f'fixed<{input_bits}>'
5449

55-
odir = str(test_root_path / 'hls4mlprj_softmax_{}_{}_{}_{}_{}').format(
56-
backend, io_type, strategy, function.__name__, str(input_shape)
50+
odir = str(
51+
test_root_path
52+
/ f'hls4mlprj_softmax_{backend}_{io_type}_{strategy}_{input_shape}_input-bits={input_bits}_table-bits={table_bits}'
5753
)
5854
hls_model = hls4ml.converters.convert_from_keras_model(
5955
model, hls_config=cfg, io_type=io_type, output_dir=odir, backend=backend
@@ -73,9 +69,9 @@ def test_softmax(backend, strategy, generate_data, input_shape, io_type, functio
7369
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
7470
def test_softmax_skipped(backend, io_type):
7571
X = np.random.rand(100, 10)
76-
model = tf.keras.models.Sequential()
77-
model.add(tf.keras.layers.Dense(14, input_shape=(10,), name='dense'))
78-
model.add(tf.keras.layers.Activation(activation='softmax', name='softmax'))
72+
dense = tf.keras.layers.Dense(14, input_shape=(10,), name='dense')
73+
softmax = tf.keras.layers.Activation(activation='softmax', name='softmax')
74+
model = tf.keras.models.Sequential([dense, softmax])
7975
model.compile()
8076

8177
cfg = hls4ml.utils.config_from_keras_model(model, granularity='name')
@@ -92,7 +88,6 @@ def test_softmax_skipped(backend, io_type):
9288
assert len(hls_layers) == 2
9389

9490
# Verify hls4ml output is equal to Dense output
95-
y_keras = model.predict(X)
96-
y_hls4ml = hls_model.predict(X).reshape(y_keras.shape)
97-
keras_trace = hls4ml.model.profiling.get_ymodel_keras(model, X)
98-
np.testing.assert_allclose(y_hls4ml, keras_trace['dense'], rtol=0, atol=2e-2)
91+
y_keras_dense = dense(X).numpy() # type: ignore
92+
y_hls4ml = hls_model.predict(X).reshape(y_keras_dense.shape) # type: ignore
93+
np.testing.assert_allclose(y_hls4ml, y_keras_dense, rtol=0, atol=2e-2)

0 commit comments

Comments
 (0)