Skip to content

Commit 09e5274

Browse files
dsrankinjmduartevloncar
authored
Fix for 2D conv layers in the special case of io_parallel with full parallelization (#760)
* adding fix for fully partition io_parallel * remove newline * pre-commit * update for vitis compliance * Get rid of unroll factor and follow up in 1D --------- Co-authored-by: Javier Duarte <jduarte@ucsd.edu> Co-authored-by: Vladimir Loncar <vloncar@users.noreply.github.com>
1 parent 51be56e commit 09e5274

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ void conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
3232

3333
PartitionLoop:
3434
for (int i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) {
35-
#pragma HLS PIPELINE II=CONFIG_T::reuse_factor
35+
#pragma HLS PIPELINE II=CONFIG_T::reuse_factor rewind
3636

3737
CONFIG_T::template fill_buffer<data_T, CONFIG_T>::fill_buffer(data, data_buf, i_part);
3838

@@ -45,9 +45,11 @@ void conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
4545
// Do the matrix-multiply
4646
Product1:
4747
for (int i_in = 0; i_in < mult_n_in; i_in++) {
48+
#pragma HLS UNROLL
4849
cache = data_buf[i_pxl][i_in];
4950
Product2:
5051
for (int i_out = 0; i_out < mult_n_out; i_out++) {
52+
#pragma HLS UNROLL
5153
mult[i_in * mult_n_out + i_out] =
5254
CONFIG_T::mult_config::template product<data_T, typename CONFIG_T::mult_config::weight_t>::product(
5355
cache, weights[i_in * mult_n_out + i_out]);
@@ -57,21 +59,25 @@ void conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
5759
// Initialize accumulator with input biases
5860
ResetAccum:
5961
for (int i_acc = 0; i_acc < mult_n_out; i_acc++) {
62+
#pragma HLS UNROLL
6063
acc[i_acc] = (typename CONFIG_T::accum_t)biases[i_acc];
6164
}
6265

6366
// Accumulate multiplication result
6467
Accum1:
6568
for (int i_in = 0; i_in < mult_n_in; i_in++) {
69+
#pragma HLS UNROLL
6670
Accum2:
6771
for (int i_out = 0; i_out < mult_n_out; i_out++) {
72+
#pragma HLS UNROLL
6873
acc[i_out] += mult[i_in * mult_n_out + i_out];
6974
}
7075
}
7176

7277
// Cast to "res_t" type
7378
Result:
7479
for (int i_res = 0; i_res < mult_n_out; i_res++) {
80+
#pragma HLS UNROLL
7581
*(res++) = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[i_res]);
7682
}
7783
}

hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ void conv_2d_latency_cl(
3333

3434
PartitionLoop:
3535
for (int i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) {
36-
#pragma HLS PIPELINE II=CONFIG_T::reuse_factor
36+
#pragma HLS PIPELINE II=CONFIG_T::reuse_factor rewind
3737

3838
CONFIG_T::template fill_buffer<data_T, CONFIG_T>::fill_buffer(data, data_buf, i_part);
3939

@@ -46,9 +46,11 @@ void conv_2d_latency_cl(
4646
// Do the matrix-multiply
4747
Product1:
4848
for (int i_in = 0; i_in < mult_n_in; i_in++) {
49+
#pragma HLS UNROLL
4950
cache = data_buf[i_pxl][i_in];
5051
Product2:
5152
for (int i_out = 0; i_out < mult_n_out; i_out++) {
53+
#pragma HLS UNROLL
5254
mult[i_in * mult_n_out + i_out] =
5355
CONFIG_T::mult_config::template product<data_T, typename CONFIG_T::mult_config::weight_t>::product(
5456
cache, weights[i_in * mult_n_out + i_out]);
@@ -58,21 +60,25 @@ void conv_2d_latency_cl(
5860
// Initialize accumulator with input biases
5961
ResetAccum:
6062
for (int i_acc = 0; i_acc < mult_n_out; i_acc++) {
63+
#pragma HLS UNROLL
6164
acc[i_acc] = (typename CONFIG_T::accum_t)biases[i_acc];
6265
}
6366

6467
// Accumulate multiplication result
6568
Accum1:
6669
for (int i_in = 0; i_in < mult_n_in; i_in++) {
70+
#pragma HLS UNROLL
6771
Accum2:
6872
for (int i_out = 0; i_out < mult_n_out; i_out++) {
73+
#pragma HLS UNROLL
6974
acc[i_out] += mult[i_in * mult_n_out + i_out];
7075
}
7176
}
7277

7378
// Cast to "res_t" type
7479
Result:
7580
for (int i_res = 0; i_res < mult_n_out; i_res++) {
81+
#pragma HLS UNROLL
7682
*(res++) = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[i_res]);
7783
}
7884
}

0 commit comments

Comments
 (0)