Skip to content

Commit 304fd74

Browse files
committed
NXP backend: Add dim order support to NeutronBackend.
1 parent 2f501c5 commit 304fd74

File tree

7 files changed

+165
-39
lines changed

7 files changed

+165
-39
lines changed

backends/nxp/_passes/remove_getitem_pass.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77

88
import torch
99

10-
from executorch.backends.nxp.backend.node_format_inference import (
11-
NodeFormat,
12-
NXP_NODE_FORMAT,
13-
)
10+
from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT
1411
from executorch.exir.dialects._ops import ops as exir_ops
1512
from executorch.exir.pass_base import ExportPass, PassResult
1613

backends/nxp/backend/edge_helper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,10 @@ def get_non_qdq_users(node: Node) -> list[Node]:
184184
res.extend(list(dequant_node.users))
185185

186186
return res
187+
188+
189+
def is_channels_last_dim_order(dim_order: list[int]) -> bool:
190+
if len(dim_order) < 3:
191+
return False
192+
193+
return list(dim_order) == [0] + list(range(2, len(dim_order))) + [1]

backends/nxp/backend/edge_program_converter.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,11 @@ def convert_program(
7474
:return: TFLite flatbuffers as bytes.
7575
"""
7676
parameters_mapping = self.map_inputs_to_parameters(edge_program)
77+
dim_order_map = self.map_nodes_to_dim_order(edge_program)
7778

7879
cc = self.build_conversion_context(
7980
parameters_mapping,
81+
dim_order_map,
8082
neutron_target_spec,
8183
conversion_config,
8284
custom_delegation_options,
@@ -174,15 +176,35 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet
174176

175177
return result_map
176178

179+
@staticmethod
180+
def map_nodes_to_dim_order(edge_program: ExportedProgram) -> dict[str, Parameter]:
181+
"""
182+
Create mapping between node names and their dim-orders.
183+
184+
:param edge_program: EdgeProgram instance.
185+
:return: Mapping from node name to dim-order.
186+
"""
187+
188+
return {
189+
n.name: val.dim_order()
190+
for n in edge_program.graph.nodes
191+
if hasattr(val := n.meta.get("val", None), "dim_order")
192+
}
193+
177194
@staticmethod
178195
def build_conversion_context(
179196
parameters_mapping: dict,
197+
dim_order_map: dict[str, ...],
180198
neutron_target_spec: NeutronTargetSpec,
181199
conversion_config: ConversionConfig = _default_conversion_config,
182200
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
183201
) -> ConversionContext:
184202
tflite_builder = AtenModelBuilderDirector(
185-
3, "TFLite from EdgeProgram", neutron_target_spec, conversion_config
203+
3,
204+
"TFLite from EdgeProgram",
205+
neutron_target_spec,
206+
dim_order_map,
207+
conversion_config,
186208
)
187209

188210
# Add "sentinel" buffer (defined in schema.fbs)

backends/nxp/backend/ir/converter/builder/model_builder.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88

99
from copy import deepcopy
1010
from itertools import chain
11-
from typing import Dict, List, Optional, Union
11+
from typing import List, Optional, Union
1212

1313
import executorch.backends.nxp.backend.ir.converter.conversion.translator as translator
1414
import executorch.backends.nxp.backend.ir.logger as logger
1515
import executorch.backends.nxp.backend.ir.tflite_generator.tflite_model as tflite_model
16-
1716
import numpy as np
17+
from executorch.backends.nxp.backend.edge_helper import is_channels_last_dim_order
1818
from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig
1919
from executorch.backends.nxp.backend.ir.converter.builder import (
2020
quantization_verification,
@@ -65,23 +65,25 @@ class ModelBuilder:
6565

6666
_tfl_model: tflite_model.Model
6767

68-
_tensor_name_map: Dict # Mapping 'str' to 'tflT.Tensor'
68+
_tensor_name_map: dict # Mapping 'str' to 'tflT.Tensor'
6969

70-
# Maps BuiltinOperator to a Dict, mapping version to index. Operators of type 'BuiltinOperator.CUSTOM'
70+
# Maps BuiltinOperator to a dict, mapping version to index. Operators of type 'BuiltinOperator.CUSTOM'
7171
# have their 'version' prepended with its name, for example "FlexErf_1".
72-
op_code_type_index_map: Dict[BuiltinOperator, Dict[Union[str, int], int]]
72+
op_code_type_index_map: dict[BuiltinOperator, dict[Union[str, int], int]]
7373

74-
_nchw_tensor_version: Dict # Mapping 'tflT.Tensor' to 'tflT.Tensor' which is
74+
_nchw_tensor_version: dict # Mapping 'tflT.Tensor' to 'tflT.Tensor' which is
7575
# equal, but in NCHW format
7676

77-
_skipped_output_map: Dict # Mapping 'tflT.Tensor' objects that were outputs
77+
_skipped_output_map: dict # Mapping 'tflT.Tensor' objects that were outputs
7878
# of skipped operators, to 'tflT.Tensor' outputs of
7979
# previous operators
8080

81-
_zeros_tensor_map: Dict # Mapping 'string' shapes to 'tflT.Tensor' objects
81+
_zeros_tensor_map: dict # Mapping 'string' shapes to 'tflT.Tensor' objects
8282

8383
neutron_target_spec: NeutronTargetSpec
8484

85+
dim_order_map: dict # Mapping tensor names to their ExecuTorch `dim_order`.
86+
8587
conversion_config: ConversionConfig
8688

8789
_default_conversion_config = ConversionConfig()
@@ -91,11 +93,13 @@ def __init__(
9193
model_version: int,
9294
model_description: str,
9395
neutron_target_spec: NeutronTargetSpec,
96+
dim_order_map: dict[str, ...],
9497
conversion_config: ConversionConfig = _default_conversion_config,
9598
) -> None:
9699
self._tfl_model = tflite_model.Model(model_version, model_description)
97100
self.neutron_target_spec = neutron_target_spec
98101
self.conversion_config = conversion_config
102+
self.dim_order_map = dim_order_map
99103

100104
self.op_code_type_index_map = {}
101105
self._tensor_name_map = {}
@@ -358,6 +362,16 @@ def _make_inputs_channels_first(self):
358362
for input_tensor in self.get_sub_graph().inputs.tmp_inputs:
359363

360364
if input_tensor.tensor_format.is_channels_last():
365+
# The input must be permuted.
366+
367+
if is_channels_last_dim_order(
368+
self.dim_order_map.get(input_tensor.name, [])
369+
):
370+
# Do NOT insert a Transpose, as the input will already be provided in the channels last format
371+
# during runtime.
372+
new_inputs.append(input_tensor)
373+
continue
374+
361375
# Create a Transpose operator and replace the graph input
362376

363377
new_input_shape = translator.channels_last_shape_to_channels_first(
@@ -408,6 +422,16 @@ def _make_outputs_channels_first(self):
408422

409423
for output_tensor in self.get_sub_graph().outputs.tmp_outputs:
410424
if output_tensor.tensor_format.is_channels_last():
425+
# The output must be permuted.
426+
427+
if is_channels_last_dim_order(
428+
self.dim_order_map.get(output_tensor.name, [])
429+
):
430+
# Do NOT insert a Transpose, as the output will be required to be in the channels last format
431+
# during runtime.
432+
new_outputs.append(output_tensor)
433+
continue
434+
411435
# Add a Transpose operator, to make the output channels first
412436

413437
shape = output_tensor.shape.vector

backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Relu = exir_ops.edge.aten.relu.default
2121
Sigmoid = exir_ops.edge.aten.sigmoid.default
2222
Tanh = exir_ops.edge.aten.tanh.default
23+
CloneDimOrder = exir_ops.edge.dim_order_ops._clone_dim_order.default
2324

2425

2526
def insert_qdq_pair_after_node(
@@ -102,6 +103,9 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
102103
MM: [
103104
ViewCopy,
104105
],
106+
ViewCopy: [
107+
CloneDimOrder,
108+
],
105109
}
106110

107111
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class QDQCluster:
7979
exir_ops.edge.aten.relu.default,
8080
exir_ops.edge.aten.sigmoid.default,
8181
exir_ops.edge.aten.tanh.default,
82+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
8283
]
8384

8485
def __init__(self):

0 commit comments

Comments
 (0)