diff --git a/backends/nxp/_passes/remove_getitem_pass.py b/backends/nxp/_passes/remove_getitem_pass.py index e2759e25e36..316cc13f49c 100644 --- a/backends/nxp/_passes/remove_getitem_pass.py +++ b/backends/nxp/_passes/remove_getitem_pass.py @@ -7,10 +7,7 @@ import torch -from executorch.backends.nxp.backend.node_format_inference import ( - NodeFormat, - NXP_NODE_FORMAT, -) +from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult diff --git a/backends/nxp/backend/edge_helper.py b/backends/nxp/backend/edge_helper.py index 1e80d272bad..d78997ea4a6 100644 --- a/backends/nxp/backend/edge_helper.py +++ b/backends/nxp/backend/edge_helper.py @@ -184,3 +184,10 @@ def get_non_qdq_users(node: Node) -> list[Node]: res.extend(list(dequant_node.users)) return res + + +def is_channels_last_dim_order(dim_order: list[int]) -> bool: + if len(dim_order) < 3: + return False + + return list(dim_order) == [0] + list(range(2, len(dim_order))) + [1] diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index 4189ac2dc47..30630e7fdc7 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -74,9 +74,11 @@ def convert_program( :return: TFLite flatbuffers as bytes. """ parameters_mapping = self.map_inputs_to_parameters(edge_program) + dim_order_map = self.map_nodes_to_dim_order(edge_program) cc = self.build_conversion_context( parameters_mapping, + dim_order_map, neutron_target_spec, conversion_config, custom_delegation_options, @@ -174,15 +176,35 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet return result_map + @staticmethod + def map_nodes_to_dim_order(edge_program: ExportedProgram) -> dict[str, Parameter]: + """ + Create mapping between node names and their dim-orders. + + :param edge_program: EdgeProgram instance. + :return: Mapping from node name to dim-order. + """ + + return { + n.name: val.dim_order() + for n in edge_program.graph.nodes + if hasattr(val := n.meta.get("val", None), "dim_order") + } + @staticmethod def build_conversion_context( parameters_mapping: dict, + dim_order_map: dict[str, ...], neutron_target_spec: NeutronTargetSpec, conversion_config: ConversionConfig = _default_conversion_config, custom_delegation_options: CustomDelegationOptions = _default_delegation_options, ) -> ConversionContext: tflite_builder = AtenModelBuilderDirector( - 3, "TFLite from EdgeProgram", neutron_target_spec, conversion_config + 3, + "TFLite from EdgeProgram", + neutron_target_spec, + dim_order_map, + conversion_config, ) # Add "sentinel" buffer (defined in schema.fbs) diff --git a/backends/nxp/backend/ir/converter/builder/model_builder.py b/backends/nxp/backend/ir/converter/builder/model_builder.py index cfd80d8e300..87b1e55bcf9 100755 --- a/backends/nxp/backend/ir/converter/builder/model_builder.py +++ b/backends/nxp/backend/ir/converter/builder/model_builder.py @@ -8,13 +8,13 @@ from copy import deepcopy from itertools import chain -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union import executorch.backends.nxp.backend.ir.converter.conversion.translator as translator import executorch.backends.nxp.backend.ir.logger as logger import executorch.backends.nxp.backend.ir.tflite_generator.tflite_model as tflite_model - import numpy as np +from executorch.backends.nxp.backend.edge_helper import is_channels_last_dim_order from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder import ( quantization_verification, @@ -65,23 +65,25 @@ class ModelBuilder: _tfl_model: tflite_model.Model - _tensor_name_map: Dict # Mapping 'str' to 'tflT.Tensor' + _tensor_name_map: dict # Mapping 'str' to 'tflT.Tensor' - # Maps BuiltinOperator to a Dict, mapping version to index. Operators of type 'BuiltinOperator.CUSTOM' + # Maps BuiltinOperator to a dict, mapping version to index. Operators of type 'BuiltinOperator.CUSTOM' # have their 'version' prepended with its name, for example "FlexErf_1". - op_code_type_index_map: Dict[BuiltinOperator, Dict[Union[str, int], int]] + op_code_type_index_map: dict[BuiltinOperator, dict[Union[str, int], int]] - _nchw_tensor_version: Dict # Mapping 'tflT.Tensor' to 'tflT.Tensor' which is + _nchw_tensor_version: dict # Mapping 'tflT.Tensor' to 'tflT.Tensor' which is # equal, but in NCHW format - _skipped_output_map: Dict # Mapping 'tflT.Tensor' objects that were outputs + _skipped_output_map: dict # Mapping 'tflT.Tensor' objects that were outputs # of skipped operators, to 'tflT.Tensor' outputs of # previous operators - _zeros_tensor_map: Dict # Mapping 'string' shapes to 'tflT.Tensor' objects + _zeros_tensor_map: dict # Mapping 'string' shapes to 'tflT.Tensor' objects neutron_target_spec: NeutronTargetSpec + dim_order_map: dict # Mapping tensor names to their ExecuTorch `dim_order`. + conversion_config: ConversionConfig _default_conversion_config = ConversionConfig() @@ -91,11 +93,13 @@ def __init__( model_version: int, model_description: str, neutron_target_spec: NeutronTargetSpec, + dim_order_map: dict[str, ...], conversion_config: ConversionConfig = _default_conversion_config, ) -> None: self._tfl_model = tflite_model.Model(model_version, model_description) self.neutron_target_spec = neutron_target_spec self.conversion_config = conversion_config + self.dim_order_map = dim_order_map self.op_code_type_index_map = {} self._tensor_name_map = {} @@ -358,6 +362,16 @@ def _make_inputs_channels_first(self): for input_tensor in self.get_sub_graph().inputs.tmp_inputs: if input_tensor.tensor_format.is_channels_last(): + # The input must be permuted. + + if is_channels_last_dim_order( + self.dim_order_map.get(input_tensor.name, []) + ): + # Do NOT insert a Transpose, as the input will already be provided in the channels last format + # during runtime. + new_inputs.append(input_tensor) + continue + # Create a Transpose operator and replace the graph input new_input_shape = translator.channels_last_shape_to_channels_first( @@ -408,6 +422,16 @@ def _make_outputs_channels_first(self): for output_tensor in self.get_sub_graph().outputs.tmp_outputs: if output_tensor.tensor_format.is_channels_last(): + # The output must be permuted. + + if is_channels_last_dim_order( + self.dim_order_map.get(output_tensor.name, []) + ): + # Do NOT insert a Transpose, as the output will be required to be in the channels last format + # during runtime. + new_outputs.append(output_tensor) + continue + # Add a Transpose operator, to make the output channels first shape = output_tensor.shape.vector diff --git a/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py b/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py index f32e09e78e0..14c4890a202 100644 --- a/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py +++ b/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py @@ -20,6 +20,7 @@ Relu = exir_ops.edge.aten.relu.default Sigmoid = exir_ops.edge.aten.sigmoid.default Tanh = exir_ops.edge.aten.tanh.default +CloneDimOrder = exir_ops.edge.dim_order_ops._clone_dim_order.default def insert_qdq_pair_after_node( @@ -102,6 +103,9 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): MM: [ ViewCopy, ], + ViewCopy: [ + CloneDimOrder, + ], } def run(self, graph_module: torch.fx.GraphModule) -> PassResult: diff --git a/backends/nxp/neutron_partitioner.py b/backends/nxp/neutron_partitioner.py index c0a27adfc7c..bb13a2fecac 100644 --- a/backends/nxp/neutron_partitioner.py +++ b/backends/nxp/neutron_partitioner.py @@ -79,6 +79,7 @@ class QDQCluster: exir_ops.edge.aten.relu.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.tanh.default, + exir_ops.edge.dim_order_ops._clone_dim_order.default, ] def __init__(self): diff --git a/backends/nxp/runtime/NeutronBackend.cpp b/backends/nxp/runtime/NeutronBackend.cpp index a5c208738f3..4bf23324ef5 100644 --- a/backends/nxp/runtime/NeutronBackend.cpp +++ b/backends/nxp/runtime/NeutronBackend.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2024 NXP + * Copyright 2024-2025 NXP * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -10,6 +10,7 @@ #include #include #include +#include #include "NeutronDriver.h" #include "NeutronErrors.h" @@ -19,7 +20,6 @@ using namespace std; namespace torch { namespace executor { namespace neutron { - // All the memory need to be aligned with 16 #define BUFFER_ALIGNMENT 16 #define ALIGN_SIZE(size) \ @@ -378,18 +378,45 @@ class NeutronBackend final : public PyTorchBackendInterface { // Transpose inputs if needed. for (int i = 0; i < cfg->numInputs; i++) { auto arg = args[cfg->inputMap[i]]->toTensor(); + auto dim_order = arg.dim_order().data(); + if (cfg->inputTranspositionFlags[i] && multipleChannelsPresent(arg.sizes())) { + // The input must be transposed. if (arg.sizes().size() < 3) { ET_LOG(Error, "Unable to transpose 1D and 2D input to channel last"); return Error::InvalidProgram; } - // Allocate buffer, the allocator is reset after each PTE instruction. - void* buffer = context.allocate(arg.nbytes(), 16); - transposeInput( - arg.const_data_ptr(), buffer, arg.sizes(), arg.element_size()); - cfg->dcfg.inputs[i] = buffer; + + if (is_channels_last_dim_order(dim_order, arg.dim())) { + // The tensor is already permuted. + ET_LOG(Info, "Using channels last dim order for input %d.\n", i); + cfg->dcfg.inputs[i] = arg.const_data_ptr(); + } else if (is_contiguous_dim_order(dim_order, arg.dim())) { + // Transpose the data to channels last. + + ET_LOG(Info, "Transposing input %d to channels last.\n", i); + + // Allocate buffer, the allocator is reset after each PTE instruction. + void* buffer = context.allocate(arg.nbytes(), 16); + transposeInput( + arg.const_data_ptr(), buffer, arg.sizes(), arg.element_size()); + cfg->dcfg.inputs[i] = buffer; + } else { + // Unexpected dim-order. + ET_LOG(Error, "Input %d uses unsupported dim-order.", i); + return Error::InvalidProgram; + } } else { + // The input matches the ExecuTorch format, so no transposition is + // needed. + + if (!is_contiguous_dim_order(dim_order, arg.dim())) { + // Unexpected dim-order. + ET_LOG(Error, "Input %d uses unsupported dim-order.", i); + return Error::InvalidProgram; + } + cfg->dcfg.inputs[i] = arg.const_data_ptr(); } } @@ -398,12 +425,35 @@ class NeutronBackend final : public PyTorchBackendInterface { // Redirect outputs if needed before transposition. for (int i = 0; i < cfg->numOutputs; i++) { auto arg = args[cfg->numInputArgs + cfg->outputMap[i]]->toTensor(); + auto dim_order = arg.dim_order().data(); + if (cfg->outputTranspositionFlags[i] && multipleChannelsPresent(arg.sizes())) { - // Allocate buffer, the allocator is reset after each PTE instruction. - void* buffer = context.allocate(arg.nbytes(), 16); - cfg->dcfg.outputs[i] = buffer; + // The output will have to be transposed. + + if (is_channels_last_dim_order(dim_order, arg.dim())) { + // The tensor will already be correctly permuted. No transposition + // needed. + cfg->dcfg.outputs[i] = arg.mutable_data_ptr(); + } else if (is_contiguous_dim_order(dim_order, arg.dim())) { + // Allocate buffer, the allocator is reset after each PTE instruction. + void* buffer = context.allocate(arg.nbytes(), 16); + cfg->dcfg.outputs[i] = buffer; + } else { + // Unexpected dim-order. + ET_LOG(Error, "Output %d uses unsupported dim-order.", i); + return Error::InvalidProgram; + } } else { + // The tensor should match the ExecuTorch required format, so no + // transposition is needed. + + if (!is_contiguous_dim_order(dim_order, arg.dim())) { + // Unexpected dim-order. + ET_LOG(Error, "Output %d uses unsupported dim-order.", i); + return Error::InvalidProgram; + } + cfg->dcfg.outputs[i] = arg.mutable_data_ptr(); } } @@ -427,18 +477,35 @@ class NeutronBackend final : public PyTorchBackendInterface { // Transpose outputs. for (int i = 0; i < cfg->numOutputs; i++) { auto arg = args[cfg->numInputArgs + cfg->outputMap[i]]->toTensor(); + if (cfg->outputTranspositionFlags[i] && multipleChannelsPresent(arg.sizes())) { + // The output must be transposed. + if (arg.sizes().size() < 3) { ET_LOG( Error, "Unable to transpose 1D and 2D output to channel first"); return Error::InvalidProgram; } - transposeOutput( - cfg->dcfg.outputs[i], - arg.mutable_data_ptr(), - arg.sizes(), - arg.element_size()); + + auto dim_order = arg.dim_order().data(); + if (is_channels_last_dim_order(dim_order, arg.dim())) { + // The rest of the model expects the `channels_last` dim order, which + // the data already matches. + ET_LOG(Info, "Using channels last dim order for output %d.\n", i); + } else if (is_contiguous_dim_order(dim_order, arg.dim())) { + // Transpose the data to channels first. + ET_LOG(Info, "Transposing output %d to channels first.\n", i); + transposeOutput( + cfg->dcfg.outputs[i], + arg.mutable_data_ptr(), + arg.sizes(), + arg.element_size()); + } else { + // Unexpected dim-order. + ET_LOG(Error, "Output %d uses unsupported dim-order.", i); + return Error::InvalidProgram; + } } } @@ -467,7 +534,6 @@ auto backend = NeutronBackend(); Backend backend_id{"NeutronBackend", &backend}; static auto registered = register_backend(backend_id); } // namespace - } // namespace neutron } // namespace executor -} // namespace torch +} // namespace torch \ No newline at end of file