From 5a6196f1d294e5c17315f50018e69caddfc3ef8a Mon Sep 17 00:00:00 2001 From: Gopalakrishnan Nallasamy Date: Tue, 16 Jun 2026 18:38:57 -0700 Subject: [PATCH] Avoid small MatMul batch parameter heap allocations --- onnxruntime/core/providers/cpu/math/matmul.cc | 75 ++++++++++++------- 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 6fa3e0d9a4827..7a4dbed2f4b68 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/math/matmul.h" +#include "core/common/inlined_containers.h" #include "core/providers/cpu/math/gemm_matmul_common.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/util/math.h" @@ -329,39 +330,57 @@ Status MatMul::Compute(OpKernelContext* ctx) const { } if (can_use_fastmath_sbgemm) { - std::vector data(max_len); - for (size_t i = 0; i < max_len; i++) { - data[i].BIsfp32 = !(bool(packed_b_)); - data[i].AIsfp32 = true; - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].B = data[i].BIsfp32 ? b_data + helper.RightOffsets()[i] : (float*)packed_b_.get(); - data[i].ldb = ldb; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; - data[i].Bias = nullptr; - data[i].OutputProcessor = nullptr; - data[i].BIsPacked = static_cast(packed_b_); + auto gemm_batch = [&](auto& data) { + for (size_t i = 0; i < max_len; i++) { + data[i].BIsfp32 = !(bool(packed_b_)); + data[i].AIsfp32 = true; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = data[i].BIsfp32 ? b_data + helper.RightOffsets()[i] : (float*)packed_b_.get(); + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].Bias = nullptr; + data[i].OutputProcessor = nullptr; + data[i].BIsPacked = static_cast(packed_b_); + } + MlasSBGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, + M, N, K, max_len, data.data(), thread_pool, &mlas_backend_kernel_selector_config_); + }; + + if (max_len <= 2) { + InlinedVector data(max_len); + gemm_batch(data); + } else { + std::vector data(max_len); + gemm_batch(data); } - MlasSBGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, - M, N, K, max_len, data.data(), thread_pool, &mlas_backend_kernel_selector_config_); } else #endif { - std::vector data(max_len); - for (size_t i = 0; i < max_len; i++) { - data[i].BIsPacked = bool(packed_b_); - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i]; - data[i].ldb = ldb; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; - data[i].alpha = alpha_attr_; - data[i].beta = 0.0f; + auto gemm_batch = [&](auto& data) { + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = bool(packed_b_); + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = alpha_attr_; + data[i].beta = 0.0f; + } + MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, + M, N, K, data.data(), max_len, thread_pool, &mlas_backend_kernel_selector_config_); + }; + + if (max_len <= 2) { + InlinedVector data(max_len); + gemm_batch(data); + } else { + std::vector data(max_len); + gemm_batch(data); } - MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, - M, N, K, data.data(), max_len, thread_pool, &mlas_backend_kernel_selector_config_); } return Status::OK(); }