11/*
2- * Copyright 2024 NXP
2+ * Copyright 2024-2025 NXP
33 *
44 * This source code is licensed under the BSD-style license found in the
55 * LICENSE file in the root directory of this source tree.
1010#include < executorch/runtime/backend/interface.h>
1111#include < executorch/runtime/core/error.h>
1212#include < executorch/runtime/core/evalue.h>
13+ #include < executorch/runtime/core/exec_aten/util/dim_order_util.h>
1314
1415#include " NeutronDriver.h"
1516#include " NeutronErrors.h"
@@ -19,7 +20,6 @@ using namespace std;
1920namespace torch {
2021namespace executor {
2122namespace neutron {
22-
2323// All the memory need to be aligned with 16
2424#define BUFFER_ALIGNMENT 16
2525#define ALIGN_SIZE (size ) \
@@ -378,18 +378,45 @@ class NeutronBackend final : public PyTorchBackendInterface {
378378 // Transpose inputs if needed.
379379 for (int i = 0 ; i < cfg->numInputs ; i++) {
380380 auto arg = args[cfg->inputMap [i]]->toTensor ();
381+ auto dim_order = arg.dim_order ().data ();
382+
381383 if (cfg->inputTranspositionFlags [i] &&
382384 multipleChannelsPresent (arg.sizes ())) {
385+ // The input must be transposed.
383386 if (arg.sizes ().size () < 3 ) {
384387 ET_LOG (Error, " Unable to transpose 1D and 2D input to channel last" );
385388 return Error::InvalidProgram;
386389 }
387- // Allocate buffer, the allocator is reset after each PTE instruction.
388- void * buffer = context.allocate (arg.nbytes (), 16 );
389- transposeInput (
390- arg.const_data_ptr (), buffer, arg.sizes (), arg.element_size ());
391- cfg->dcfg .inputs [i] = buffer;
390+
391+ if (is_channels_last_dim_order (dim_order, arg.dim ())) {
392+ // The tensor is already permuted.
393+ ET_LOG (Info, " Using channels last dim order for input %d.\n " , i);
394+ cfg->dcfg .inputs [i] = arg.const_data_ptr ();
395+ } else if (is_contiguous_dim_order (dim_order, arg.dim ())) {
396+ // Transpose the data to channels last.
397+
398+ ET_LOG (Info, " Transposing input %d to channels last.\n " , i);
399+
400+ // Allocate buffer, the allocator is reset after each PTE instruction.
401+ void * buffer = context.allocate (arg.nbytes (), 16 );
402+ transposeInput (
403+ arg.const_data_ptr (), buffer, arg.sizes (), arg.element_size ());
404+ cfg->dcfg .inputs [i] = buffer;
405+ } else {
406+ // Unexpected dim-order.
407+ ET_LOG (Error, " Input %d uses unsupported dim-order." , i);
408+ return Error::InvalidProgram;
409+ }
392410 } else {
411+ // The input matches the ExecuTorch format, so no transposition is
412+ // needed.
413+
414+ if (!is_contiguous_dim_order (dim_order, arg.dim ())) {
415+ // Unexpected dim-order.
416+ ET_LOG (Error, " Input %d uses unsupported dim-order." , i);
417+ return Error::InvalidProgram;
418+ }
419+
393420 cfg->dcfg .inputs [i] = arg.const_data_ptr ();
394421 }
395422 }
@@ -398,12 +425,35 @@ class NeutronBackend final : public PyTorchBackendInterface {
398425 // Redirect outputs if needed before transposition.
399426 for (int i = 0 ; i < cfg->numOutputs ; i++) {
400427 auto arg = args[cfg->numInputArgs + cfg->outputMap [i]]->toTensor ();
428+ auto dim_order = arg.dim_order ().data ();
429+
401430 if (cfg->outputTranspositionFlags [i] &&
402431 multipleChannelsPresent (arg.sizes ())) {
403- // Allocate buffer, the allocator is reset after each PTE instruction.
404- void * buffer = context.allocate (arg.nbytes (), 16 );
405- cfg->dcfg .outputs [i] = buffer;
432+ // The output will have to be transposed.
433+
434+ if (is_channels_last_dim_order (dim_order, arg.dim ())) {
435+ // The tensor will already be correctly permuted. No transposition
436+ // needed.
437+ cfg->dcfg .outputs [i] = arg.mutable_data_ptr ();
438+ } else if (is_contiguous_dim_order (dim_order, arg.dim ())) {
439+ // Allocate buffer, the allocator is reset after each PTE instruction.
440+ void * buffer = context.allocate (arg.nbytes (), 16 );
441+ cfg->dcfg .outputs [i] = buffer;
442+ } else {
443+ // Unexpected dim-order.
444+ ET_LOG (Error, " Output %d uses unsupported dim-order." , i);
445+ return Error::InvalidProgram;
446+ }
406447 } else {
448+ // The tensor should match the ExecuTorch required format, so no
449+ // transposition is needed.
450+
451+ if (!is_contiguous_dim_order (dim_order, arg.dim ())) {
452+ // Unexpected dim-order.
453+ ET_LOG (Error, " Output %d uses unsupported dim-order." , i);
454+ return Error::InvalidProgram;
455+ }
456+
407457 cfg->dcfg .outputs [i] = arg.mutable_data_ptr ();
408458 }
409459 }
@@ -427,18 +477,35 @@ class NeutronBackend final : public PyTorchBackendInterface {
427477 // Transpose outputs.
428478 for (int i = 0 ; i < cfg->numOutputs ; i++) {
429479 auto arg = args[cfg->numInputArgs + cfg->outputMap [i]]->toTensor ();
480+
430481 if (cfg->outputTranspositionFlags [i] &&
431482 multipleChannelsPresent (arg.sizes ())) {
483+ // The output must be transposed.
484+
432485 if (arg.sizes ().size () < 3 ) {
433486 ET_LOG (
434487 Error, " Unable to transpose 1D and 2D output to channel first" );
435488 return Error::InvalidProgram;
436489 }
437- transposeOutput (
438- cfg->dcfg .outputs [i],
439- arg.mutable_data_ptr (),
440- arg.sizes (),
441- arg.element_size ());
490+
491+ auto dim_order = arg.dim_order ().data ();
492+ if (is_channels_last_dim_order (dim_order, arg.dim ())) {
493+ // The rest of the model expects the `channels_last` dim order, which
494+ // the data already matches.
495+ ET_LOG (Info, " Using channels last dim order for output %d.\n " , i);
496+ } else if (is_contiguous_dim_order (dim_order, arg.dim ())) {
497+ // Transpose the data to channels first.
498+ ET_LOG (Info, " Transposing output %d to channels first.\n " , i);
499+ transposeOutput (
500+ cfg->dcfg .outputs [i],
501+ arg.mutable_data_ptr (),
502+ arg.sizes (),
503+ arg.element_size ());
504+ } else {
505+ // Unexpected dim-order.
506+ ET_LOG (Error, " Output %d uses unsupported dim-order." , i);
507+ return Error::InvalidProgram;
508+ }
442509 }
443510 }
444511
@@ -467,7 +534,6 @@ auto backend = NeutronBackend();
467534Backend backend_id{" NeutronBackend" , &backend};
468535static auto registered = register_backend(backend_id);
469536} // namespace
470-
471537} // namespace neutron
472538} // namespace executor
473- } // namespace torch
539+ } // namespace torch
0 commit comments