Skip to content

Commit ce39bbe

Browse files
committed
Add softmax support for int8 in Cortex M (dim=-1)
- integrate CMSIS softmax into Cortex-M backend - add fusion pass/tests for quantized softmax - lint cleanup passes - Resolved merge conflicts Change-Id: I0ec19f011069fa1482e2de2ab62b9e7d7f56b2a8 Signed-off-by: Xingguo Li <xingguo.li@arm.com>
1 parent 2441917 commit ce39bbe

File tree

10 files changed

+500
-4
lines changed

10 files changed

+500
-4
lines changed

backends/cortex_m/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ set(_cortex_m_kernels__srcs
6161
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
6262
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp
6363
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_maximum.cpp
64+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_softmax.cpp
6465
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_transpose.cpp
6566
)
6667

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Copyright 2025 Arm Limited and/or its affiliates.
3+
*
4+
* This source code is licensed under the BSD-style license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include "cortex_m_ops_common.h"
9+
10+
#include <cmath>
11+
#include <cstdint>
12+
#include <limits>
13+
14+
// Include CMSIS-NN headers with C linkage
15+
extern "C" {
16+
#include "arm_nnfunctions.h"
17+
}
18+
19+
namespace cortex_m {
20+
namespace native {
21+
22+
namespace {
23+
24+
constexpr int32_t kCmsisSoftmaxZeroPoint = -128;
25+
26+
inline bool is_int8_tensor(const Tensor& tensor) {
27+
return tensor.scalar_type() == ScalarType::Char;
28+
}
29+
30+
inline bool is_last_dim(const Tensor& tensor, int64_t dim) {
31+
const auto rank = tensor.dim();
32+
const int64_t positive_dim = dim >= 0 ? dim : dim + rank;
33+
return positive_dim == static_cast<int64_t>(rank - 1);
34+
}
35+
36+
inline int64_t normalize_dim(const Tensor& tensor, int64_t dim) {
37+
const auto rank = tensor.dim();
38+
const int64_t positive_dim = dim >= 0 ? dim : dim + rank;
39+
return positive_dim;
40+
}
41+
42+
} // namespace
43+
44+
Tensor& softmax_out(
45+
KernelRuntimeContext& context,
46+
const Tensor& input,
47+
int64_t dim,
48+
int64_t input_zero_point,
49+
int64_t output_zero_point,
50+
int64_t input_multiplier,
51+
int64_t input_shift,
52+
int64_t diff_min,
53+
Tensor& out) {
54+
if (!is_int8_tensor(input) || !is_int8_tensor(out)) {
55+
ET_LOG(
56+
Error,
57+
"softmax_out: only int8 tensors are supported (input=%d, out=%d)",
58+
static_cast<int>(input.scalar_type()),
59+
static_cast<int>(out.scalar_type()));
60+
context.fail(Error::InvalidArgument);
61+
return out;
62+
}
63+
64+
if (!is_last_dim(input, dim)) {
65+
ET_LOG(
66+
Error,
67+
"softmax_out: only last-dimension softmax is supported (dim=%lld, rank=%zu)",
68+
static_cast<long long>(dim),
69+
static_cast<size_t>(input.dim()));
70+
context.fail(Error::InvalidArgument);
71+
return out;
72+
}
73+
74+
const int32_t input_zp_val = static_cast<int32_t>(input_zero_point);
75+
const int32_t output_zp_val = static_cast<int32_t>(output_zero_point);
76+
(void)input_zp_val; // Zero-point difference cancels out during subtraction.
77+
78+
validate_single_quant_params(
79+
Scalar(input_zp_val),
80+
Scalar(input_multiplier),
81+
Scalar(input_shift),
82+
"softmax input");
83+
84+
const auto positive_dim = normalize_dim(input, dim);
85+
const int64_t row_size64 = input.size(positive_dim);
86+
if (row_size64 <= 0 || row_size64 > std::numeric_limits<int32_t>::max()) {
87+
ET_LOG(
88+
Error,
89+
"softmax_out: row size must fit in int32 (row_size=%lld)",
90+
static_cast<long long>(row_size64));
91+
context.fail(Error::InvalidArgument);
92+
return out;
93+
}
94+
95+
const int32_t row_size = static_cast<int32_t>(row_size64);
96+
const int64_t num_rows64 = input.numel() / row_size64;
97+
if (num_rows64 <= 0 || num_rows64 > std::numeric_limits<int32_t>::max()) {
98+
ET_LOG(
99+
Error,
100+
"softmax_out: num_rows must fit in int32 (num_rows=%lld)",
101+
static_cast<long long>(num_rows64));
102+
context.fail(Error::InvalidArgument);
103+
return out;
104+
}
105+
const int32_t num_rows = static_cast<int32_t>(num_rows64);
106+
107+
const int8_t* input_data = input.const_data_ptr<int8_t>();
108+
int8_t* output_data = out.mutable_data_ptr<int8_t>();
109+
110+
if (num_rows <= 0 || row_size <= 0) {
111+
ET_LOG(
112+
Error,
113+
"softmax_out: invalid args (dim=%ld, rows=%d, row_size=%d)",
114+
static_cast<long>(dim),
115+
num_rows,
116+
row_size);
117+
context.fail(Error::InvalidArgument);
118+
return out;
119+
}
120+
121+
const int32_t input_multiplier_val = static_cast<int32_t>(input_multiplier);
122+
const int32_t input_shift_val = static_cast<int32_t>(input_shift);
123+
const int32_t diff_min_val = static_cast<int32_t>(diff_min);
124+
125+
if (output_zp_val != kCmsisSoftmaxZeroPoint) {
126+
ET_LOG(
127+
Error,
128+
"softmax_out: expected output zero_point=%d (got zero_point=%d)",
129+
kCmsisSoftmaxZeroPoint,
130+
output_zp_val);
131+
context.fail(Error::InvalidArgument);
132+
return out;
133+
}
134+
135+
arm_softmax_s8(
136+
input_data,
137+
num_rows,
138+
row_size,
139+
input_multiplier_val,
140+
input_shift_val,
141+
diff_min_val,
142+
output_data);
143+
144+
return out;
145+
}
146+
147+
} // namespace native
148+
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
import math
89
from math import prod
910
from typing import Sequence
1011

@@ -14,6 +15,10 @@
1415
requantize_cmsis,
1516
SHIFT_INT8,
1617
)
18+
from executorch.backends.cortex_m.quantizer.quantization_configs import (
19+
CMSIS_SOFTMAX_SCALE,
20+
CMSIS_SOFTMAX_ZERO_POINT,
21+
)
1722
from executorch.exir.dialects._ops import ops as exir_ops
1823

1924
# To provide the implementation of the operators
@@ -23,6 +28,8 @@
2328
# New operator library with a custom namespace to allow fusion etc.
2429
lib = Library("cortex_m", "DEF")
2530

31+
SOFTMAX_INPUT_INTEGER_BITS = 5
32+
2633
###
2734
# dequantize_per_tensor
2835
###
@@ -394,6 +401,65 @@ def quantized_linear_impl(
394401
return output
395402

396403

404+
# ===================================================================
405+
# SOFTMAX OPERATION DEFINITION
406+
# ===================================================================
407+
408+
lib.define(
409+
"softmax(Tensor input, int dim, int input_zero_point, int output_zero_point, int input_multiplier, int input_shift, int diff_min) -> Tensor"
410+
)
411+
lib.define(
412+
"softmax.out(Tensor input, int dim, int input_zero_point, int output_zero_point, int input_multiplier, int input_shift, int diff_min, *, Tensor(a!) out) -> Tensor(a!)"
413+
)
414+
415+
416+
@register_fake("cortex_m::softmax")
417+
def softmax_meta(
418+
input: torch.Tensor,
419+
dim: int,
420+
input_zero_point: int,
421+
output_zero_point: int,
422+
input_multiplier: int,
423+
input_shift: int,
424+
diff_min: int,
425+
) -> torch.Tensor:
426+
return torch.empty_like(input, dtype=torch.int8)
427+
428+
429+
@impl(lib, "softmax", "CompositeExplicitAutograd")
430+
def softmax_impl(
431+
input: torch.Tensor,
432+
dim: int,
433+
input_zero_point: int,
434+
output_zero_point: int,
435+
input_multiplier: int,
436+
input_shift: int,
437+
diff_min: int,
438+
) -> torch.Tensor:
439+
del diff_min # not used in reference path
440+
if input.dtype != torch.int8:
441+
raise TypeError(
442+
f"cortex_m.softmax: expected int8 input tensor, got {input.dtype}"
443+
)
444+
if output_zero_point != CMSIS_SOFTMAX_ZERO_POINT:
445+
raise ValueError(
446+
f"cortex_m.softmax: expected output_zero_point {CMSIS_SOFTMAX_ZERO_POINT}, got {output_zero_point}"
447+
)
448+
449+
real_multiplier = float(input_multiplier) / float(1 << 31)
450+
real_multiplier = math.ldexp(real_multiplier, input_shift)
451+
input_scale = real_multiplier / float(1 << (31 - SOFTMAX_INPUT_INTEGER_BITS))
452+
if input_scale <= 0:
453+
raise ValueError(
454+
f"cortex_m.softmax: derived non-positive input scale {input_scale}"
455+
)
456+
457+
input_fp = (input.to(torch.int32) - int(input_zero_point)).float() * input_scale
458+
probs = torch.softmax(input_fp, dim=dim)
459+
quantized = torch.round(probs / CMSIS_SOFTMAX_SCALE) + int(output_zero_point)
460+
return quantized.clamp(-128, 127).to(torch.int8)
461+
462+
397463
# ===================================================================
398464
# TRANSPOSE OPERATION DEFINITION
399465
# ===================================================================

backends/cortex_m/ops/operators.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
- arg_meta: null
4848
kernel_name: cortex_m::quantized_linear_out
4949

50+
- func: cortex_m::softmax.out(Tensor input, int dim, int input_zero_point, int output_zero_point, int input_multiplier, int input_shift, int diff_min, *, Tensor(a!) out) -> Tensor(a!)
51+
variants: function
52+
kernels:
53+
- arg_meta: null
54+
kernel_name: cortex_m::softmax_out
55+
5056
- func: cortex_m::transpose.out(Tensor input, int[] perm, *, Tensor(a!) out) -> Tensor(a!)
5157
variants: function
5258
kernels:

backends/cortex_m/passes/quantized_op_fusion_pass.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
import math
89
from typing import Dict
910

1011
import torch
@@ -13,6 +14,10 @@
1314
quantize_multiplier_aot,
1415
SHIFT_INT8,
1516
)
17+
from executorch.backends.cortex_m.quantizer.quantization_configs import (
18+
CMSIS_SOFTMAX_SCALE,
19+
CMSIS_SOFTMAX_ZERO_POINT,
20+
)
1621

1722
from executorch.exir.dialects._ops import ops as exir_ops
1823
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -31,6 +36,8 @@ class QuantizedOpFusionPass(ExportPass):
3136
Supports multiple binary operations with backward compatibility for add.
3237
"""
3338

39+
_SOFTMAX_INPUT_INTEGER_BITS = 5
40+
3441
def _get_add_replacement(self, args, meta):
3542
if (
3643
meta.data.get("input_qparams", {}) == {}
@@ -101,6 +108,81 @@ def _get_mul_replacement(self, args, meta):
101108

102109
return exir_ops.edge.cortex_m.quantized_mul.default, args
103110

111+
def _compute_softmax_params(self, input_scale: float) -> tuple[int, int, int]:
112+
"""
113+
Convert the incoming per-tensor input scale into the CMSIS fixed-point
114+
parameters expected by `arm_softmax_s8`.
115+
116+
1. Clamp the real multiplier to the Q31 range using the fixed number of
117+
input integer bits mandated by CMSIS.
118+
2. Feed that multiplier through `quantize_multiplier_aot` to get the
119+
(multiplier, shift) pair arm_softmax_s8 expects.
120+
3. Derive `diff_min`, the CMSIS threshold for early bailout when
121+
differences saturate, using the same multiplier/shift values.
122+
"""
123+
real_multiplier = min(
124+
input_scale * (1 << (31 - self._SOFTMAX_INPUT_INTEGER_BITS)),
125+
float((1 << 31) - 1),
126+
)
127+
input_multiplier, input_shift = quantize_multiplier_aot(real_multiplier)
128+
diff_min_term = (
129+
((1 << self._SOFTMAX_INPUT_INTEGER_BITS) - 1)
130+
* math.ldexp(1.0, 31 - self._SOFTMAX_INPUT_INTEGER_BITS)
131+
/ math.ldexp(1.0, input_shift)
132+
)
133+
diff_min = -int(math.floor(diff_min_term))
134+
return int(input_multiplier), int(input_shift), diff_min
135+
136+
def _get_softmax_replacement(self, args, meta):
137+
if (
138+
meta.data.get("input_qparams", {}) == {}
139+
or meta.data.get("output_qparams", {}) == {}
140+
):
141+
return exir_ops.edge.aten._softmax.default, args
142+
143+
input_qparams = meta["input_qparams"][0]
144+
output_qparams = meta["output_qparams"][0]
145+
146+
half_to_float = args[2] if len(args) > 2 else False
147+
if half_to_float:
148+
return exir_ops.edge.aten._softmax.default, args
149+
150+
input_multiplier, input_shift, diff_min = self._compute_softmax_params(
151+
float(input_qparams.scale)
152+
)
153+
154+
output_scale_attr = getattr(output_qparams, "scale", None)
155+
output_zp_attr = getattr(output_qparams, "zp", None)
156+
if output_scale_attr is None or output_zp_attr is None:
157+
raise AssertionError("Softmax requires output quantization parameters.")
158+
159+
output_scale_val = float(output_scale_attr)
160+
output_zp_val = int(output_zp_attr)
161+
if not math.isclose(
162+
output_scale_val, CMSIS_SOFTMAX_SCALE, rel_tol=0.0, abs_tol=1e-12
163+
):
164+
raise AssertionError(
165+
"Softmax output scale must match CMSIS (1/256). "
166+
f"Got {output_scale_val}."
167+
)
168+
if output_zp_val != CMSIS_SOFTMAX_ZERO_POINT:
169+
raise AssertionError(
170+
"Softmax output zero-point must match CMSIS (-128). "
171+
f"Got {output_zp_val}."
172+
)
173+
174+
new_args = (
175+
args[0],
176+
args[1],
177+
int(input_qparams.zp),
178+
output_zp_val,
179+
input_multiplier,
180+
input_shift,
181+
diff_min,
182+
)
183+
184+
return exir_ops.edge.cortex_m.softmax.default, new_args
185+
104186
def _get_minimum_replacement(self, args, meta):
105187
if args[0].data.dtype != torch.int8:
106188
return exir_ops.edge.aten.minimum.default, args
@@ -135,6 +217,8 @@ def call_operator(
135217
op, args = self._get_add_replacement(args, meta)
136218
case exir_ops.edge.aten.mul.Tensor:
137219
op, args = self._get_mul_replacement(args, meta)
220+
case exir_ops.edge.aten._softmax.default:
221+
op, args = self._get_softmax_replacement(args, meta)
138222
case exir_ops.edge.aten.minimum.default:
139223
op, args = self._get_minimum_replacement(args, meta)
140224
case exir_ops.edge.aten.maximum.default:
@@ -144,4 +228,5 @@ def call_operator(
144228
case _:
145229
pass
146230

147-
return super().call_operator(op, args, {}, meta)
231+
result = super().call_operator(op, args, {}, meta)
232+
return result

0 commit comments

Comments
 (0)