Skip to content

Commit 9e3fc8d

Browse files
authored
Merge branch 'main' into split_pointwise_conv_by_rf_codegen
2 parents 6d84b80 + e778ed3 commit 9e3fc8d

File tree

14 files changed

+215
-94
lines changed

14 files changed

+215
-94
lines changed

hls4ml/backends/catapult/catapult_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def _register_flows(self):
8888
init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name)
8989

9090
streaming_passes = [
91+
'catapult:inplace_stream_flatten', # Inform downstream changed packsize in case of skipping flatten
9192
'catapult:reshape_stream',
9293
'catapult:clone_output',
9394
'catapult:insert_zero_padding_before_conv1d',

hls4ml/backends/fpga/passes/clone.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import numpy as np
1+
from math import prod
22

33
from hls4ml.backends.template import FunctionCallTemplate
44
from hls4ml.model.layers import Layer, register_layer
@@ -54,41 +54,60 @@ def match(self, node):
5454
if isinstance(node, Clone):
5555
return False
5656

57-
return True
57+
# Not needed for io_parallel
58+
io_type = node.model.config.get_config_value('IOType')
59+
if io_type != 'io_stream':
60+
return False
61+
62+
# Check if the output is used more than once
63+
output_map = node.get_output_use_map()
64+
in_output = node.name in node.model.outputs
65+
for output in node.outputs:
66+
if len(output_map[output]) + in_output > 1:
67+
# model output also need a stream
68+
return True
69+
70+
return False
5871

5972
def transform(self, model, node):
60-
if model.config.get_config_value('IOType') != 'io_stream':
61-
return False
6273

6374
output_map = node.get_output_use_map()
75+
in_output = node.name in node.model.outputs
6476

6577
transformed = False
6678
for output in node.outputs:
67-
if len(output_map[output]) > 1:
68-
if len(output_map[output]) > 3:
69-
print(
70-
'WARNING: Cloning output {} of {} ({}) more than 3 times not currently supported'.format(
71-
output, node.__class__.__name__, node.name
72-
)
73-
)
74-
return False
75-
out_var = node.get_output_variable(output)
76-
for i, layer in enumerate(output_map[output], 1):
77-
attrs = {'size': np.prod(out_var.shape)}
78-
idx = layer.inputs.index(output)
79-
layer.inputs[idx] = output + '_cpy' + str(i)
80-
81-
clone_layer: Clone = model.make_node(
82-
Clone,
83-
'clone_' + node.name,
84-
attrs,
85-
[output],
86-
[output + '_cpy' + str(i + 1) for i in range(len(output_map[output]))],
87-
)
88-
for i in range(len(output_map[output])):
89-
key = output + '_cpy' + str(i + 1)
90-
clone_layer.attributes[key].type = node.attributes['result_t']
91-
model.insert_node(clone_layer)
92-
transformed = True
79+
n_outputs = len(output_map[output]) + in_output
80+
if n_outputs == 1:
81+
continue
82+
if n_outputs > 3:
83+
msg = f'ERROR: Cloning output {output} of {node.class_name}\
84+
({node.name}) more than 3 times not currently supported'
85+
raise ValueError(msg)
86+
87+
out_var = node.get_output_variable(output)
88+
attrs = {'size': prod(out_var.shape)}
89+
90+
init_stream_idx = 1
91+
if in_output:
92+
# If the value is used as output, add one extra stream
93+
idx = node.model.outputs.index(node.name)
94+
node.model.outputs[idx] = node.name + '_cpy1'
95+
init_stream_idx = 2
96+
for i, layer in enumerate(output_map[output], init_stream_idx):
97+
idx = layer.inputs.index(output)
98+
layer.inputs[idx] = output + f'_cpy{i}'
99+
100+
clone_layer: Clone = model.make_node(
101+
Clone,
102+
'clone_' + node.name,
103+
attrs,
104+
[output],
105+
[output + '_cpy' + str(i + 1) for i in range(n_outputs)],
106+
)
107+
for i in range(n_outputs):
108+
key = output + '_cpy' + str(i + 1)
109+
clone_layer.attributes[key].type = node.attributes['result_t']
110+
model.insert_node(clone_layer)
111+
transformed = True
93112

94113
return transformed

hls4ml/backends/fpga/passes/inplace_parallel_reshape.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,21 @@ class InplaceParallelReshape(OptimizerPass):
1111
"""
1212

1313
def match(self, node):
14-
return isinstance(node, Reshape)
15-
16-
def transform(self, model, node):
17-
if model.config.get_config_value('IOType') != 'io_parallel':
14+
if not isinstance(node, Reshape):
1815
return False
16+
return node.model.config.get_config_value('IOType') == 'io_parallel'
1917

18+
def transform(self, model, node):
2019
outvar = node.get_output_variable()
2120
invar = node.get_input_variable()
2221
newoutvar = InplaceTensorVariable(outvar, invar)
2322
node.set_attr(node.outputs[0], newoutvar)
23+
if node.name in model.outputs:
24+
prev_node = node.get_input_node()
25+
assert (
26+
prev_node.name not in model.outputs
27+
), f"Cannot output node {prev_node.name}: reshape is a no-op in io_parallel.\
28+
As a result, the previous node {prev_node.name}'s output will be used as the\
29+
output. However, this node is already an output."
30+
model.outputs = [name if name != node.name else prev_node.name for name in model.outputs]
2431
return False

hls4ml/backends/fpga/passes/inplace_stream_flatten.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,20 @@ class InplaceStreamFlatten(OptimizerPass):
1111
"""
1212

1313
def match(self, node):
14-
# Reshape acts as a Flatten layer when the result has 1 dimension
15-
return isinstance(node, Reshape) and len(node.get_output_variable().shape) == 1
14+
# Layers require flatten data can gather it from the stream, no need for repacking.
15+
# Reshape acts as a Flatten layer when the result has 1 dimension. Make it a inplace tensor if it happens.
1616

17-
def transform(self, model, node):
18-
if model.config.get_config_value('IOType') != 'io_stream':
17+
if node.model.config.get_config_value('IOType') != 'io_stream':
18+
return False
19+
if not (isinstance(node, Reshape) and len(node.get_output_variable().shape) == 1):
20+
# If is not flatten
1921
return False
22+
if node.name in node.model.outputs:
23+
# If used as model output. Output shape shall be preserved in this case.
24+
return False
25+
return True
2026

27+
def transform(self, model, node):
2128
outvar = node.get_output_variable()
2229
invar = node.get_input_variable()
2330
newoutvar = InplaceTensorVariable(outvar, invar)

hls4ml/backends/fpga/passes/repack_stream.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class ReshapeStream(OptimizerPass):
4949

5050
def match(self, node):
5151
# do not run optimizer pass for a flatten layer (1 output dimension)
52-
return isinstance(node, Reshape) and len(node.get_output_variable().shape) > 1
52+
if not isinstance(node, Reshape):
53+
return False
54+
return len(node.get_output_variable().shape) > 1 or node.name in node.model.outputs
5355

5456
def transform(self, model, node):
5557
if model.config.get_config_value('IOType') != 'io_stream':

hls4ml/backends/quartus/quartus_backend.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ def _register_flows(self):
4848
initializers = self._get_layer_initializers()
4949
init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name)
5050

51-
streaming_passes = ['quartus:reshape_stream', 'quartus:clone_output']
51+
streaming_passes = [
52+
'quartus:inplace_stream_flatten', # Inform downstream changed packsize in case of skipping flatten
53+
'quartus:reshape_stream',
54+
'quartus:clone_output',
55+
]
56+
5257
streaming_flow = register_flow('streaming', streaming_passes, requires=[init_flow], backend=self.name)
5358

5459
quartus_types = [

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def _register_flows(self):
7878
init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name)
7979

8080
streaming_passes = [
81+
'vivado:inplace_stream_flatten', # Inform downstream changed packsize in case of skipping flatten
8182
'vivado:reshape_stream',
8283
'vivado:clone_output',
8384
'vivado:insert_zero_padding_before_conv1d',

hls4ml/model/graph.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,8 @@ def insert_node(self, node, before=None, input_idx=0):
506506

507507
if next_node is not None:
508508
next_node.inputs[input_idx] = node.outputs[0]
509+
else:
510+
self.outputs = [node.outputs[0] if name == prev_node.outputs[0] else name for name in self.outputs]
509511

510512
new_graph = OrderedDict()
511513
for k, v in self.graph.items():
@@ -514,47 +516,57 @@ def insert_node(self, node, before=None, input_idx=0):
514516
new_graph[node.name] = node
515517

516518
self.graph = new_graph
517-
self._update_model_outputs()
518519

519520
def remove_node(self, node, rewire=True):
520-
"""Remove a node from a graph.
521+
"""Removes a node from the graph.
521522
522-
By default, this function can connect the outputs of previous node to the input of next one.
523-
Note that when removing a leaf node `rewire` should be set to `False`.
523+
By default, this function connects the outputs of the previous
524+
node to the inputs of the next node. If the removed node has multiple
525+
input/output tensors, an exception is raised.
524526
525527
Args:
526-
node (Layer): The node to remove
527-
rewire (bool, optional): If `True`, connects the outputs of the previous node
528-
to the inputs of the next node
528+
node (Layer): The node to remove.
529+
rewire (bool, optional): Deprecated, has no effect.
529530
530531
Raises:
531-
Exception: If an attempt is made to rewire a leaf node or a node with multiple
532-
inputs/outputs.
532+
Exception: If an attempt is made to rewire a node with
533+
multiple inputs/outputs.
533534
535+
Note:
536+
The `rewire` parameter is deprecated and has no effect.
534537
"""
535-
if rewire:
536-
inputs = [inp for inp in node.inputs if inp]
537-
outputs = [outp for outp in node.outputs if outp]
538-
if len(inputs) > 1 or len(outputs) > 1:
539-
raise Exception('Cannot rewire a node with multiple inputs/outputs')
540-
prev_node = node.get_input_node(node.inputs[0])
538+
539+
inputs = [inp for inp in node.inputs if inp]
540+
outputs = [outp for outp in node.outputs if outp]
541+
542+
if len(inputs) > 1 or len(outputs) > 1:
543+
raise Exception('Cannot delete a node with multiple inputs/outputs')
544+
545+
if len(inputs) == 1:
546+
# Connect inputs -> $outputs
547+
if node.name in self.outputs:
548+
msg = f'Remove leaf node {node.name} will connect its input node {inputs[0]} to output, but it already is.'
549+
assert inputs[0] not in self.outputs, msg
550+
self.outputs = [inputs[0] if name == node.name else name for name in self.outputs]
551+
552+
if len(outputs) == 1 and len(inputs) == 1:
553+
inp_var = node.get_input_variable()
554+
out_var = node.get_output_variable()
555+
556+
# fmt: off
557+
assert (np.prod(inp_var.shape) == np.prod(out_var.shape)), \
558+
f'Input and output shapes do not match for {node.name}: {inp_var.shape} -> {out_var.shape}'
559+
# fmt: on
560+
541561
next_nodes = [x for x in self.graph.values() if node.outputs[0] in x.inputs]
542-
if prev_node is not None:
543-
if len(next_nodes) > 0:
544-
for next_node in next_nodes:
545-
for i, _ in enumerate(next_node.inputs):
546-
if node.outputs[0] == next_node.inputs[i]:
547-
next_node.inputs[i] = prev_node.outputs[0]
548-
break
549-
else:
550-
if not node.outputs[0] in self.outputs:
551-
raise Exception('Cannot rewire a node without child')
552-
else:
553-
raise Exception('Cannot rewire a node without a parent')
562+
for next_node in next_nodes:
563+
# Connect inputs -> next
564+
for i, nxt_inp in enumerate(next_node.inputs):
565+
if outputs[0] == nxt_inp:
566+
next_node.inputs[i] = inputs[0]
554567

555568
del self.output_vars[node.outputs[0]]
556569
del self.graph[node.name]
557-
self._update_model_outputs()
558570

559571
def replace_node(self, old_node, new_node):
560572
"""Replace an existing node in the graph with a new one.
@@ -584,7 +596,11 @@ def replace_node(self, old_node, new_node):
584596
node.outputs[i] = repl[n]
585597

586598
self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items())
587-
self._update_model_outputs()
599+
600+
old_name = old_node.name
601+
if old_name in self.outputs:
602+
new_name = new_node.name
603+
self.outputs = [new_name if name == old_name else name for name in self.outputs]
588604

589605
def split_node(self, old_node, new_node1, new_node2):
590606
"""Replace an existing node in the graph with two nodes in sequence.
@@ -622,17 +638,9 @@ def split_node(self, old_node, new_node1, new_node2):
622638
else:
623639
new_graph[key] = value
624640
self.graph = new_graph
625-
self._update_model_outputs()
626-
627-
def _update_model_outputs(self):
628-
'''Update the model outputs
629641

630-
All node outputs and inputs are found. The model outputs are set to all node outputs
631-
that are not also node inputs.
632-
'''
633-
node_outputs = [out for node in self.graph.values() for out in node.outputs]
634-
node_inputs = [inp for node in self.graph.values() for inp in node.inputs]
635-
self.outputs = [out for out in node_outputs if out not in node_inputs]
642+
if old_node.name in self.outputs:
643+
self.outputs = [new_node2.name if name == old_node.name else name for name in self.outputs]
636644

637645
def next_layer(self):
638646
self.index += 1

hls4ml/model/optimizer/passes/linear.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def transform(self, model, node):
4040
# if the activation has a quantizer (usually from a QONNX Quant node), set the previous node's output precision
4141
if quantizer is not None:
4242
prev_node.set_attr("quantizer", quantizer)
43-
prev_node.types['result_t'] = quantizer.hls_type
4443
prev_node.get_output_variable().type.precision = quantizer.hls_type
4544
model.remove_node(node)
4645
return True

hls4ml/model/optimizer/passes/merge_const.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def transform(self, model, node):
5454
const_node0.set_attr('quantizer', quantizer) # overwrite the quantizer
5555
if quantizer:
5656
const_node0.set_attr('quantizer', quantizer)
57-
const_node0.types['result_t'] = quantizer.hls_type
5857
const_node0.get_output_variable().type.precision = quantizer.hls_type
5958
const_node0.set_attr('value', new_val)
6059

0 commit comments

Comments
 (0)