88
99from copy import deepcopy
1010from itertools import chain
11- from typing import Dict , List , Optional , Union
11+ from typing import List , Optional , Union
1212
1313import executorch .backends .nxp .backend .ir .converter .conversion .translator as translator
1414import executorch .backends .nxp .backend .ir .logger as logger
1515import executorch .backends .nxp .backend .ir .tflite_generator .tflite_model as tflite_model
16-
1716import numpy as np
17+ from executorch .backends .nxp .backend .edge_helper import is_channels_last_dim_order
1818from executorch .backends .nxp .backend .ir .conversion_config import ConversionConfig
1919from executorch .backends .nxp .backend .ir .converter .builder import (
2020 quantization_verification ,
@@ -65,23 +65,25 @@ class ModelBuilder:
6565
6666 _tfl_model : tflite_model .Model
6767
68- _tensor_name_map : Dict # Mapping 'str' to 'tflT.Tensor'
68+ _tensor_name_map : dict # Mapping 'str' to 'tflT.Tensor'
6969
70- # Maps BuiltinOperator to a Dict , mapping version to index. Operators of type 'BuiltinOperator.CUSTOM'
70+ # Maps BuiltinOperator to a dict , mapping version to index. Operators of type 'BuiltinOperator.CUSTOM'
7171 # have their 'version' prepended with its name, for example "FlexErf_1".
72- op_code_type_index_map : Dict [BuiltinOperator , Dict [Union [str , int ], int ]]
72+ op_code_type_index_map : dict [BuiltinOperator , dict [Union [str , int ], int ]]
7373
74- _nchw_tensor_version : Dict # Mapping 'tflT.Tensor' to 'tflT.Tensor' which is
74+ _nchw_tensor_version : dict # Mapping 'tflT.Tensor' to 'tflT.Tensor' which is
7575 # equal, but in NCHW format
7676
77- _skipped_output_map : Dict # Mapping 'tflT.Tensor' objects that were outputs
77+ _skipped_output_map : dict # Mapping 'tflT.Tensor' objects that were outputs
7878 # of skipped operators, to 'tflT.Tensor' outputs of
7979 # previous operators
8080
81- _zeros_tensor_map : Dict # Mapping 'string' shapes to 'tflT.Tensor' objects
81+ _zeros_tensor_map : dict # Mapping 'string' shapes to 'tflT.Tensor' objects
8282
8383 neutron_target_spec : NeutronTargetSpec
8484
85+ dim_order_map : dict # Mapping tensor names to their ExecuTorch `dim_order`.
86+
8587 conversion_config : ConversionConfig
8688
8789 _default_conversion_config = ConversionConfig ()
@@ -91,11 +93,13 @@ def __init__(
9193 model_version : int ,
9294 model_description : str ,
9395 neutron_target_spec : NeutronTargetSpec ,
96+ dim_order_map : dict [str , ...],
9497 conversion_config : ConversionConfig = _default_conversion_config ,
9598 ) -> None :
9699 self ._tfl_model = tflite_model .Model (model_version , model_description )
97100 self .neutron_target_spec = neutron_target_spec
98101 self .conversion_config = conversion_config
102+ self .dim_order_map = dim_order_map
99103
100104 self .op_code_type_index_map = {}
101105 self ._tensor_name_map = {}
@@ -358,6 +362,16 @@ def _make_inputs_channels_first(self):
358362 for input_tensor in self .get_sub_graph ().inputs .tmp_inputs :
359363
360364 if input_tensor .tensor_format .is_channels_last ():
365+ # The input must be permuted.
366+
367+ if is_channels_last_dim_order (
368+ self .dim_order_map .get (input_tensor .name , [])
369+ ):
370+ # Do NOT insert a Transpose, as the input will already be provided in the channels last format
371+ # during runtime.
372+ new_inputs .append (input_tensor )
373+ continue
374+
361375 # Create a Transpose operator and replace the graph input
362376
363377 new_input_shape = translator .channels_last_shape_to_channels_first (
@@ -408,6 +422,16 @@ def _make_outputs_channels_first(self):
408422
409423 for output_tensor in self .get_sub_graph ().outputs .tmp_outputs :
410424 if output_tensor .tensor_format .is_channels_last ():
425+ # The output must be permuted.
426+
427+ if is_channels_last_dim_order (
428+ self .dim_order_map .get (output_tensor .name , [])
429+ ):
430+ # Do NOT insert a Transpose, as the output will be required to be in the channels last format
431+ # during runtime.
432+ new_outputs .append (output_tensor )
433+ continue
434+
411435 # Add a Transpose operator, to make the output channels first
412436
413437 shape = output_tensor .shape .vector
0 commit comments