Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cortex_m/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ set(_cortex_m_kernels__srcs
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_maximum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_transpose.cpp
)

Expand Down
148 changes: 148 additions & 0 deletions backends/cortex_m/ops/op_softmax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright 2025 Arm Limited and/or its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "cortex_m_ops_common.h"

#include <cmath>
#include <cstdint>
#include <limits>

// Include CMSIS-NN headers with C linkage
extern "C" {
#include "arm_nnfunctions.h"
}

namespace cortex_m {
namespace native {

namespace {

constexpr int32_t kCmsisSoftmaxZeroPoint = -128;

inline bool is_int8_tensor(const Tensor& tensor) {
return tensor.scalar_type() == ScalarType::Char;
}

inline bool is_last_dim(const Tensor& tensor, int64_t dim) {
const auto rank = tensor.dim();
const int64_t positive_dim = dim >= 0 ? dim : dim + rank;
return positive_dim == static_cast<int64_t>(rank - 1);
}

inline int64_t normalize_dim(const Tensor& tensor, int64_t dim) {
const auto rank = tensor.dim();
const int64_t positive_dim = dim >= 0 ? dim : dim + rank;
return positive_dim;
}

} // namespace

Tensor& softmax_out(
KernelRuntimeContext& context,
const Tensor& input,
int64_t dim,
int64_t input_zero_point,
int64_t output_zero_point,
int64_t input_multiplier,
int64_t input_shift,
int64_t diff_min,
Tensor& out) {
if (!is_int8_tensor(input) || !is_int8_tensor(out)) {
ET_LOG(
Error,
"softmax_out: only int8 tensors are supported (input=%d, out=%d)",
static_cast<int>(input.scalar_type()),
static_cast<int>(out.scalar_type()));
context.fail(Error::InvalidArgument);
return out;
}

if (!is_last_dim(input, dim)) {
ET_LOG(
Error,
"softmax_out: only last-dimension softmax is supported (dim=%lld, rank=%zu)",
static_cast<long long>(dim),
static_cast<size_t>(input.dim()));
context.fail(Error::InvalidArgument);
return out;
}

const int32_t input_zp_val = static_cast<int32_t>(input_zero_point);
const int32_t output_zp_val = static_cast<int32_t>(output_zero_point);
(void)input_zp_val; // Zero-point difference cancels out during subtraction.

validate_single_quant_params(
Scalar(input_zp_val),
Scalar(input_multiplier),
Scalar(input_shift),
"softmax input");

const auto positive_dim = normalize_dim(input, dim);
const int64_t row_size64 = input.size(positive_dim);
if (row_size64 <= 0 || row_size64 > std::numeric_limits<int32_t>::max()) {
ET_LOG(
Error,
"softmax_out: row size must fit in int32 (row_size=%lld)",
static_cast<long long>(row_size64));
context.fail(Error::InvalidArgument);
return out;
}

const int32_t row_size = static_cast<int32_t>(row_size64);
const int64_t num_rows64 = input.numel() / row_size64;
if (num_rows64 <= 0 || num_rows64 > std::numeric_limits<int32_t>::max()) {
ET_LOG(
Error,
"softmax_out: num_rows must fit in int32 (num_rows=%lld)",
static_cast<long long>(num_rows64));
context.fail(Error::InvalidArgument);
return out;
}
const int32_t num_rows = static_cast<int32_t>(num_rows64);

const int8_t* input_data = input.const_data_ptr<int8_t>();
int8_t* output_data = out.mutable_data_ptr<int8_t>();

if (num_rows <= 0 || row_size <= 0) {
ET_LOG(
Error,
"softmax_out: invalid args (dim=%ld, rows=%d, row_size=%d)",
static_cast<long>(dim),
num_rows,
row_size);
context.fail(Error::InvalidArgument);
return out;
}

const int32_t input_multiplier_val = static_cast<int32_t>(input_multiplier);
const int32_t input_shift_val = static_cast<int32_t>(input_shift);
const int32_t diff_min_val = static_cast<int32_t>(diff_min);

if (output_zp_val != kCmsisSoftmaxZeroPoint) {
ET_LOG(
Error,
"softmax_out: expected output zero_point=%d (got zero_point=%d)",
kCmsisSoftmaxZeroPoint,
output_zp_val);
context.fail(Error::InvalidArgument);
return out;
}

arm_softmax_s8(
input_data,
num_rows,
row_size,
input_multiplier_val,
input_shift_val,
diff_min_val,
output_data);

return out;
}

} // namespace native
} // namespace cortex_m
66 changes: 66 additions & 0 deletions backends/cortex_m/ops/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
from math import prod
from typing import Sequence

Expand All @@ -15,6 +16,10 @@
requantize_cmsis,
SHIFT_INT8,
)
from executorch.backends.cortex_m.quantizer.quantization_configs import (
CMSIS_SOFTMAX_SCALE,
CMSIS_SOFTMAX_ZERO_POINT,
)
from executorch.exir.dialects._ops import ops as exir_ops

# To provide the implementation of the operators
Expand All @@ -24,6 +29,8 @@
# New operator library with a custom namespace to allow fusion etc.
lib = Library("cortex_m", "DEF")

SOFTMAX_INPUT_INTEGER_BITS = 5

###
# dequantize_per_tensor
###
Expand Down Expand Up @@ -401,6 +408,65 @@ def quantized_linear_impl(
return output


# ===================================================================
# SOFTMAX OPERATION DEFINITION
# ===================================================================

lib.define(
"softmax(Tensor input, int dim, int input_zero_point, int output_zero_point, int input_multiplier, int input_shift, int diff_min) -> Tensor"
)
lib.define(
"softmax.out(Tensor input, int dim, int input_zero_point, int output_zero_point, int input_multiplier, int input_shift, int diff_min, *, Tensor(a!) out) -> Tensor(a!)"
)


@register_fake("cortex_m::softmax")
def softmax_meta(
input: torch.Tensor,
dim: int,
input_zero_point: int,
output_zero_point: int,
input_multiplier: int,
input_shift: int,
diff_min: int,
) -> torch.Tensor:
return torch.empty_like(input, dtype=torch.int8)


@impl(lib, "softmax", "CompositeExplicitAutograd")
def softmax_impl(
input: torch.Tensor,
dim: int,
input_zero_point: int,
output_zero_point: int,
input_multiplier: int,
input_shift: int,
diff_min: int,
) -> torch.Tensor:
del diff_min # not used in reference path
if input.dtype != torch.int8:
raise TypeError(
f"cortex_m.softmax: expected int8 input tensor, got {input.dtype}"
)
if output_zero_point != CMSIS_SOFTMAX_ZERO_POINT:
raise ValueError(
f"cortex_m.softmax: expected output_zero_point {CMSIS_SOFTMAX_ZERO_POINT}, got {output_zero_point}"
)

real_multiplier = float(input_multiplier) / float(1 << 31)
real_multiplier = math.ldexp(real_multiplier, input_shift)
input_scale = real_multiplier / float(1 << (31 - SOFTMAX_INPUT_INTEGER_BITS))
if input_scale <= 0:
raise ValueError(
f"cortex_m.softmax: derived non-positive input scale {input_scale}"
)

input_fp = (input.to(torch.int32) - int(input_zero_point)).float() * input_scale
probs = torch.softmax(input_fp, dim=dim)
quantized = torch.round(probs / CMSIS_SOFTMAX_SCALE) + int(output_zero_point)
return quantized.clamp(-128, 127).to(torch.int8)


# ===================================================================
# TRANSPOSE OPERATION DEFINITION
# ===================================================================
Expand Down
6 changes: 6 additions & 0 deletions backends/cortex_m/ops/operators.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@
- arg_meta: null
kernel_name: cortex_m::quantized_linear_out

- func: cortex_m::softmax.out(Tensor input, int dim, int input_zero_point, int output_zero_point, int input_multiplier, int input_shift, int diff_min, *, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
- arg_meta: null
kernel_name: cortex_m::softmax_out

- func: cortex_m::transpose.out(Tensor input, int[] perm, *, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
Expand Down
87 changes: 86 additions & 1 deletion backends/cortex_m/passes/quantized_op_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import Dict

import torch
Expand All @@ -13,6 +14,10 @@
quantize_multiplier_aot,
SHIFT_INT8,
)
from executorch.backends.cortex_m.quantizer.quantization_configs import (
CMSIS_SOFTMAX_SCALE,
CMSIS_SOFTMAX_ZERO_POINT,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
Expand All @@ -31,6 +36,8 @@ class QuantizedOpFusionPass(ExportPass):
Supports multiple binary operations with backward compatibility for add.
"""

_SOFTMAX_INPUT_INTEGER_BITS = 5

def _get_add_replacement(self, args, meta):
if (
meta.data.get("input_qparams", {}) == {}
Expand Down Expand Up @@ -101,6 +108,81 @@ def _get_mul_replacement(self, args, meta):

return exir_ops.edge.cortex_m.quantized_mul.default, args

def _compute_softmax_params(self, input_scale: float) -> tuple[int, int, int]:
"""
Convert the incoming per-tensor input scale into the CMSIS fixed-point
parameters expected by `arm_softmax_s8`.

1. Clamp the real multiplier to the Q31 range using the fixed number of
input integer bits mandated by CMSIS.
2. Feed that multiplier through `quantize_multiplier_aot` to get the
(multiplier, shift) pair arm_softmax_s8 expects.
3. Derive `diff_min`, the CMSIS threshold for early bailout when
differences saturate, using the same multiplier/shift values.
"""
real_multiplier = min(
input_scale * (1 << (31 - self._SOFTMAX_INPUT_INTEGER_BITS)),
float((1 << 31) - 1),
)
input_multiplier, input_shift = quantize_multiplier_aot(real_multiplier)
diff_min_term = (
((1 << self._SOFTMAX_INPUT_INTEGER_BITS) - 1)
* math.ldexp(1.0, 31 - self._SOFTMAX_INPUT_INTEGER_BITS)
/ math.ldexp(1.0, input_shift)
)
diff_min = -int(math.floor(diff_min_term))
return int(input_multiplier), int(input_shift), diff_min

def _get_softmax_replacement(self, args, meta):
if (
meta.data.get("input_qparams", {}) == {}
or meta.data.get("output_qparams", {}) == {}
):
return exir_ops.edge.aten._softmax.default, args

input_qparams = meta["input_qparams"][0]
output_qparams = meta["output_qparams"][0]

half_to_float = args[2] if len(args) > 2 else False
if half_to_float:
return exir_ops.edge.aten._softmax.default, args

input_multiplier, input_shift, diff_min = self._compute_softmax_params(
float(input_qparams.scale)
)

output_scale_attr = getattr(output_qparams, "scale", None)
output_zp_attr = getattr(output_qparams, "zp", None)
if output_scale_attr is None or output_zp_attr is None:
raise AssertionError("Softmax requires output quantization parameters.")

output_scale_val = float(output_scale_attr)
output_zp_val = int(output_zp_attr)
if not math.isclose(
output_scale_val, CMSIS_SOFTMAX_SCALE, rel_tol=0.0, abs_tol=1e-12
):
raise AssertionError(
"Softmax output scale must match CMSIS (1/256). "
f"Got {output_scale_val}."
)
if output_zp_val != CMSIS_SOFTMAX_ZERO_POINT:
raise AssertionError(
"Softmax output zero-point must match CMSIS (-128). "
f"Got {output_zp_val}."
)

new_args = (
args[0],
args[1],
int(input_qparams.zp),
output_zp_val,
input_multiplier,
input_shift,
diff_min,
)

return exir_ops.edge.cortex_m.softmax.default, new_args

def _get_minimum_replacement(self, args, meta):
if args[0].data.dtype != torch.int8:
return exir_ops.edge.aten.minimum.default, args
Expand Down Expand Up @@ -135,6 +217,8 @@ def call_operator(
op, args = self._get_add_replacement(args, meta)
case exir_ops.edge.aten.mul.Tensor:
op, args = self._get_mul_replacement(args, meta)
case exir_ops.edge.aten._softmax.default:
op, args = self._get_softmax_replacement(args, meta)
case exir_ops.edge.aten.minimum.default:
op, args = self._get_minimum_replacement(args, meta)
case exir_ops.edge.aten.maximum.default:
Expand All @@ -144,4 +228,5 @@ def call_operator(
case _:
pass

return super().call_operator(op, args, {}, meta)
result = super().call_operator(op, args, {}, meta)
return result
Loading
Loading