Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions backends/nxp/_passes/remove_getitem_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@

import torch

from executorch.backends.nxp.backend.node_format_inference import (
NodeFormat,
NXP_NODE_FORMAT,
)
from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Expand Down
7 changes: 7 additions & 0 deletions backends/nxp/backend/edge_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,10 @@ def get_non_qdq_users(node: Node) -> list[Node]:
res.extend(list(dequant_node.users))

return res


def is_channels_last_dim_order(dim_order: list[int]) -> bool:
if len(dim_order) < 3:
return False

return list(dim_order) == [0] + list(range(2, len(dim_order))) + [1]
24 changes: 23 additions & 1 deletion backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ def convert_program(
:return: TFLite flatbuffers as bytes.
"""
parameters_mapping = self.map_inputs_to_parameters(edge_program)
dim_order_map = self.map_nodes_to_dim_order(edge_program)

cc = self.build_conversion_context(
parameters_mapping,
dim_order_map,
neutron_target_spec,
conversion_config,
custom_delegation_options,
Expand Down Expand Up @@ -174,15 +176,35 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet

return result_map

@staticmethod
def map_nodes_to_dim_order(edge_program: ExportedProgram) -> dict[str, Parameter]:
"""
Create mapping between node names and their dim-orders.

:param edge_program: EdgeProgram instance.
:return: Mapping from node name to dim-order.
"""

return {
n.name: val.dim_order()
for n in edge_program.graph.nodes
if hasattr(val := n.meta.get("val", None), "dim_order")
}

@staticmethod
def build_conversion_context(
parameters_mapping: dict,
dim_order_map: dict[str, ...],
neutron_target_spec: NeutronTargetSpec,
conversion_config: ConversionConfig = _default_conversion_config,
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
) -> ConversionContext:
tflite_builder = AtenModelBuilderDirector(
3, "TFLite from EdgeProgram", neutron_target_spec, conversion_config
3,
"TFLite from EdgeProgram",
neutron_target_spec,
dim_order_map,
conversion_config,
)

# Add "sentinel" buffer (defined in schema.fbs)
Expand Down
40 changes: 32 additions & 8 deletions backends/nxp/backend/ir/converter/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

from copy import deepcopy
from itertools import chain
from typing import Dict, List, Optional, Union
from typing import List, Optional, Union

import executorch.backends.nxp.backend.ir.converter.conversion.translator as translator
import executorch.backends.nxp.backend.ir.logger as logger
import executorch.backends.nxp.backend.ir.tflite_generator.tflite_model as tflite_model

import numpy as np
from executorch.backends.nxp.backend.edge_helper import is_channels_last_dim_order
from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig
from executorch.backends.nxp.backend.ir.converter.builder import (
quantization_verification,
Expand Down Expand Up @@ -65,23 +65,25 @@ class ModelBuilder:

_tfl_model: tflite_model.Model

_tensor_name_map: Dict # Mapping 'str' to 'tflT.Tensor'
_tensor_name_map: dict # Mapping 'str' to 'tflT.Tensor'

# Maps BuiltinOperator to a Dict, mapping version to index. Operators of type 'BuiltinOperator.CUSTOM'
# Maps BuiltinOperator to a dict, mapping version to index. Operators of type 'BuiltinOperator.CUSTOM'
# have their 'version' prepended with its name, for example "FlexErf_1".
op_code_type_index_map: Dict[BuiltinOperator, Dict[Union[str, int], int]]
op_code_type_index_map: dict[BuiltinOperator, dict[Union[str, int], int]]

_nchw_tensor_version: Dict # Mapping 'tflT.Tensor' to 'tflT.Tensor' which is
_nchw_tensor_version: dict # Mapping 'tflT.Tensor' to 'tflT.Tensor' which is
# equal, but in NCHW format

_skipped_output_map: Dict # Mapping 'tflT.Tensor' objects that were outputs
_skipped_output_map: dict # Mapping 'tflT.Tensor' objects that were outputs
# of skipped operators, to 'tflT.Tensor' outputs of
# previous operators

_zeros_tensor_map: Dict # Mapping 'string' shapes to 'tflT.Tensor' objects
_zeros_tensor_map: dict # Mapping 'string' shapes to 'tflT.Tensor' objects

neutron_target_spec: NeutronTargetSpec

dim_order_map: dict # Mapping tensor names to their ExecuTorch `dim_order`.

conversion_config: ConversionConfig

_default_conversion_config = ConversionConfig()
Expand All @@ -91,11 +93,13 @@ def __init__(
model_version: int,
model_description: str,
neutron_target_spec: NeutronTargetSpec,
dim_order_map: dict[str, ...],
conversion_config: ConversionConfig = _default_conversion_config,
) -> None:
self._tfl_model = tflite_model.Model(model_version, model_description)
self.neutron_target_spec = neutron_target_spec
self.conversion_config = conversion_config
self.dim_order_map = dim_order_map

self.op_code_type_index_map = {}
self._tensor_name_map = {}
Expand Down Expand Up @@ -358,6 +362,16 @@ def _make_inputs_channels_first(self):
for input_tensor in self.get_sub_graph().inputs.tmp_inputs:

if input_tensor.tensor_format.is_channels_last():
# The input must be permuted.

if is_channels_last_dim_order(
self.dim_order_map.get(input_tensor.name, [])
):
# Do NOT insert a Transpose, as the input will already be provided in the channels last format
# during runtime.
new_inputs.append(input_tensor)
continue

# Create a Transpose operator and replace the graph input

new_input_shape = translator.channels_last_shape_to_channels_first(
Expand Down Expand Up @@ -408,6 +422,16 @@ def _make_outputs_channels_first(self):

for output_tensor in self.get_sub_graph().outputs.tmp_outputs:
if output_tensor.tensor_format.is_channels_last():
# The output must be permuted.

if is_channels_last_dim_order(
self.dim_order_map.get(output_tensor.name, [])
):
# Do NOT insert a Transpose, as the output will be required to be in the channels last format
# during runtime.
new_outputs.append(output_tensor)
continue

# Add a Transpose operator, to make the output channels first

shape = output_tensor.shape.vector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Relu = exir_ops.edge.aten.relu.default
Sigmoid = exir_ops.edge.aten.sigmoid.default
Tanh = exir_ops.edge.aten.tanh.default
CloneDimOrder = exir_ops.edge.dim_order_ops._clone_dim_order.default


def insert_qdq_pair_after_node(
Expand Down Expand Up @@ -102,6 +103,9 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
MM: [
ViewCopy,
],
ViewCopy: [
CloneDimOrder,
],
}

def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
Expand Down
1 change: 1 addition & 0 deletions backends/nxp/neutron_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class QDQCluster:
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
]

def __init__(self):
Expand Down
100 changes: 83 additions & 17 deletions backends/nxp/runtime/NeutronBackend.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2024 NXP
* Copyright 2024-2025 NXP
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
Expand All @@ -10,6 +10,7 @@
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>

#include "NeutronDriver.h"
#include "NeutronErrors.h"
Expand All @@ -19,7 +20,6 @@ using namespace std;
namespace torch {
namespace executor {
namespace neutron {

// All the memory need to be aligned with 16
#define BUFFER_ALIGNMENT 16
#define ALIGN_SIZE(size) \
Expand Down Expand Up @@ -378,18 +378,45 @@ class NeutronBackend final : public PyTorchBackendInterface {
// Transpose inputs if needed.
for (int i = 0; i < cfg->numInputs; i++) {
auto arg = args[cfg->inputMap[i]]->toTensor();
auto dim_order = arg.dim_order().data();

if (cfg->inputTranspositionFlags[i] &&
multipleChannelsPresent(arg.sizes())) {
// The input must be transposed.
if (arg.sizes().size() < 3) {
ET_LOG(Error, "Unable to transpose 1D and 2D input to channel last");
return Error::InvalidProgram;
}
// Allocate buffer, the allocator is reset after each PTE instruction.
void* buffer = context.allocate(arg.nbytes(), 16);
transposeInput(
arg.const_data_ptr(), buffer, arg.sizes(), arg.element_size());
cfg->dcfg.inputs[i] = buffer;

if (is_channels_last_dim_order(dim_order, arg.dim())) {
// The tensor is already permuted.
ET_LOG(Info, "Using channels last dim order for input %d.\n", i);
cfg->dcfg.inputs[i] = arg.const_data_ptr();
} else if (is_contiguous_dim_order(dim_order, arg.dim())) {
// Transpose the data to channels last.

ET_LOG(Info, "Transposing input %d to channels last.\n", i);

// Allocate buffer, the allocator is reset after each PTE instruction.
void* buffer = context.allocate(arg.nbytes(), 16);
transposeInput(
arg.const_data_ptr(), buffer, arg.sizes(), arg.element_size());
cfg->dcfg.inputs[i] = buffer;
} else {
// Unexpected dim-order.
ET_LOG(Error, "Input %d uses unsupported dim-order.", i);
return Error::InvalidProgram;
}
} else {
// The input matches the ExecuTorch format, so no transposition is
// needed.

if (!is_contiguous_dim_order(dim_order, arg.dim())) {
// Unexpected dim-order.
ET_LOG(Error, "Input %d uses unsupported dim-order.", i);
return Error::InvalidProgram;
}

cfg->dcfg.inputs[i] = arg.const_data_ptr();
}
}
Expand All @@ -398,12 +425,35 @@ class NeutronBackend final : public PyTorchBackendInterface {
// Redirect outputs if needed before transposition.
for (int i = 0; i < cfg->numOutputs; i++) {
auto arg = args[cfg->numInputArgs + cfg->outputMap[i]]->toTensor();
auto dim_order = arg.dim_order().data();

if (cfg->outputTranspositionFlags[i] &&
multipleChannelsPresent(arg.sizes())) {
// Allocate buffer, the allocator is reset after each PTE instruction.
void* buffer = context.allocate(arg.nbytes(), 16);
cfg->dcfg.outputs[i] = buffer;
// The output will have to be transposed.

if (is_channels_last_dim_order(dim_order, arg.dim())) {
// The tensor will already be correctly permuted. No transposition
// needed.
cfg->dcfg.outputs[i] = arg.mutable_data_ptr();
} else if (is_contiguous_dim_order(dim_order, arg.dim())) {
// Allocate buffer, the allocator is reset after each PTE instruction.
void* buffer = context.allocate(arg.nbytes(), 16);
cfg->dcfg.outputs[i] = buffer;
} else {
// Unexpected dim-order.
ET_LOG(Error, "Output %d uses unsupported dim-order.", i);
return Error::InvalidProgram;
}
} else {
// The tensor should match the ExecuTorch required format, so no
// transposition is needed.

if (!is_contiguous_dim_order(dim_order, arg.dim())) {
// Unexpected dim-order.
ET_LOG(Error, "Output %d uses unsupported dim-order.", i);
return Error::InvalidProgram;
}

cfg->dcfg.outputs[i] = arg.mutable_data_ptr();
}
}
Expand All @@ -427,18 +477,35 @@ class NeutronBackend final : public PyTorchBackendInterface {
// Transpose outputs.
for (int i = 0; i < cfg->numOutputs; i++) {
auto arg = args[cfg->numInputArgs + cfg->outputMap[i]]->toTensor();

if (cfg->outputTranspositionFlags[i] &&
multipleChannelsPresent(arg.sizes())) {
// The output must be transposed.

if (arg.sizes().size() < 3) {
ET_LOG(
Error, "Unable to transpose 1D and 2D output to channel first");
return Error::InvalidProgram;
}
transposeOutput(
cfg->dcfg.outputs[i],
arg.mutable_data_ptr(),
arg.sizes(),
arg.element_size());

auto dim_order = arg.dim_order().data();
if (is_channels_last_dim_order(dim_order, arg.dim())) {
// The rest of the model expects the `channels_last` dim order, which
// the data already matches.
ET_LOG(Info, "Using channels last dim order for output %d.\n", i);
} else if (is_contiguous_dim_order(dim_order, arg.dim())) {
// Transpose the data to channels first.
ET_LOG(Info, "Transposing output %d to channels first.\n", i);
transposeOutput(
cfg->dcfg.outputs[i],
arg.mutable_data_ptr(),
arg.sizes(),
arg.element_size());
} else {
// Unexpected dim-order.
ET_LOG(Error, "Output %d uses unsupported dim-order.", i);
return Error::InvalidProgram;
}
}
}

Expand Down Expand Up @@ -467,7 +534,6 @@ auto backend = NeutronBackend();
Backend backend_id{"NeutronBackend", &backend};
static auto registered = register_backend(backend_id);
} // namespace

} // namespace neutron
} // namespace executor
} // namespace torch
} // namespace torch
Loading