Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cortex_m/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ set(_cortex_m_kernels__srcs
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_avg_pool2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp
Expand Down
86 changes: 86 additions & 0 deletions backends/cortex_m/ops/op_quantized_avg_pool2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2025 Arm Limited and/or its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "cortex_m_ops_common.h"

extern "C" {
#include "arm_nnfunctions.h"
}

namespace cortex_m {
namespace native {

using KernelRuntimeContext = torch::executor::KernelRuntimeContext;

Tensor& quantized_avg_pool2d_out(
KernelRuntimeContext& context,
const Tensor& input,
const IntArrayRef kernel_size,
const IntArrayRef stride,
const IntArrayRef padding,
const Scalar& zero_point,
const Scalar& multiplier,
const Scalar& shift,
Tensor& out) {
if (input.dim() != 4 || out.dim() != 4) {
ET_LOG(Error, "quantized_avg_pool2d_out: tensors must be 4-D");
context.fail(Error::InvalidArgument);
return out;
}
int32_t batch = static_cast<int32_t>(input.size(0));
int32_t channels = static_cast<int32_t>(input.size(1));
int32_t input_h = static_cast<int32_t>(input.size(2));
int32_t input_w = static_cast<int32_t>(input.size(3));
int32_t kernel_h = static_cast<int32_t>(kernel_size[0]);
int32_t kernel_w = static_cast<int32_t>(kernel_size[1]);
int32_t stride_h = static_cast<int32_t>(stride[0]);
int32_t stride_w = static_cast<int32_t>(stride[1]);
int32_t pad_h = static_cast<int32_t>(padding[0]);
int32_t pad_w = static_cast<int32_t>(padding[1]);
int32_t output_h = static_cast<int32_t>(out.size(2));
int32_t output_w = static_cast<int32_t>(out.size(3));
const int32_t activation_min = std::numeric_limits<int8_t>::min();
const int32_t activation_max = std::numeric_limits<int8_t>::max();

const int8_t* input_data = input.const_data_ptr<int8_t>();
int8_t* output_data = out.mutable_data_ptr<int8_t>();

cmsis_nn_context cmsis_ctx;
cmsis_ctx.buf = nullptr;
cmsis_ctx.size = 0;
cmsis_nn_pool_params pool_params;
pool_params.stride.h = stride_h;
pool_params.stride.w = stride_w;
pool_params.padding.h = pad_h;
pool_params.padding.w = pad_w;
pool_params.activation.min = activation_min;
pool_params.activation.max = activation_max;

cmsis_nn_dims input_dims{batch, input_h, input_w, channels};
cmsis_nn_dims filter_dims{1, kernel_h, kernel_w, 1};
cmsis_nn_dims output_dims{batch, output_h, output_w, channels};

arm_cmsis_nn_status status = arm_avgpool_s8(
&cmsis_ctx,
&pool_params,
&input_dims,
input_data,
&filter_dims,
&output_dims,
output_data);
if (status != ARM_CMSIS_NN_SUCCESS) {
ET_LOG(
Error,
"quantized_avg_pool2d_out: arm_avgpool_s8 failed with status [%d]",
status);
context.fail(Error::Internal);
}
return out;
}

} // namespace native
} // namespace cortex_m
73 changes: 73 additions & 0 deletions backends/cortex_m/ops/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import torch
import torch.nn.functional as F
from executorch.backends.cortex_m.passes.passes_utils import (
dequantize_per_tensor_cmsis,
is_channel_broadcast,
quantize_per_tensor_cmsis,
requantize_cmsis,
SHIFT_INT8,
)
Expand Down Expand Up @@ -577,3 +579,74 @@ def quantized_conv2d_impl(
result = torch.clamp(result, activation_min, activation_max)

return result.to(torch.int8)


# ===================================================================
# QUANTIZED AVG_POOL2D OPERATION DEFINITION
# ===================================================================

lib.define(
"quantized_avg_pool2d("
"Tensor input, "
"int[] kernel_size, "
"int[] stride, "
"int[] padding, "
"Scalar zero_point, "
"Scalar multiplier, "
"Scalar shift"
") -> Tensor"
)
lib.define(
"quantized_avg_pool2d.out("
"Tensor input, "
"int[] kernel_size, "
"int[] stride, "
"int[] padding, "
"Scalar zero_point, "
"Scalar multiplier, "
"Scalar shift, "
"*, Tensor(a!) out) -> Tensor(a!)"
)


@register_fake("cortex_m::quantized_avg_pool2d")
def quantized_avg_pool2d_meta(
input: torch.Tensor,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Sequence[int],
zero_point: int,
multiplier: int,
shift: int,
) -> torch.Tensor:
# Compute output shape as in PyTorch avg_pool2d

output = F.avg_pool2d(input, kernel_size, stride, padding)
return torch.empty_like(output, dtype=torch.int8)


@impl(lib, "quantized_avg_pool2d", "CompositeExplicitAutograd")
def quantized_avg_pool2d_impl(
input: torch.Tensor,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Sequence[int],
zero_point: int,
multiplier: int,
shift: int,
) -> torch.Tensor:

dequant_input = dequantize_per_tensor_cmsis(input, zero_point, multiplier, shift)

# TODO: implement count_include_pad=True, ceil_mode=True.
result = F.avg_pool2d(
dequant_input,
kernel_size,
stride=stride,
padding=padding,
count_include_pad=False,
ceil_mode=False,
)
result = quantize_per_tensor_cmsis(result, zero_point, multiplier, shift)
output = torch.clamp(result, -128, 127)
return output.to(torch.int8)
6 changes: 6 additions & 0 deletions backends/cortex_m/ops/operators.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,9 @@
kernels:
- arg_meta: null
kernel_name: cortex_m::quantized_conv2d_out

- func: cortex_m::quantized_avg_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, Scalar zero_point, Scalar multiplier, Scalar shift, *, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
- arg_meta: null
kernel_name: cortex_m::quantized_avg_pool2d_out
23 changes: 23 additions & 0 deletions backends/cortex_m/passes/quantized_op_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,27 @@ def _get_permute_replacement(self, args, meta):
args = (args[0], perms)
return exir_ops.edge.cortex_m.transpose.default, args

def _get_avg_pool2d_replacement(self, args, meta):
if (
meta.data.get("input_qparams", {}) == {}
or meta.data.get("output_qparams", {}) == {}
):
return exir_ops.edge.aten.avg_pool2d.default, args

# Extract values
scale = meta["input_qparams"][0].scale
zero_point = meta["input_qparams"][0].zp

output_mult, output_shift = quantize_multiplier_aot(scale)
args = (
*args[0:-2],
zero_point,
output_mult,
output_shift,
)

return exir_ops.edge.cortex_m.quantized_avg_pool2d.default, args

def call_operator(
self,
op: EdgeOpOverload,
Expand All @@ -141,6 +162,8 @@ def call_operator(
op, args = self._get_maximum_replacement(args, meta)
case exir_ops.edge.aten.permute_copy.default:
op, args = self._get_permute_replacement(args, meta)
case exir_ops.edge.aten.avg_pool2d.default:
op, args = self._get_avg_pool2d_replacement(args, meta)
case _:
pass

Expand Down
16 changes: 15 additions & 1 deletion backends/cortex_m/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.


from typing import Callable, List, Optional
from typing import Callable, cast, List, Optional

import torch
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
Expand Down Expand Up @@ -315,6 +315,7 @@ class SharedQspecQuantizer(Quantizer):
# Min/Max/Mean
torch.ops.aten.minimum.default,
torch.ops.aten.maximum.default,
torch.ops.aten.avg_pool2d.default,
# Data shuffling
torch.ops.aten.permute.default,
torch.ops.aten.permute_copy.default,
Expand Down Expand Up @@ -402,7 +403,20 @@ def _annotate_shared_cluster(self, root_node: Node) -> None:
mark_node_as_annotated(node, input_qspec_map, shared_qspec)

def annotate(self, model: GraphModule) -> None:
"""
Annotate shared quantization spec for supported ops, but skip avg_pool2d
when both ceil_mode and count_include_pad are True.
"""
for node in model.graph.nodes:
# TODO Skip avg_pool2d when ceil_mode=True or count_include_pad=True
# CMSIS-NN doesn't directly support this. But, it should be done.
if node.target is torch.ops.aten.avg_pool2d.default:
ceil_mode = cast(bool, node.args[4]) if len(node.args) > 4 else False
count_include_pad = (
cast(bool, node.args[5]) if len(node.args) > 5 else True
)
if ceil_mode or count_include_pad:
continue
if node.target in self.targets and not self._is_annotated(node):
self._annotate_shared_cluster(node)

Expand Down
99 changes: 99 additions & 0 deletions backends/cortex_m/test/ops/test_avg_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm.test.common import parametrize
from executorch.backends.cortex_m.test.tester import (
CortexMTester,
McuTestCase,
ramp_tensor,
)


class CortexMAvgPool2d(torch.nn.Module):
ops_before_transforms = {
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1,
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2,
}

ops_after_transforms = {
"executorch_exir_dialects_edge__ops_cortex_m_quantized_avg_pool2d_default": 1,
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1,
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
}

def __init__(
self, kernel_size, stride, padding=0, ceil_mode=False, count_include_pad=False
):
super().__init__()
self.pool = torch.nn.AvgPool2d(
kernel_size,
stride,
padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)

def forward(self, x): # noqa: D102
return self.pool(x)


# Prepare test cases: simple 2x2 pool on 4x4, and 3x3 stride 1 on 3x3
test_cases = {
"avgpool_2x2": McuTestCase(
CortexMAvgPool2d(kernel_size=2, stride=2), (ramp_tensor(0, 15, (1, 1, 4, 4)),)
),
"avgpool_3x3_s1": McuTestCase(
CortexMAvgPool2d(kernel_size=3, stride=1, padding=1),
(ramp_tensor(0, 8, (1, 1, 3, 3)),),
),
# additional pooling configurations: padding, stride, ceil_mode, count_include_pad
"avgpool_2x2_pad1": McuTestCase(
CortexMAvgPool2d(kernel_size=2, stride=2, padding=1),
(ramp_tensor(0, 24, (1, 1, 5, 5)),),
),
"avgpool_3x3_s2_pad1": McuTestCase(
CortexMAvgPool2d(kernel_size=3, stride=2, padding=1),
(ramp_tensor(0, 15, (1, 1, 4, 4)),),
),
}

test_cases_fp = {
"avgpool_3x3_s2_pad1_ceil": McuTestCase(
CortexMAvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True),
(ramp_tensor(0, 15, (1, 1, 4, 4)),),
),
"avgpool_3x3_s2_pad1_countinc": McuTestCase(
CortexMAvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=True),
(ramp_tensor(0, 15, (1, 1, 4, 4)),),
),
}


@parametrize("test_case", test_cases)
def test_dialect_avg_pool2d(test_case):
tester = CortexMTester(test_case.model, test_case.example_inputs)
tester.test_dialect(
test_case.model.ops_before_transforms,
test_case.model.ops_after_transforms,
qtol=1,
)


@parametrize("test_case", test_cases_fp)
def test_dialect_avg_pool2d_fp(test_case):
tester = CortexMTester(test_case.model, test_case.example_inputs)
tester.test_dialect(
{"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1},
{"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1},
qtol=1,
)


@parametrize("test_case", test_cases)
def test_implementation_avg_pool2d(test_case):
tester = CortexMTester(test_case.model, test_case.example_inputs)
tester.test_implementation(qtol=1)
3 changes: 2 additions & 1 deletion backends/cortex_m/test/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def __init__(self):
torch.ops.aten.hardsigmoid_.default,
torch.ops.aten.hardswish.default,
torch.ops.aten.hardswish_.default,
]
],
_check_ir_validity=False,
)
super().__init__(config)

Expand Down
Loading