Skip to content

Commit 670bc11

Browse files
NXP backend: added aten.mul support (#15971)
### Summary adds support for aten.mul operator ### Test plan tests can be manually run using `pytest -c /dev/null backends/nxp/tests/` cc @robert-kalmar @JakeStevens @digantdesai --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 2f501c5 commit 670bc11

File tree

8 files changed

+365
-13
lines changed

8 files changed

+365
-13
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
exir_ops.edge.aten.max_pool2d.default: MaxPool2dConverter, # noqa F405
3939
exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405
4040
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
41+
exir_ops.edge.aten.mul.Tensor: MulTensorConverter, # noqa F405
4142
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
4243
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
4344
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405

backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.mm_converter import (
3838
MMConverter,
3939
)
40+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.mul_tensor_converter import (
41+
MulTensorConverter,
42+
)
4043
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.permute_copy_converter import (
4144
PermuteCopyConverter,
4245
)
@@ -67,27 +70,28 @@
6770
)
6871

6972
__all__ = [
73+
"AbsConverter",
74+
"AdaptiveAvgPool2dConverter",
7075
"AddMMConverter",
76+
"AddTensorConverter",
77+
"AvgPool2dConverter",
7178
"CatConverter",
79+
"CloneConverter",
80+
"ConstantPadNDConverter",
7281
"ConvolutionConverter",
82+
"HardTanhConverter",
83+
"MaxPool2dConverter",
84+
"MeanDimConverter",
7385
"MMConverter",
86+
"MulTensorConverter",
7487
"PermuteCopyConverter",
75-
"SoftmaxConverter",
76-
"ViewCopyConverter",
77-
"QDQPerTensorDequantizeConverter",
7888
"QDQPerChannelDequantizeConverter",
89+
"QDQPerTensorDequantizeConverter",
7990
"QDQQuantizeConverter",
80-
"ConstantPadNDConverter",
8191
"ReLUConverter",
82-
"MeanDimConverter",
83-
"MaxPool2dConverter",
84-
"AvgPool2dConverter",
85-
"AddTensorConverter",
86-
"SubTensorConverter",
87-
"CloneConverter",
88-
"AbsConverter",
89-
"AdaptiveAvgPool2dConverter",
90-
"HardTanhConverter",
9192
"SigmoidConverter",
93+
"SoftmaxConverter",
94+
"SubTensorConverter",
9295
"TanhConverter",
96+
"ViewCopyConverter",
9397
]
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.nxp.backend.ir.converter.conversion.common import (
7+
node_uses_shape_broadcasting,
8+
)
9+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
10+
CustomDelegationOptions,
11+
NodeConverter,
12+
)
13+
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
14+
mul_options,
15+
)
16+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
17+
from torch.fx import Node
18+
from torch.nn import Parameter
19+
20+
21+
class MulTensorConverter(NodeConverter):
22+
@staticmethod
23+
def _is_supported_on_target(
24+
node: Node,
25+
neutron_target_spec: NeutronTargetSpec,
26+
parameters_mapping: dict[str, Parameter],
27+
custom_delegation_options: CustomDelegationOptions,
28+
) -> bool:
29+
if node_uses_shape_broadcasting(node):
30+
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
31+
return False
32+
33+
node_shape = node.meta["val"].shape
34+
35+
# Check that at least one dimension is divisible by number of MACS
36+
# or all dimensions are equal to one
37+
# Otherwise Neutron cannot convert it
38+
dim_divisible = any(s % 8 == 0 for s in node_shape) or all(
39+
s == 1 for s in node_shape
40+
)
41+
return dim_divisible
42+
43+
@staticmethod
44+
def _is_supported_in_IR(
45+
node: Node,
46+
parameters_mapping: dict[str, Parameter],
47+
custom_delegation_options: CustomDelegationOptions,
48+
) -> bool:
49+
if len(node.args) != 2:
50+
return False
51+
52+
return True
53+
54+
# mul.Tensor Node format: (Tensor self, Tensor other, *)
55+
def convert(self, node: Node):
56+
"""Convert 'mul_tensor' operator to NeutronIR 'Mul'."""
57+
self.assert_convertible(node)
58+
t_op = self._create_tflite_op_with_io_tensors(node)
59+
t_op.builtin_options = mul_options.Mul()
60+
61+
self.builder.append_operators([t_op])

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def tag_qdq_clusters(self, nodes: list[torch.fx.Node]):
208208
exir_ops.edge.aten.max_pool2d_with_indices.default: MaxPool2dConverter, # noqa F405
209209
exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405
210210
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
211+
exir_ops.edge.aten.mul.Tensor: MulTensorConverter, # noqa F405
211212
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
212213
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
213214
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
MaxPoolPattern,
3030
MeanDimPattern,
3131
MmPattern,
32+
MulTensorPattern,
3233
NodeArgsIdx,
3334
PadPattern,
3435
PermutePattern,
@@ -208,6 +209,7 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec):
208209
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
209210
NeutronAtenQuantizer(MeanDimPattern(), static_qconfig),
210211
NeutronAtenQuantizer(MmPattern(self), static_qconfig),
212+
NeutronAtenQuantizer(MulTensorPattern(), static_qconfig),
211213
NeutronAtenQuantizer(PadPattern(), static_qconfig),
212214
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
213215
NeutronAtenQuantizer(ReluPattern(), static_qconfig),

backends/nxp/quantizer/patterns.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,49 @@ def get_anchors(
673673
)
674674

675675

676+
class MulTensorPattern(QuantizationPattern):
677+
"""
678+
Quantization pattern for Mul Tensor quantization. Accepts 1 or 2 input nodes.
679+
680+
Basic quantization for all inputs and output.
681+
"""
682+
683+
def partition_types(self) -> list[torch.nn.Module]:
684+
return [torch.ops.aten.mul.Tensor]
685+
686+
def get_anchors(
687+
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
688+
) -> PartitionAnchors | None:
689+
node = fused_partition[0].nodes[-1]
690+
input_nodes = node.all_input_nodes
691+
692+
qspec = FixedQParamsQuantizationSpec(
693+
dtype=torch.int8,
694+
scale=1.0 / 256.0,
695+
zero_point=0,
696+
quant_min=-128,
697+
quant_max=127,
698+
qscheme=torch.per_tensor_affine,
699+
)
700+
701+
# The "Mul" operator in Neutron IR requires a specific scale and zero_point
702+
# (defined above) for its inputs.
703+
# Since these input nodes have already been annotated by their own patterns
704+
# which didn't take the requirements of "Mul" into account, we need to overwrite
705+
# the existing "quantization_annotation".
706+
for input_node in input_nodes:
707+
input_node.meta["quantization_annotation"].output_qspec = qspec
708+
709+
return PartitionAnchors(
710+
inputs=[(node, NodeArgsIdx(0), qspec), (node, NodeArgsIdx(1), qspec)],
711+
weights=[],
712+
biases=[],
713+
output=[
714+
(node,),
715+
],
716+
)
717+
718+
676719
class PadPattern(SharedSpecPattern):
677720
"""
678721
Quantizer for Pad operator.

0 commit comments

Comments
 (0)