Skip to content

Commit 0c54fd0

Browse files
Cortex_m backend: Support channels-broadcasting for ADD/MUL (#16131)
Adds support for broadcasting in the special case where one input contains only channels which are broadcasted onto every spatial element of the other tensor, e.g. [1,C,1,1] + [N, C, H, W] for channel-last tensors. This is a needed for mobilenet_v3. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 152ea3e commit 0c54fd0

File tree

8 files changed

+241
-67
lines changed

8 files changed

+241
-67
lines changed

backends/cortex_m/ops/cortex_m_ops_common.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,43 @@ inline void validate_quantization_params(
129129
"Single quant Output");
130130
}
131131

132+
inline bool is_channels_last_tensor(const Tensor& tensor) {
133+
if (tensor.dim() != 4) {
134+
return false;
135+
}
136+
137+
// When channels or spatial dims are 1 the layout information is ambiguous.
138+
if (tensor.size(1) == 1 || (tensor.size(2) == 1 && tensor.size(3) == 1)) {
139+
return true;
140+
}
141+
142+
constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = {
143+
0, 2, 3, 1};
144+
executorch::aten::ArrayRef<executorch::aten::DimOrderType>
145+
channels_last_order(kChannelsLastDimOrder, 4);
146+
147+
return tensor.dim_order() == channels_last_order;
148+
}
149+
150+
inline bool is_channel_broadcast(const Tensor& tensor1, const Tensor& tensor2) {
151+
if (tensor1.dim() != tensor2.dim()) {
152+
return false;
153+
}
154+
155+
if (tensor1.dim() != 4) {
156+
return false;
157+
}
158+
159+
if (tensor1.size(1) != tensor2.size(1)) {
160+
return false;
161+
}
162+
163+
const bool tensor1_channels_only = tensor1.numel() == tensor1.size(1);
164+
const bool tensor2_channels_only = tensor2.numel() == tensor2.size(1);
165+
166+
return tensor1_channels_only || tensor2_channels_only;
167+
}
168+
132169
// Refer to CMSIS-NN 'arm_nn_requantize' implementation for details:
133170
// https://github.com/ARM-software/CMSIS-NN/blob/main/Include/arm_nnsupportfunctions.h#L1625
134171
// multiplier: Range {ARM_NN_Q31_MIN + 1, Q32_MAX}

backends/cortex_m/ops/op_quantized_add.cpp

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,14 @@ Tensor& quantized_add_out(
3333
const Scalar& output_shift,
3434
Tensor& out) {
3535
// Validate tensor types and dim order
36-
validate_cmsis_nn_tensor_requirements(input1_int8, input2_int8, out);
36+
bool channel_broadcast = is_channel_broadcast(input1_int8, input2_int8);
37+
validate_cmsis_nn_tensor_requirements(
38+
input1_int8,
39+
input2_int8,
40+
out,
41+
ScalarType::Char,
42+
/*require_channels_last=*/channel_broadcast,
43+
/*require_same_sizes=*/!channel_broadcast);
3744

3845
// Validate quantization parameters
3946
validate_quantization_params(
@@ -62,6 +69,8 @@ Tensor& quantized_add_out(
6269
int32_t out_zp = extractScalarToInt32(output_zero_point);
6370
int32_t output_mult = extractScalarToInt32(output_multiplier);
6471
int output_shift_val = extractScalarToInt(output_shift);
72+
int8_t* input1_ptr = input1_int8.data_ptr<int8_t>();
73+
int8_t* input2_ptr = input2_int8.data_ptr<int8_t>();
6574

6675
// Left shift to maximize precision
6776
const int32_t left_shift = 20;
@@ -87,33 +96,49 @@ Tensor& quantized_add_out(
8796
// addition. To preserve precision when rescaling the inputs, they are first
8897
// upscaled as much as possible, Hence the left_shift parameter required here.
8998

90-
// Call CMSIS-NN kernel with precomputed parameters
91-
arm_cmsis_nn_status status = arm_elementwise_add_s8(
92-
input1_int8.const_data_ptr<int8_t>(),
93-
input2_int8.const_data_ptr<int8_t>(),
94-
-static_cast<int32_t>(zp1),
95-
input1_mult,
96-
input1_shift_val,
97-
-static_cast<int32_t>(zp2),
98-
input2_mult,
99-
input2_shift_val,
100-
left_shift,
101-
out.mutable_data_ptr<int8_t>(),
102-
static_cast<int32_t>(out_zp),
103-
output_mult,
104-
output_shift_val,
105-
activation_min,
106-
activation_max,
107-
static_cast<int32_t>(out.numel()));
108-
109-
if (status != ARM_CMSIS_NN_SUCCESS) {
110-
ET_LOG(
111-
Error,
112-
"quantized_add_out: arm_elementwise_add_s8 failed with status [%d]",
113-
status);
114-
115-
context.fail(Error::Internal); // Fail the execution context
116-
return out;
99+
int32_t adds_per_loop = 0;
100+
if (channel_broadcast) {
101+
if (input1_int8.numel() < input2_int8.numel()) {
102+
std::swap<int32_t>(zp1, zp2);
103+
std::swap<int32_t>(input1_mult, input2_mult);
104+
std::swap<int>(input1_shift_val, input2_shift_val);
105+
std::swap<int8_t*>(input1_ptr, input2_ptr);
106+
}
107+
adds_per_loop = input1_int8.size(1);
108+
} else {
109+
adds_per_loop = out.numel();
110+
}
111+
112+
for (int32_t broadcast_offset = 0; broadcast_offset < out.numel();
113+
broadcast_offset += adds_per_loop) {
114+
// Call CMSIS-NN kernel with precomputed parameters
115+
arm_cmsis_nn_status status = arm_elementwise_add_s8(
116+
input1_ptr + broadcast_offset,
117+
input2_ptr,
118+
-static_cast<int32_t>(zp1),
119+
input1_mult,
120+
input1_shift_val,
121+
-static_cast<int32_t>(zp2),
122+
input2_mult,
123+
input2_shift_val,
124+
left_shift,
125+
out.mutable_data_ptr<int8_t>() + broadcast_offset,
126+
static_cast<int32_t>(out_zp),
127+
output_mult,
128+
output_shift_val,
129+
activation_min,
130+
activation_max,
131+
adds_per_loop);
132+
133+
if (status != ARM_CMSIS_NN_SUCCESS) {
134+
ET_LOG(
135+
Error,
136+
"quantized_add_out: arm_elementwise_add_s8 failed with status [%d]",
137+
status);
138+
139+
context.fail(Error::Internal); // Fail the execution context
140+
return out;
141+
}
117142
}
118143
ET_LOG(
119144
Info,

backends/cortex_m/ops/op_quantized_mul.cpp

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,15 @@ Tensor& quantized_mul_out(
3434
const Scalar& output_shift,
3535
Tensor& out) {
3636
// Validate tensor types and quantization parameters
37-
validate_cmsis_nn_tensor_requirements(input1_int8, input2_int8, out);
37+
38+
bool channel_broadcast = is_channel_broadcast(input1_int8, input2_int8);
39+
validate_cmsis_nn_tensor_requirements(
40+
input1_int8,
41+
input2_int8,
42+
out,
43+
ScalarType::Char,
44+
/*require_channels_last=*/channel_broadcast,
45+
/*require_same_sizes=*/!channel_broadcast);
3846

3947
const Scalar kIdentityMultiplier(/*value=*/1);
4048
const Scalar kZeroShift(/*value=*/0);
@@ -51,12 +59,26 @@ Tensor& quantized_mul_out(
5159
out);
5260

5361
// Extract quantization parameters
54-
const int32_t zp1 = extractScalarToInt32(input1_zero_point);
55-
const int32_t zp2 = extractScalarToInt32(input2_zero_point);
62+
int8_t* input1_ptr = input1_int8.data_ptr<int8_t>();
63+
int8_t* input2_ptr = input2_int8.data_ptr<int8_t>();
64+
int32_t zp1 = extractScalarToInt32(input1_zero_point);
65+
int32_t zp2 = extractScalarToInt32(input2_zero_point);
5666
const int32_t out_zp = extractScalarToInt32(output_zero_point);
5767
const int32_t output_mult = extractScalarToInt32(output_multiplier);
5868
const int32_t output_shift_val = extractScalarToInt32(output_shift);
5969

70+
int32_t muls_per_loop = 0;
71+
72+
if (channel_broadcast) {
73+
if (input1_int8.numel() < input2_int8.numel()) {
74+
std::swap<int32_t>(zp1, zp2);
75+
std::swap<int8_t*>(input1_ptr, input2_ptr);
76+
}
77+
78+
muls_per_loop = input1_int8.size(1);
79+
} else {
80+
muls_per_loop = out.numel();
81+
}
6082
// Note 1: The CMSIS-NN kernel implementation uses offsets which are always
6183
// added to the data, whereas zero_points are subtracted when dequantizing
6284
// (for the inputs) and added when quantizing (for the output). Hence the
@@ -72,29 +94,31 @@ Tensor& quantized_mul_out(
7294
// effective_scale = (scale_in1 * scale_in2 / scale_out)
7395
// Hence no input quantization params required here.
7496

75-
// Call CMSIS-NN elementwise multiply kernel
76-
arm_cmsis_nn_status status = arm_elementwise_mul_s8(
77-
input1_int8.const_data_ptr<int8_t>(),
78-
input2_int8.const_data_ptr<int8_t>(),
79-
-static_cast<int32_t>(zp1),
80-
-static_cast<int32_t>(zp2),
81-
out.mutable_data_ptr<int8_t>(),
82-
static_cast<int32_t>(out_zp),
83-
output_mult,
84-
output_shift_val,
85-
kInt8ActivationMin,
86-
kInt8ActivationMax,
87-
static_cast<int32_t>(out.numel()));
88-
89-
if (status != ARM_CMSIS_NN_SUCCESS) {
90-
ET_LOG(
91-
Error,
92-
"quantized_mul_out: arm_elementwise_mul_s8 failed with status [%d]",
93-
status);
94-
context.fail(Error::Internal);
95-
return out;
96-
}
97+
for (int32_t broadcast_offset = 0; broadcast_offset < out.numel();
98+
broadcast_offset += muls_per_loop) {
99+
// Call CMSIS-NN elementwise multiply kernel
100+
arm_cmsis_nn_status status = arm_elementwise_mul_s8(
101+
input1_ptr + broadcast_offset,
102+
input2_ptr,
103+
-static_cast<int32_t>(zp1),
104+
-static_cast<int32_t>(zp2),
105+
out.mutable_data_ptr<int8_t>() + broadcast_offset,
106+
static_cast<int32_t>(out_zp),
107+
output_mult,
108+
output_shift_val,
109+
kInt8ActivationMin,
110+
kInt8ActivationMax,
111+
muls_per_loop);
97112

113+
if (status != ARM_CMSIS_NN_SUCCESS) {
114+
ET_LOG(
115+
Error,
116+
"quantized_mul_out: arm_elementwise_mul_s8 failed with status [%d]",
117+
status);
118+
context.fail(Error::Internal);
119+
return out;
120+
}
121+
}
98122
return out;
99123
}
100124

backends/cortex_m/ops/operators.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.nn.functional as F
1313
from executorch.backends.cortex_m.passes.passes_utils import (
14+
is_channel_broadcast,
1415
requantize_cmsis,
1516
SHIFT_INT8,
1617
)
@@ -140,12 +141,15 @@ def quantized_add_meta(
140141
output_multiplier: int,
141142
output_shift: int,
142143
) -> torch.Tensor:
143-
assert self.shape == other.shape, (
144-
"Cortex-M quantized_mul: broadcasting is not yet supported — "
144+
assert self.shape == other.shape or is_channel_broadcast(self, other), (
145+
"Cortex-M quantized_add: broadcasting is not yet supported except for channel dim — "
145146
f"got self.shape={self.shape}, other.shape={other.shape}"
146147
)
147-
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
148-
return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device)
148+
if self.numel() > other.numel():
149+
output_tensor = self
150+
else:
151+
output_tensor = other
152+
return torch.empty_like(output_tensor)
149153

150154

151155
@impl(lib, "quantized_add", "CompositeExplicitAutograd")
@@ -162,8 +166,8 @@ def quantized_add_impl(
162166
output_multiplier: int,
163167
output_shift: int,
164168
) -> torch.Tensor:
165-
assert self.shape == other.shape, (
166-
"Cortex-M quantized_mul: broadcasting is not yet supported — "
169+
assert self.shape == other.shape or is_channel_broadcast(self, other), (
170+
"Cortex-M quantized_add: broadcasting is not yet supported except for channel dim — "
167171
f"got self.shape={self.shape}, other.shape={other.shape}"
168172
)
169173
self_shifted = (self.to(torch.int32) - self_zero_point) << SHIFT_INT8
@@ -207,12 +211,15 @@ def quantized_mul_meta(
207211
output_shift: int,
208212
) -> torch.Tensor:
209213
# Broadcast to output shape
210-
assert self.shape == other.shape, (
211-
"Cortex-M quantized_mul: broadcasting is not yet supported — "
214+
assert self.shape == other.shape or is_channel_broadcast(self, other), (
215+
"Cortex-M quantized_mul: broadcasting is not yet supported except for channel dim — "
212216
f"got self.shape={self.shape}, other.shape={other.shape}"
213217
)
214-
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
215-
return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device)
218+
if self.numel() > other.numel():
219+
output_tensor = self
220+
else:
221+
output_tensor = other
222+
return torch.empty_like(output_tensor)
216223

217224

218225
@impl(lib, "quantized_mul", "CompositeExplicitAutograd")
@@ -228,8 +235,8 @@ def quantized_mul_impl(
228235
# CMSIS-NN kernel multiplies raw int8 tensors (after zero-point offset) and
229236
# only uses the output multiplier/shift for rescaling. Mirror that here to
230237
# keep the composite implementation numerically aligned with the backend.
231-
assert self.shape == other.shape, (
232-
"Cortex-M quantized_mul: broadcasting is not yet supported — "
238+
assert self.shape == other.shape or is_channel_broadcast(self, other), (
239+
"Cortex-M quantized_mul: broadcasting is not yet supported except for channel dim — "
233240
f"got self.shape={self.shape}, other.shape={other.shape}"
234241
)
235242
self_int = self.to(torch.int32) - self_zero_point

backends/cortex_m/passes/passes_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,34 @@ def cleanup_nodes(nodes_to_erase, graph):
193193
print(f"Warning: {len(failed_nodes)} nodes could not be erased")
194194

195195
return failed_nodes
196+
197+
198+
def is_channels_last(tensor: torch.Tensor) -> bool:
199+
"""Check if a 4D tensor is in channels last format."""
200+
if tensor.ndim != 4:
201+
return False
202+
203+
if tensor.shape[1] == 1 or tensor.shape[2] == tensor.shape[3] == 1:
204+
return True
205+
206+
dim_order = list(tensor.dim_order())
207+
return dim_order[0:2] == [0, 2]
208+
209+
210+
def is_channel_broadcast(tensor1: torch.Tensor, tensor2: torch.Tensor) -> bool:
211+
"""
212+
Check if tensor1 is broadcasted to tensor2 along channel dimension.
213+
Assumes tensor2 has shape [N, C, ...] and tensor1 has shape [N, 1, ...] or [1, C, ...].
214+
"""
215+
if tensor1.dim() != tensor2.dim():
216+
return False
217+
if not is_channels_last(tensor1):
218+
return False
219+
if not is_channels_last(tensor2):
220+
return False
221+
222+
channel_match = tensor1.size(1) == tensor2.size(1)
223+
tensor1_channels_only = tensor1.numel() == tensor1.size(1)
224+
tensor2_channels_only = tensor2.numel() == tensor2.size(1)
225+
226+
return channel_match and (tensor1_channels_only or tensor2_channels_only)

backends/cortex_m/quantizer/quantizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
1212
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1313
from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager
14+
from executorch.backends.cortex_m.passes.passes_utils import (
15+
is_channel_broadcast,
16+
is_channels_last,
17+
)
1418
from executorch.backends.cortex_m.quantizer.operator_configs import (
1519
BINARY_OP_PATTERNS,
1620
CONV_OP_PATTERNS,
@@ -61,7 +65,9 @@ def broadcasting_filter(self, node: Optional[Node]) -> bool:
6165
if len(node.all_input_nodes) == 2:
6266
t1 = get_first_fake_tensor(node.all_input_nodes[0])
6367
t2 = get_first_fake_tensor(node.all_input_nodes[1])
64-
return t1.shape != t2.shape
68+
return t1.shape != t2.shape and not (
69+
is_channel_broadcast(t1, t2) and is_channels_last(t1)
70+
)
6571

6672
return False
6773

@@ -78,7 +84,7 @@ def nchw_filter(self, node: Optional[Node]) -> bool:
7884
if tensor is None:
7985
return False
8086

81-
return not tensor.is_contiguous(memory_format=torch.channels_last)
87+
return not is_channels_last(tensor)
8288

8389
def __init__(self) -> None:
8490
quantizers: List[Quantizer] = [

0 commit comments

Comments
 (0)