From 6379fd69009944cfe04516cd2170026ee524bc4c Mon Sep 17 00:00:00 2001 From: Xingguo Li Date: Mon, 24 Nov 2025 21:16:41 +0000 Subject: [PATCH] Add softmax support for int8 in Cortex M (dim=-1) - integrate CMSIS softmax into Cortex-M backend - add fusion pass/tests for quantized softmax - lint cleanup passes - Resolved merge conflicts Change-Id: I0ec19f011069fa1482e2de2ab62b9e7d7f56b2a8 Signed-off-by: Xingguo Li --- backends/cortex_m/CMakeLists.txt | 1 + backends/cortex_m/ops/op_softmax.cpp | 148 ++++++++++++++++++ backends/cortex_m/ops/operators.py | 66 ++++++++ backends/cortex_m/ops/operators.yaml | 6 + .../passes/quantized_op_fusion_pass.py | 87 +++++++++- .../cortex_m/quantizer/operator_configs.py | 11 ++ .../quantizer/quantization_configs.py | 24 ++- backends/cortex_m/quantizer/quantizer.py | 78 ++++++++- backends/cortex_m/test/ops/test_softmax.py | 80 ++++++++++ backends/cortex_m/test/tester.py | 3 +- 10 files changed, 500 insertions(+), 4 deletions(-) create mode 100644 backends/cortex_m/ops/op_softmax.cpp create mode 100644 backends/cortex_m/test/ops/test_softmax.py diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt index ac330d4b015..f730d9154cb 100644 --- a/backends/cortex_m/CMakeLists.txt +++ b/backends/cortex_m/CMakeLists.txt @@ -61,6 +61,7 @@ set(_cortex_m_kernels__srcs ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_maximum.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_transpose.cpp ) diff --git a/backends/cortex_m/ops/op_softmax.cpp b/backends/cortex_m/ops/op_softmax.cpp new file mode 100644 index 00000000000..2bf7488adcb --- /dev/null +++ b/backends/cortex_m/ops/op_softmax.cpp @@ -0,0 +1,148 @@ +/* + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +#include +#include +#include + +// Include CMSIS-NN headers with C linkage +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +namespace { + +constexpr int32_t kCmsisSoftmaxZeroPoint = -128; + +inline bool is_int8_tensor(const Tensor& tensor) { + return tensor.scalar_type() == ScalarType::Char; +} + +inline bool is_last_dim(const Tensor& tensor, int64_t dim) { + const auto rank = tensor.dim(); + const int64_t positive_dim = dim >= 0 ? dim : dim + rank; + return positive_dim == static_cast(rank - 1); +} + +inline int64_t normalize_dim(const Tensor& tensor, int64_t dim) { + const auto rank = tensor.dim(); + const int64_t positive_dim = dim >= 0 ? dim : dim + rank; + return positive_dim; +} + +} // namespace + +Tensor& softmax_out( + KernelRuntimeContext& context, + const Tensor& input, + int64_t dim, + int64_t input_zero_point, + int64_t output_zero_point, + int64_t input_multiplier, + int64_t input_shift, + int64_t diff_min, + Tensor& out) { + if (!is_int8_tensor(input) || !is_int8_tensor(out)) { + ET_LOG( + Error, + "softmax_out: only int8 tensors are supported (input=%d, out=%d)", + static_cast(input.scalar_type()), + static_cast(out.scalar_type())); + context.fail(Error::InvalidArgument); + return out; + } + + if (!is_last_dim(input, dim)) { + ET_LOG( + Error, + "softmax_out: only last-dimension softmax is supported (dim=%lld, rank=%zu)", + static_cast(dim), + static_cast(input.dim())); + context.fail(Error::InvalidArgument); + return out; + } + + const int32_t input_zp_val = static_cast(input_zero_point); + const int32_t output_zp_val = static_cast(output_zero_point); + (void)input_zp_val; // Zero-point difference cancels out during subtraction. + + validate_single_quant_params( + Scalar(input_zp_val), + Scalar(input_multiplier), + Scalar(input_shift), + "softmax input"); + + const auto positive_dim = normalize_dim(input, dim); + const int64_t row_size64 = input.size(positive_dim); + if (row_size64 <= 0 || row_size64 > std::numeric_limits::max()) { + ET_LOG( + Error, + "softmax_out: row size must fit in int32 (row_size=%lld)", + static_cast(row_size64)); + context.fail(Error::InvalidArgument); + return out; + } + + const int32_t row_size = static_cast(row_size64); + const int64_t num_rows64 = input.numel() / row_size64; + if (num_rows64 <= 0 || num_rows64 > std::numeric_limits::max()) { + ET_LOG( + Error, + "softmax_out: num_rows must fit in int32 (num_rows=%lld)", + static_cast(num_rows64)); + context.fail(Error::InvalidArgument); + return out; + } + const int32_t num_rows = static_cast(num_rows64); + + const int8_t* input_data = input.const_data_ptr(); + int8_t* output_data = out.mutable_data_ptr(); + + if (num_rows <= 0 || row_size <= 0) { + ET_LOG( + Error, + "softmax_out: invalid args (dim=%ld, rows=%d, row_size=%d)", + static_cast(dim), + num_rows, + row_size); + context.fail(Error::InvalidArgument); + return out; + } + + const int32_t input_multiplier_val = static_cast(input_multiplier); + const int32_t input_shift_val = static_cast(input_shift); + const int32_t diff_min_val = static_cast(diff_min); + + if (output_zp_val != kCmsisSoftmaxZeroPoint) { + ET_LOG( + Error, + "softmax_out: expected output zero_point=%d (got zero_point=%d)", + kCmsisSoftmaxZeroPoint, + output_zp_val); + context.fail(Error::InvalidArgument); + return out; + } + + arm_softmax_s8( + input_data, + num_rows, + row_size, + input_multiplier_val, + input_shift_val, + diff_min_val, + output_data); + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 291615f613a..c5317f1e439 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -5,6 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math from math import prod from typing import Sequence @@ -15,6 +16,10 @@ requantize_cmsis, SHIFT_INT8, ) +from executorch.backends.cortex_m.quantizer.quantization_configs import ( + CMSIS_SOFTMAX_SCALE, + CMSIS_SOFTMAX_ZERO_POINT, +) from executorch.exir.dialects._ops import ops as exir_ops # To provide the implementation of the operators @@ -24,6 +29,8 @@ # New operator library with a custom namespace to allow fusion etc. lib = Library("cortex_m", "DEF") +SOFTMAX_INPUT_INTEGER_BITS = 5 + ### # dequantize_per_tensor ### @@ -401,6 +408,65 @@ def quantized_linear_impl( return output +# =================================================================== +# SOFTMAX OPERATION DEFINITION +# =================================================================== + +lib.define( + "softmax(Tensor input, int dim, int input_zero_point, int output_zero_point, int input_multiplier, int input_shift, int diff_min) -> Tensor" +) +lib.define( + "softmax.out(Tensor input, int dim, int input_zero_point, int output_zero_point, int input_multiplier, int input_shift, int diff_min, *, Tensor(a!) out) -> Tensor(a!)" +) + + +@register_fake("cortex_m::softmax") +def softmax_meta( + input: torch.Tensor, + dim: int, + input_zero_point: int, + output_zero_point: int, + input_multiplier: int, + input_shift: int, + diff_min: int, +) -> torch.Tensor: + return torch.empty_like(input, dtype=torch.int8) + + +@impl(lib, "softmax", "CompositeExplicitAutograd") +def softmax_impl( + input: torch.Tensor, + dim: int, + input_zero_point: int, + output_zero_point: int, + input_multiplier: int, + input_shift: int, + diff_min: int, +) -> torch.Tensor: + del diff_min # not used in reference path + if input.dtype != torch.int8: + raise TypeError( + f"cortex_m.softmax: expected int8 input tensor, got {input.dtype}" + ) + if output_zero_point != CMSIS_SOFTMAX_ZERO_POINT: + raise ValueError( + f"cortex_m.softmax: expected output_zero_point {CMSIS_SOFTMAX_ZERO_POINT}, got {output_zero_point}" + ) + + real_multiplier = float(input_multiplier) / float(1 << 31) + real_multiplier = math.ldexp(real_multiplier, input_shift) + input_scale = real_multiplier / float(1 << (31 - SOFTMAX_INPUT_INTEGER_BITS)) + if input_scale <= 0: + raise ValueError( + f"cortex_m.softmax: derived non-positive input scale {input_scale}" + ) + + input_fp = (input.to(torch.int32) - int(input_zero_point)).float() * input_scale + probs = torch.softmax(input_fp, dim=dim) + quantized = torch.round(probs / CMSIS_SOFTMAX_SCALE) + int(output_zero_point) + return quantized.clamp(-128, 127).to(torch.int8) + + # =================================================================== # TRANSPOSE OPERATION DEFINITION # =================================================================== diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index 0b0b2f5c715..420dbf76d5e 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -47,6 +47,12 @@ - arg_meta: null kernel_name: cortex_m::quantized_linear_out +- func: cortex_m::softmax.out(Tensor input, int dim, int input_zero_point, int output_zero_point, int input_multiplier, int input_shift, int diff_min, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::softmax_out + - func: cortex_m::transpose.out(Tensor input, int[] perm, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py index c84e66dd7d9..be85a58203a 100644 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -5,6 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math from typing import Dict import torch @@ -13,6 +14,10 @@ quantize_multiplier_aot, SHIFT_INT8, ) +from executorch.backends.cortex_m.quantizer.quantization_configs import ( + CMSIS_SOFTMAX_SCALE, + CMSIS_SOFTMAX_ZERO_POINT, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -31,6 +36,8 @@ class QuantizedOpFusionPass(ExportPass): Supports multiple binary operations with backward compatibility for add. """ + _SOFTMAX_INPUT_INTEGER_BITS = 5 + def _get_add_replacement(self, args, meta): if ( meta.data.get("input_qparams", {}) == {} @@ -101,6 +108,81 @@ def _get_mul_replacement(self, args, meta): return exir_ops.edge.cortex_m.quantized_mul.default, args + def _compute_softmax_params(self, input_scale: float) -> tuple[int, int, int]: + """ + Convert the incoming per-tensor input scale into the CMSIS fixed-point + parameters expected by `arm_softmax_s8`. + + 1. Clamp the real multiplier to the Q31 range using the fixed number of + input integer bits mandated by CMSIS. + 2. Feed that multiplier through `quantize_multiplier_aot` to get the + (multiplier, shift) pair arm_softmax_s8 expects. + 3. Derive `diff_min`, the CMSIS threshold for early bailout when + differences saturate, using the same multiplier/shift values. + """ + real_multiplier = min( + input_scale * (1 << (31 - self._SOFTMAX_INPUT_INTEGER_BITS)), + float((1 << 31) - 1), + ) + input_multiplier, input_shift = quantize_multiplier_aot(real_multiplier) + diff_min_term = ( + ((1 << self._SOFTMAX_INPUT_INTEGER_BITS) - 1) + * math.ldexp(1.0, 31 - self._SOFTMAX_INPUT_INTEGER_BITS) + / math.ldexp(1.0, input_shift) + ) + diff_min = -int(math.floor(diff_min_term)) + return int(input_multiplier), int(input_shift), diff_min + + def _get_softmax_replacement(self, args, meta): + if ( + meta.data.get("input_qparams", {}) == {} + or meta.data.get("output_qparams", {}) == {} + ): + return exir_ops.edge.aten._softmax.default, args + + input_qparams = meta["input_qparams"][0] + output_qparams = meta["output_qparams"][0] + + half_to_float = args[2] if len(args) > 2 else False + if half_to_float: + return exir_ops.edge.aten._softmax.default, args + + input_multiplier, input_shift, diff_min = self._compute_softmax_params( + float(input_qparams.scale) + ) + + output_scale_attr = getattr(output_qparams, "scale", None) + output_zp_attr = getattr(output_qparams, "zp", None) + if output_scale_attr is None or output_zp_attr is None: + raise AssertionError("Softmax requires output quantization parameters.") + + output_scale_val = float(output_scale_attr) + output_zp_val = int(output_zp_attr) + if not math.isclose( + output_scale_val, CMSIS_SOFTMAX_SCALE, rel_tol=0.0, abs_tol=1e-12 + ): + raise AssertionError( + "Softmax output scale must match CMSIS (1/256). " + f"Got {output_scale_val}." + ) + if output_zp_val != CMSIS_SOFTMAX_ZERO_POINT: + raise AssertionError( + "Softmax output zero-point must match CMSIS (-128). " + f"Got {output_zp_val}." + ) + + new_args = ( + args[0], + args[1], + int(input_qparams.zp), + output_zp_val, + input_multiplier, + input_shift, + diff_min, + ) + + return exir_ops.edge.cortex_m.softmax.default, new_args + def _get_minimum_replacement(self, args, meta): if args[0].data.dtype != torch.int8: return exir_ops.edge.aten.minimum.default, args @@ -135,6 +217,8 @@ def call_operator( op, args = self._get_add_replacement(args, meta) case exir_ops.edge.aten.mul.Tensor: op, args = self._get_mul_replacement(args, meta) + case exir_ops.edge.aten._softmax.default: + op, args = self._get_softmax_replacement(args, meta) case exir_ops.edge.aten.minimum.default: op, args = self._get_minimum_replacement(args, meta) case exir_ops.edge.aten.maximum.default: @@ -144,4 +228,5 @@ def call_operator( case _: pass - return super().call_operator(op, args, {}, meta) + result = super().call_operator(op, args, {}, meta) + return result diff --git a/backends/cortex_m/quantizer/operator_configs.py b/backends/cortex_m/quantizer/operator_configs.py index dadee30fa41..5458eefdbbb 100644 --- a/backends/cortex_m/quantizer/operator_configs.py +++ b/backends/cortex_m/quantizer/operator_configs.py @@ -12,6 +12,7 @@ from executorch.backends.cortex_m.quantizer.quantization_configs import ( INT8_PER_CHANNEL_CONFIG, INT8_PER_TENSOR_CONFIG, + SOFTMAX_PER_TENSOR_CONFIG, ) from torchao.quantization.pt2e.quantizer import OperatorConfig @@ -47,6 +48,11 @@ [torch.ops.aten.conv2d.default, torch.ops.aten.clamp_.default], ] +SOFTMAX_OP_PATTERNS = [ + [torch.ops.aten._softmax.default], + [torch.ops.aten.softmax.int], +] + # ----------------- OPERATOR CONFIG PRESETS ----------------- INT8_BINARY_OPS_OPERATOR_CONFIG = OperatorConfig( INT8_PER_TENSOR_CONFIG, BINARY_OP_PATTERNS @@ -61,3 +67,8 @@ INT8_PER_CHANNEL_CONFIG, CONV_OP_PATTERNS, ) + +INT8_SOFTMAX_OPERATOR_CONFIG = OperatorConfig( + SOFTMAX_PER_TENSOR_CONFIG, + SOFTMAX_OP_PATTERNS, +) diff --git a/backends/cortex_m/quantizer/quantization_configs.py b/backends/cortex_m/quantizer/quantization_configs.py index c6600241b6d..7ec84c61bf4 100644 --- a/backends/cortex_m/quantizer/quantization_configs.py +++ b/backends/cortex_m/quantizer/quantization_configs.py @@ -5,6 +5,7 @@ import torch +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from torchao.quantization.pt2e import ( HistogramObserver, MinMaxObserver, @@ -12,7 +13,7 @@ ) from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, - QuantizationConfig, + FixedQParamsQuantizationSpec, QuantizationSpec, ) @@ -43,6 +44,19 @@ ch_axis=0, ) +# Constants shared by Cortex-M quantized operators. +CMSIS_SOFTMAX_SCALE: float = 1.0 / 256.0 +CMSIS_SOFTMAX_ZERO_POINT: int = -128 + +SOFTMAX_OUTPUT_FIXED_QSPEC = FixedQParamsQuantizationSpec( + dtype=torch.int8, + scale=CMSIS_SOFTMAX_SCALE, + zero_point=CMSIS_SOFTMAX_ZERO_POINT, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, +) + def _derive_bias_qparams_fn( obs_or_fqs, @@ -97,3 +111,11 @@ def _get_int32_per_channel_bias_qspec(node): INT8_WEIGHT_PER_CHANNEL_QSPEC, _get_int32_per_channel_bias_qspec, ) + + +SOFTMAX_PER_TENSOR_CONFIG = QuantizationConfig( + INT8_ACTIVATION_PER_TENSOR_QSPEC, + SOFTMAX_OUTPUT_FIXED_QSPEC, + None, + None, +) diff --git a/backends/cortex_m/quantizer/quantizer.py b/backends/cortex_m/quantizer/quantizer.py index 185a39b9eae..5ed866f67b5 100644 --- a/backends/cortex_m/quantizer/quantizer.py +++ b/backends/cortex_m/quantizer/quantizer.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. -from typing import Callable, List, Optional +from typing import Any, Callable, List, Optional import torch from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor @@ -21,6 +21,8 @@ INT8_BINARY_OPS_OPERATOR_CONFIG, INT8_CONV_OPERATOR_CONFIG, INT8_LINEAR_OPERATOR_CONFIG, + INT8_SOFTMAX_OPERATOR_CONFIG, + SOFTMAX_OP_PATTERNS, ) from executorch.backends.cortex_m.quantizer.quantization_configs import ( INT8_PER_TENSOR_CONFIG, @@ -86,6 +88,76 @@ def nchw_filter(self, node: Optional[Node]) -> bool: return not is_channels_last(tensor) + @staticmethod + def _resolve_int(value: Any) -> Optional[int]: + """Best-effort conversion of FX node arguments to ints.""" + if isinstance(value, int): + return value + if hasattr(value, "item"): + try: + return int(value.item()) # type: ignore[arg-type] + except Exception: + return None + if hasattr(value, "meta"): + meta_val = value.meta.get("val") + return CortexMQuantizer._resolve_int(meta_val) + return None + + def _extract_dim(self, node: Node) -> Optional[int]: + """Return the dim argument from a softmax node when statically known.""" + dim_arg = None + if len(node.args) > 1: + dim_arg = node.args[1] + elif "dim" in node.kwargs: + dim_arg = node.kwargs["dim"] + + if dim_arg is None: + return -1 + + return self._resolve_int(dim_arg) + + def softmax_memory_format_filter(self, node: Optional[Node]) -> bool: + """ + Return true given the tensor must either + - be contiguous (default layout) with softmax dim == last logical dim, or + - be channels_last with softmax dim == channel dim. + Any other combination is skipped so the op stays in ATen form. + """ + if node is None: + return False + if [node.target] not in SOFTMAX_OP_PATTERNS: + return False + + tensor = get_first_fake_tensor(node) + if tensor is None: + return True + + dim = self._extract_dim(node) + if dim is None: + return True + + rank = tensor.dim() + if rank == 0: + return True + + positive_dim = dim if dim >= 0 else dim + rank + if positive_dim < 0 or positive_dim >= rank: + return True + + is_channels_last = False + if rank == 4: + is_channels_last = tensor.is_contiguous(memory_format=torch.channels_last) + + if is_channels_last: + channel_dim = 1 if rank >= 2 else rank - 1 + if positive_dim != channel_dim: + return True + else: + if positive_dim != rank - 1: + return True + + return False + def __init__(self) -> None: quantizers: List[Quantizer] = [ OperatorConfigQuantizer( @@ -95,6 +167,10 @@ def __init__(self) -> None: OperatorConfigQuantizer( INT8_CONV_OPERATOR_CONFIG, filter_fn=self.nchw_filter ), + OperatorConfigQuantizer( + INT8_SOFTMAX_OPERATOR_CONFIG, + filter_fn=self.softmax_memory_format_filter, + ), InputQuantizer(INT8_PER_TENSOR_CONFIG), OutputQuantizer(INT8_PER_TENSOR_CONFIG), SharedQspecQuantizer(), diff --git a/backends/cortex_m/test/ops/test_softmax.py b/backends/cortex_m/test/ops/test_softmax.py new file mode 100644 index 00000000000..cef4060f25a --- /dev/null +++ b/backends/cortex_m/test/ops/test_softmax.py @@ -0,0 +1,80 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMSoftmax(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten__softmax_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_softmax_default": 1, + } + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.softmax(x, dim=self.dim) + + +test_cases = { + "rank1": McuTestCase( + CortexMSoftmax(dim=-1), + (ramp_tensor(-4, 4, (16,)),), + ), + "rank2": McuTestCase( + CortexMSoftmax(dim=-1), + (ramp_tensor(-8, 8, (4, 8)),), + ), + "rank3": McuTestCase( + CortexMSoftmax(dim=-1), + (ramp_tensor(-2, 2, (2, 3, 4)),), + ), + "dim_not_last": McuTestCase( + CortexMSoftmax(dim=1), + (ramp_tensor(-2, 2, (2, 3, 4)),), + ), +} + + +xfail_cases_dialect = { + "dim_not_last": ( + "Softmax stays in ATen when dim isn’t the channel dimension, so dialect expectations fail", + Exception, + ), +} +xfail_cases_impl = { + "dim_not_last": ( + "Softmax on Cortex-M currently supports only the last dimension", + Exception, + ), +} + + +@parametrize("test_case", test_cases, xfails=xfail_cases_dialect) +def test_dialect_softmax(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=2, + ) + + +@parametrize("test_case", test_cases, xfails=xfail_cases_impl) +def test_implementation_softmax(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation(qtol=2) diff --git a/backends/cortex_m/test/tester.py b/backends/cortex_m/test/tester.py index ce5f16195c0..ca6a0377218 100644 --- a/backends/cortex_m/test/tester.py +++ b/backends/cortex_m/test/tester.py @@ -41,7 +41,8 @@ def __init__(self): torch.ops.aten.hardsigmoid_.default, torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default, - ] + ], + _check_ir_validity=False, ) super().__init__(config)