Skip to content

Commit dced28f

Browse files
steltzestzelepiJanFSchultevloncar
authored
FIFO depth optimizer for Vitis backend (#1037)
* Init depthwise resource implementation for streaming interface * Init fifo optimization file for vitis backend * Register fifo opt flow in vitis backend * Init changes in build_prj.tcl and modification files in vitis writer * Fix vitis writer by adding project.tcl modifer * Fix build_prj.tcl to synthesize with the large FIFOs * Fix if statement in cosim tcl script * Clean the optimizer file * Implement the optmized depths parsing * Implement setter for new depths * Fix csv file name parsing * Fix name parsing, deeply hardcoded for now * Clean documentation and files * Remove unused function * Add documentation and runtime checks * Add documentation * Include extracting optimized depths * Fix documentation * Add function to override Vivado test bench * Fix hls4ml docs * Undo changes in sepconv stream * Format code * Run pre-commit * Remove unused imports * Run pre-commit * Remove comment * Fix typo and documentation * Remove commented out code * Init unit test * Use proper model for unit test to profile fifos * Fix json generator to include before and after depths * Set up full test * Set up exception tests * Clean test * Fix full test * Clean test * Run precommit * Force the cosimulation to execute twice * Skip tests * Update documentation * Fix conflict, use built-in os function * Setup onnx pytest * Rebase and fix optimizer after main branch changes * Update documentation * Run precommit * Fix qonnx test by optimizing away the input quantization * Run precommit * Address review comments * Fix c-test for loop * Correct comment * Streamlining some changes to better fit the codebase (but mostly cosmetic) --------- Co-authored-by: stzelepi <stylianos.tzelepis@cern.ch> Co-authored-by: Jan-Frederik Schulte <jschulte@cern.ch> Co-authored-by: Vladimir Loncar <vloncar@users.noreply.github.com>
1 parent 5fdcb18 commit dced28f

File tree

8 files changed

+467
-20
lines changed

8 files changed

+467
-20
lines changed

docs/advanced/fifo_depth.rst

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,29 @@ FIFO Buffer Depth Optimization
55
With the ``io_stream`` IO type, each layer is connected with the subsequent layer through first-in first-out (FIFO) buffers.
66
The implementation of the FIFO buffers contribute to the overall resource utilization of the design, impacting in particular the BRAM or LUT utilization.
77
Because the neural networks can have complex architectures generally, it is hard to know a priori the correct depth of each FIFO buffer.
8-
By default ``hls4ml`` choses the most conservative possible depth for each FIFO buffer, which can result in a an unnecessary overutilization of resources.
8+
By default ``hls4ml`` choses the most conservative possible depth for each FIFO buffer, which can result in a an unnecessary over-utilization of resources.
99

10-
In order to reduce the impact on the resources used for FIFO buffer implementation, an optimization has been developed in `#509 <https://github.com/fastmachinelearning/hls4ml/pull/509>`_ that correctly sizes the depth of the FIFO buffers by analyzing the RTL cosimulation.
11-
We implemented this FIFO buffer resizing as a :py:class:`~hls4ml.backends.vivado.passes.fifo_depth_optimization` optimizer pass.
10+
In order to reduce the impact on the resources used for FIFO buffer implementation, an optimization flow has been developed that correctly sizes the depth
11+
of the FIFO buffers by analyzing the RTL co-simulation. This feature is currently available in ``Vitis`` and ``Vivado`` backends.
12+
13+
In ``Vivado`` backend, FIFO buffer resizing is implemented as a :py:class:`~hls4ml.backends.vivado.passes.fifo_depth_optimization` optimizer pass.
1214
Through RTL simulation with large FIFO buffers (by default set to a depth of 100,000), we estimate the maximum occupation of each FIFO.
1315
Once the maximum depth is determined, the optimizer pass sets the FIFO buffer depth to that value plus 1.
1416

15-
As an example, we show below how to use the optimizer pass, inspired by this `GitHub Gist <https://gist.github.com/nicologhielmetti/3a268be32755448920e9f7d5c78a76d8>`_.
16-
First, we can define a simple neural network in Keras
17+
Below we show an example of the use of the FIFO depth optimization. First, we can define a simple neural network in Keras:
1718

1819
.. code-block:: Python
1920
2021
from tensorflow.keras.layers import Dense
2122
from tensorflow.keras.models import Sequential
2223
2324
model = Sequential()
24-
model.add(Dense(64, input_shape=(16,), name='fc1', activation='relu')
25+
model.add(Dense(64, input_shape=(16,), name='fc1', activation='relu'))
2526
model.add(Dense(32, name='fc2', activation='relu'))
2627
model.add(Dense(32, name='fc3', activation='relu'))
27-
model.add(Dense(5, name='fc3', activation='softmax'))
28+
model.add(Dense(5, name='fc4', activation='softmax'))
2829
29-
Then, we can convert the model, including the flow
30+
Then, we can convert the model, including the flow:
3031

3132
.. code-block:: Python
3233
@@ -47,3 +48,17 @@ Then, we can convert the model, including the flow
4748
hls_model.build(reset=False, csim=True, synth=True, cosim=True)
4849
4950
For more details and results, see `H. Borras et al., "Open-source FPGA-ML codesign for the MLPerf Tiny Benchmark" (2022) <https://arxiv.org/abs/2206.11791>`_.
51+
52+
Similarly, the FIFO buffers can be optimized while using the ``Vitis`` backend with the following changes:
53+
54+
.. code-block:: Python
55+
56+
config['Flows'] = ['vitis:fifo_depth_optimization']
57+
hls4ml.model.optimizer.get_optimizer('vitis:fifo_depth_optimization').configure(profiling_fifo_depth=100_000)
58+
59+
hls_model = hls4ml.converters.convert_from_keras_model(model,
60+
io_type='io_stream',
61+
hls_config=config,
62+
output_dir='hls4mlprj_fifo_depth_opt',
63+
part='xc7z020clg400-1',
64+
backend='Vitis')
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import json
2+
import zipfile
3+
4+
from hls4ml.model.optimizer.optimizer import ConfigurableOptimizerPass, ModelOptimizerPass
5+
6+
7+
def initialize_large_fifos(model, profiling_fifo_depth):
8+
"""Set all FIFO depths equal to a large value so that they can be profiled.
9+
10+
Args:
11+
model (ModelGraph): The model to which FIFO depth optimization is applied.
12+
profiling_fifo_depth (int): A large non-negative integer, must be larger than the max expected depth of the FIFOs.
13+
14+
Returns:
15+
Dict[str, int]: A dictionary containing FIFO names as keys and their initial depths as values is returned for
16+
comparison with the optimized depths.
17+
"""
18+
19+
# filter all the output variables and keep only the internal FIFOs, excluding output objects that are not FIFOs and the
20+
# input and output FIFOs as they can't be profiled and are implementation dependant i.e AXI Stream, AXI Master or
21+
# connected to another IP
22+
vars_to_profile = {
23+
output_variable_name: output_variable
24+
for output_variable_name, output_variable in model.output_vars.items()
25+
if ('StreamVariable' in str(type(output_variable)))
26+
and output_variable != model.get_output_variables()[0]
27+
and output_variable != model.get_input_variables()[0]
28+
}
29+
30+
# initialize all the fifos to `profiling_fifo_depth` so that they will be automatically implemented in BRAMs and so
31+
# they will be profiled. Alternatively, "config_dataflow -override_user_fifo_depth profiling_fifo_depth" can be
32+
# used inside build_prj.tcl to override all FIFO depths with the specified value
33+
initial_fifo_depths = {}
34+
for output_variable in vars_to_profile.values():
35+
if output_variable.pragma:
36+
initial_fifo_depths[output_variable.name] = int(output_variable.pragma[1])
37+
output_variable.pragma = (output_variable.pragma[0], profiling_fifo_depth)
38+
return initial_fifo_depths
39+
40+
41+
def execute_cosim_to_profile_fifos(model):
42+
"""Execute a co-simulation with a test-bench that calls the top function to properly profile the max FIFO depths.
43+
Note that the top function needs to execute **least twice**, so user-provided input must have at least two samples.
44+
45+
Args:
46+
model (ModelGraph): The model to which FIFO depth optimization is applied.
47+
"""
48+
model.write()
49+
50+
model.build(
51+
reset=False,
52+
csim=False,
53+
synth=True,
54+
cosim=True,
55+
validation=False,
56+
export=False,
57+
vsynth=False,
58+
fifo_opt=True,
59+
)
60+
61+
62+
def get_vitis_optimized_fifo_depths(model):
63+
"""Parse the files generated by the co-simulation to retrieve the optimized depths for the FIFOs.
64+
Attention, only the FIFOs between the layers are profiled!
65+
66+
Args:
67+
model (ModelGraph): The model to which FIFO depth optimization is applied.
68+
69+
Returns:
70+
Dict[str, int]: A dictionary that contains the FIFO names as keys and the optimized depths as values.
71+
"""
72+
# channel.zip is generated after the co-simulation and contains the chan_status*.csv files
73+
# in the chan_status*.csv files the max depth achieved during co-simulation can be found at the last (4th) line
74+
path_to_zip_file = (
75+
model.config.get_output_dir()
76+
+ '/'
77+
+ model.config.get_project_name()
78+
+ '_prj'
79+
+ '/solution1/.autopilot/db/channel_depth_info/'
80+
)
81+
82+
with zipfile.ZipFile(f'{path_to_zip_file}channel.zip', 'r') as zip_ref:
83+
zip_ref.extractall(path_to_zip_file)
84+
85+
# the channel_info.csv file contains the mapping of each fifo name (i.e layer4_out_U) to the respective
86+
# chan_status*.csv file
87+
names_file_path = (
88+
model.config.get_output_dir()
89+
+ '/'
90+
+ model.config.get_project_name()
91+
+ '_prj'
92+
+ '/solution1/.autopilot/db/channel_info.csv'
93+
)
94+
95+
csv_fifo_depth_files = {}
96+
with open(names_file_path) as names_file:
97+
for line in names_file:
98+
layer_name = line.split(',')[1]
99+
csv_file_name = line.split(',')[3][:-1]
100+
csv_fifo_depth_files[layer_name] = csv_file_name
101+
102+
optmized_fifo_depths = {}
103+
for layer_name, file_name in csv_fifo_depth_files.items():
104+
with open(path_to_zip_file + file_name) as chan_status_file:
105+
lines = chan_status_file.readlines()
106+
optmized_fifo_depths[layer_name[:-2]] = int(
107+
lines[-1]
108+
) # remove "_U" from the layer name string and keep the last line of the file that contains the max depth
109+
110+
return optmized_fifo_depths
111+
112+
113+
def generate_depths_file(model, initial_fifo_depths, optimized_fifo_depths):
114+
"""Generate a json file with the names of the FIFOs, the initial depths set by hls4ml and their optimized depths,
115+
for post-processing. The json file is not used by the rest of the pipeline, it is only produced for the user.
116+
117+
Args:
118+
model (ModelGraph): The model to which FIFO depth optimization is applied.
119+
initial_fifo_depths (Dict[str, int]): A dictionary that contains the FIFO names as keys and the initial
120+
depths as values.
121+
optimized_fifo_depths (Dict[str, int]): A dictionary that contains the FIFO names as keys and the optimized
122+
depths as values.
123+
"""
124+
depths = {}
125+
for fifo_name in initial_fifo_depths.keys():
126+
depths[fifo_name] = {}
127+
depths[fifo_name]['initial'] = initial_fifo_depths[fifo_name]
128+
depths[fifo_name]['optimized'] = optimized_fifo_depths[fifo_name]
129+
130+
with open(model.config.get_output_dir() + '/fifo_depths.json', 'w') as f:
131+
json.dump(depths, f, indent=4)
132+
133+
134+
def set_optimized_fifo_depths(model, optimized_fifo_depths):
135+
"""Set the new optimized FIFO depths.
136+
137+
Args:
138+
model (ModelGraph): The model to which FIFO depth optimization is applied.
139+
optimized_fifo_depths (Dict[str, int]): A dictionary that contains the FIFO names as keys and the optimized
140+
depths as values.
141+
"""
142+
143+
# iterate through the layer output FIFOs
144+
for output_variable in model.output_vars.values():
145+
if 'StreamVariable' in str(type(output_variable)):
146+
if output_variable.pragma:
147+
148+
if output_variable.name not in optimized_fifo_depths.keys():
149+
continue
150+
151+
filtered_depth = optimized_fifo_depths[output_variable.name]
152+
output_variable.pragma = (output_variable.pragma[0], filtered_depth)
153+
154+
155+
class FifoDepthOptimization(ConfigurableOptimizerPass, ModelOptimizerPass):
156+
def __init__(self):
157+
# use `profiling_fifo_depth = 0` to keep the default fifo depth
158+
# consider changing 100_000 either with a very very large value > of any total bram storage space
159+
# or via vitis 2023.2 c-simulation
160+
self.profiling_fifo_depth = 100_000
161+
162+
def transform(self, model):
163+
"""Perform FIFO depth optimization between the FIFOs of all layers to reduce resource utilization as the
164+
initial FIFOs set by hls4ml might be larger than required. At the end of the optimization the FIFOs will
165+
have the largest depths achieved during co-simulation without causing any deadlocks between the layers
166+
(producer-consumer), thus no additional delays between the layers. In some cases, this optimization
167+
might lead to bigger FIFOs than initially set by the hls4ml tool in order to prevent deadlocks.
168+
169+
Args:
170+
model (ModelGraph): The model to which FIFO depth optimization is applied.
171+
172+
Raises:
173+
ValueError: If the FIFO depth for profiling provided by the user is not a non-negative integer.
174+
RuntimeError: If the IO type is not set to "io_stream".
175+
176+
Returns:
177+
bool: The execution state of the Optimizer Pass
178+
"""
179+
180+
if not isinstance(self.profiling_fifo_depth, int) or self.profiling_fifo_depth <= 0:
181+
raise ValueError('The FIFO depth for profiling (profiling_fifo_depth variable) must be a non-negative integer.')
182+
183+
# check axi-stream or io-stream
184+
if not (model.config.get_config_value('IOType') == 'io_stream'):
185+
raise RuntimeError('To use this optimization you have to set `IOType` field to `io_stream` in the HLS config.')
186+
187+
initial_fifo_depths = initialize_large_fifos(model, self.profiling_fifo_depth)
188+
execute_cosim_to_profile_fifos(model)
189+
optimized_fifo_depths = get_vitis_optimized_fifo_depths(model)
190+
generate_depths_file(model, initial_fifo_depths, optimized_fifo_depths)
191+
set_optimized_fifo_depths(model, optimized_fifo_depths)
192+
193+
print('FIFO optimization completed')
194+
195+
return False

hls4ml/backends/vitis/vitis_backend.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ def _register_flows(self):
3434

3535
self._default_flow = register_flow('ip', None, requires=ip_flow_requirements, backend=self.name)
3636

37+
# Register the fifo depth optimization flow which is different from the one for vivado
38+
fifo_depth_opt_passes = [
39+
'vitis:fifo_depth_optimization'
40+
] + writer_passes # After optimization, a new project will be written
41+
42+
register_flow('fifo_depth_optimization', fifo_depth_opt_passes, requires=['vitis:ip'], backend=self.name)
43+
3744
def create_initial_config(
3845
self,
3946
part='xcvu13p-flga2577-2-e',
@@ -76,7 +83,18 @@ def create_initial_config(
7683

7784
return config
7885

79-
def build(self, model, reset=False, csim=True, synth=True, cosim=False, validation=False, export=False, vsynth=False):
86+
def build(
87+
self,
88+
model,
89+
reset=False,
90+
csim=True,
91+
synth=True,
92+
cosim=False,
93+
validation=False,
94+
export=False,
95+
vsynth=False,
96+
fifo_opt=False,
97+
):
8098
if 'linux' in sys.platform:
8199
found = os.system('command -v vitis_hls > /dev/null')
82100
if found != 0:
@@ -87,8 +105,17 @@ def build(self, model, reset=False, csim=True, synth=True, cosim=False, validati
87105
os.system(
88106
(
89107
'vitis_hls -f build_prj.tcl "reset={reset} csim={csim} synth={synth} cosim={cosim} '
90-
'validation={validation} export={export} vsynth={vsynth}"'
91-
).format(reset=reset, csim=csim, synth=synth, cosim=cosim, validation=validation, export=export, vsynth=vsynth)
108+
'validation={validation} export={export} vsynth={vsynth} fifo_opt={fifo_opt}"'
109+
).format(
110+
reset=reset,
111+
csim=csim,
112+
synth=synth,
113+
cosim=cosim,
114+
validation=validation,
115+
export=export,
116+
vsynth=vsynth,
117+
fifo_opt=fifo_opt,
118+
)
92119
)
93120
os.chdir(curr_dir)
94121

hls4ml/templates/vivado/build_prj.tcl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ if {$opt(csim)} {
179179

180180
if {$opt(synth)} {
181181
puts "***** C/RTL SYNTHESIS *****"
182+
182183
set time_start [clock clicks -milliseconds]
183184
csynth_design
184185
set time_end [clock clicks -milliseconds]
@@ -195,7 +196,10 @@ if {$opt(cosim)} {
195196

196197
if {$opt(fifo_opt)} {
197198
puts "\[hls4ml\] - FIFO optimization started"
198-
add_vcd_instructions_tcl
199+
200+
if {[string equal "$backend" "vivado"] || [string equal $backend "vivadoaccelerator"]} {
201+
add_vcd_instructions_tcl
202+
}
199203
}
200204

201205
remove_recursive_log_wave

hls4ml/templates/vivado/myproject_test.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,16 @@ int main(int argc, char **argv) {
7777
fpr.close();
7878
} else {
7979
std::cout << "INFO: Unable to open input/predictions file, using default input." << std::endl;
80+
const unsigned NUM_TEST_SAMPLES = 5;
81+
for (unsigned i = 0; i < NUM_TEST_SAMPLES; i++) {
82+
// hls-fpga-machine-learning insert zero
8083

81-
// hls-fpga-machine-learning insert zero
82-
83-
// hls-fpga-machine-learning insert top-level-function
84+
// hls-fpga-machine-learning insert top-level-function
8485

85-
// hls-fpga-machine-learning insert output
86+
// hls-fpga-machine-learning insert output
8687

87-
// hls-fpga-machine-learning insert tb-output
88+
// hls-fpga-machine-learning insert tb-output
89+
}
8890
}
8991

9092
fout.close();

hls4ml/writer/vitis_writer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import glob
22
import os
3+
from pathlib import Path
34
from shutil import copy
45

56
from hls4ml.writer.vivado_writer import VivadoWriter
@@ -24,10 +25,34 @@ def write_nnet_utils_overrides(self, model):
2425
for h in headers:
2526
copy(srcpath + h, dstpath + h)
2627

28+
def write_board_script_override(self, model):
29+
'''
30+
Write the tcl scripts and kernel sources to create a Vitis IPI
31+
'''
32+
33+
###################
34+
# project.tcl
35+
###################
36+
37+
prj_tcl_file = Path(f'{model.config.get_output_dir()}/project.tcl')
38+
with open(prj_tcl_file) as f:
39+
prj_tcl_contents = f.readlines()
40+
for line_num, line in enumerate(prj_tcl_contents):
41+
if 'set backend' in line:
42+
prj_tcl_contents[line_num] = 'set backend "vitis"\n'
43+
if 'set clock_uncertainty' in line:
44+
prj_tcl_contents[line_num] = 'set clock_uncertainty {}\n'.format(
45+
model.config.get_config_value('ClockUncertainty', '27%')
46+
)
47+
48+
with open(prj_tcl_file, 'w') as f:
49+
f.writelines(prj_tcl_contents)
50+
2751
def write_hls(self, model):
2852
"""
2953
Write the HLS project. Calls the steps from VivadoWriter, adapted for Vitis
3054
"""
3155
super().write_hls(model)
3256
self.write_nnet_utils_overrides(model)
57+
self.write_board_script_override(model)
3358
self.write_tar(model)

0 commit comments

Comments
 (0)