Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,23 @@ static const char* const kOrtSessionOptionsModelExternalInitializersFileFolderPa
static const char* const kOrtSessionOptionsSavePrePackedConstantInitializers =
"session.save_external_prepacked_constant_initializers";

// Enables cross-session sharing of MatMulNBits pre-packed weights via an OrtPrepackedWeightsContainer,
// content-addressed by hash(packed_bytes) so weights with auto-generated names still deduplicate. This
// covers MatMulNBits weights synthesized at session-creation time (e.g. by the DQ + MatMul -> MatMulNBits
// fusion), whose names are not known ahead of time and so cannot be registered via OrtApi::AddInitializer.
//
// Scoped to MatMulNBits: content-addressed sharing is only safe when a kernel's packed bytes fully
// determine its Compute result, which MatMulNBits satisfies. Other CPU kernels are unaffected.
//
// Requires the session to be created via OrtApi::CreateSessionWithPrepackedWeightsContainer, with this
// option set consistently across all sharing sessions using the same container.
//
// - "0": Default. Only AddInitializer-registered initializers share pre-packed weights cross-session.
// - "1": Also share MatMulNBits pre-packed weights cross-session via the container.
// Sample usage: sess_options.add_session_config_entry(kOrtSessionOptionsShareMatMulNBitsPrepackedWeights, "1")
static const char* const kOrtSessionOptionsShareMatMulNBitsPrepackedWeights =
"session.share_matmulnbits_prepacked_weights";

// Use this config when you want to collect memory stats for each node in the graph.
// The file format is a CSV file with the following columns:
// The file will be created if it does not exist, and will be overwritten if it does.
Expand Down
201 changes: 155 additions & 46 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

Large diffs are not rendered by default.

25 changes: 19 additions & 6 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,15 @@ static std::string GenerateKeyForPrepackedWeightsMap(const std::string& op_type,
Status SessionState::PrepackConstantInitializedTensors(
InlinedHashMap<std::string, size_t>& constant_initializers_use_count,
const std::unordered_map<std::string, const OrtValue*>& initializers_to_share_map) {
auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map](
// When set, MatMulNBits pre-packed weights are content-addressed into the shared
// OrtPrepackedWeightsContainer for cross-session sharing. Needed for fusion-synthesized weights (e.g.
// DQ + MatMul -> MatMulNBits) whose auto-generated names can't be pre-registered via AddInitializer.
const bool share_matmulnbits_prepacked_weights =
sess_options_.config_options.GetConfigOrDefault(
kOrtSessionOptionsShareMatMulNBitsPrepackedWeights, "0") == "1";

auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map,
share_matmulnbits_prepacked_weights](
bool should_cache_prepacked_weights_for_shared_initializers) -> Status {
for (auto& node : GetGraphViewer().Nodes()) {
if (sess_options_.IsLoadCancellationFlagSet()) {
Expand Down Expand Up @@ -498,8 +506,15 @@ Status SessionState::PrepackConstantInitializedTensors(
auto iter = initializers_to_share_map.find(input_name);
bool is_shared_initializer = (iter != initializers_to_share_map.end());

// Caching pre-packed weights is limited to shared initializers associated with the CPU EP for now
if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers &&
// CPU EP only. By default only AddInitializer-registered initializers (is_shared_initializer)
// participate; share_matmulnbits_prepacked_weights also enrolls MatMulNBits weights,
// deduplicated content-addressed via hash(packed_bytes). Enrollment is restricted to
// MatMulNBits because content-addressed sharing is only safe when packed bytes fully
// determine Compute (which MatMulNBits satisfies); this also keeps the BUG CHECK below valid.
const bool enroll_matmulnbits_initializer =
share_matmulnbits_prepacked_weights && node.OpType() == "MatMulNBits";
if ((is_shared_initializer || enroll_matmulnbits_initializer) &&
should_cache_prepacked_weights_for_shared_initializers &&
node.GetExecutionProviderType() == kCpuExecutionProvider) {
// caching of pre-packed weights' turned ON

Expand Down Expand Up @@ -615,11 +630,9 @@ Status SessionState::PrepackConstantInitializedTensors(
is_packed,
&weights_to_be_filled_in));

// Some kernels (matmul_nbits and non-CPU related kernels) do not share their pre-packed results
// Some kernels (non-CPU related kernels) do not share their pre-packed results
// even though they set is_packed = true so we leave it up to them.
// We can change their behavior if we wish do so in a separate PR
// XXX: Interestingly enough, matmul_nbits does accept shared pre-packs, but does not
// produce them.
if (is_packed && !weights_to_be_filled_in.buffers_.empty()) {
const auto& op_type = node.OpType();
const std::string prepacked_weights_container_key = GenerateKeyForPrepackedWeightsMap(
Expand Down
115 changes: 115 additions & 0 deletions onnxruntime/test/contrib_ops/matmul_2bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#ifndef ORT_MINIMAL_BUILD

#include <filesystem>
#include <optional>

#include "gtest/gtest.h"
Expand All @@ -26,6 +27,9 @@
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/ort_env.h"
#include "core/util/qmath.h"
#include "core/graph/model.h"
#include "test/util/include/inference_session_wrapper.h"
#include "test/util/include/test/test_environment.h"
#include "core/providers/webgpu/webgpu_provider_options.h"
#ifdef USE_WEBGPU
#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h"
Expand Down Expand Up @@ -461,6 +465,117 @@ TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_Batch32_256x256_Bias) {
TestMatMul2BitsLutGemm<float>(32, 256, 256, 32, /*has_zero_point=*/true, /*has_bias=*/true);
}

// Regression test for the LUT GEMM pre-pack + prepacked-save path. A 2-bit MatMulNBits node pre-packed
// via the LUT path must record its packed B buffer exactly once. A prior bug appended packed_b_ twice
// on the LUT path (inside the LUT branch and again in the shared append at the end of the B block), so
// the second entry was a moved-from/null buffer paired with a non-zero packed_b_size_. The pre-packed
// content hash skips null buffers, so cross-session sharing appeared to work, but saving pre-packed
// initializers iterates every recorded buffer and writes buffer_sizes_[i] bytes from buffers_[i].get(),
// dereferencing the null pointer when mlas.use_lut_gemm=1. This drives mlas.use_lut_gemm=1 together with
// session.save_external_prepacked_constant_initializers=1 and a non-empty optimized_model_filepath, and
// asserts that initialization (which performs the save) and a subsequent run both succeed.
TEST(MatMulNBitsLutGemm, Float32_2Bits_PrepackSaveDoesNotCrash) {
constexpr int64_t M = 1, N = 128, K = 128, block_size = 32;
if (!MlasIsLutGemmAvailable(static_cast<size_t>(N), static_cast<size_t>(K), 2, static_cast<size_t>(block_size))) {
GTEST_SKIP() << "LUT GEMM not available on this platform";
}

// Quantize random weights into valid 2-bit MatMulNBits B/scales/zero_points initializers.
RandomValueGenerator random{1234};
std::vector<float> b_fp32(random.Gaussian<float>(AsSpan({K, N}), 0.0f, 0.25f));

int q_rows = 0, q_cols = 0;
MlasBlockwiseQuantizedShape<float, QBits>(static_cast<int>(block_size), /*columnwise*/ true,
static_cast<int>(K), static_cast<int>(N), q_rows, q_cols);
size_t q_data_size_in_bytes = 0, q_scale_size = 0, q_zp_size_in_bytes = 0;
MlasBlockwiseQuantizedBufferSizes<QBits>(static_cast<int>(block_size), /*columnwise*/ true,
static_cast<int>(K), static_cast<int>(N),
q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes);

std::vector<uint8_t> b_data(q_data_size_in_bytes);
std::vector<float> scales(q_scale_size);
std::vector<uint8_t> zp(q_zp_size_in_bytes);

auto& ortenv = **ort_env.get();
onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool();
MlasQuantizeBlockwise<float, QBits>(b_data.data(), scales.data(), zp.data(), b_fp32.data(),
static_cast<int32_t>(block_size), /*columnwise*/ true,
static_cast<int32_t>(K), static_cast<int32_t>(N),
static_cast<int32_t>(N), tp);

// Single-node MatMulNBits model: A is a runtime input; B/scales/zero_points are constant initializers
// (so they are pre-packed at session initialization).
const int64_t k_blocks = (K + block_size - 1) / block_size;
const std::unordered_map<std::string, int> domain_to_version{{"", 21}, {kMSDomain, 1}};
Model model("matmul_2bits_lut_prepack_save", false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version,
std::vector<ONNX_NAMESPACE::FunctionProto>(), DefaultLoggingManager().DefaultLogger());
Graph& graph = model.MainGraph();
ModelTestBuilder builder(graph);

ONNX_NAMESPACE::TypeProto float_2d;
float_2d.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType<float>());
float_2d.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(M);
float_2d.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(K);
NodeArg* A = &graph.GetOrCreateNodeArg("A", &float_2d);
NodeArg* Y = &graph.GetOrCreateNodeArg("Y", nullptr);

NodeArg* B = builder.MakeInitializer<uint8_t>(
{static_cast<int64_t>(q_cols), k_blocks, static_cast<int64_t>(q_rows) / k_blocks}, b_data);
NodeArg* scales_arg = builder.MakeInitializer<float>({N, static_cast<int64_t>(q_scale_size) / N}, scales);
NodeArg* zero_points =
builder.MakeInitializer<uint8_t>({N, static_cast<int64_t>(q_zp_size_in_bytes) / N}, zp);

Node& node = builder.AddNode("MatMulNBits", {A, B, scales_arg, zero_points}, {Y}, kMSDomain);
node.AddAttribute("K", K);
node.AddAttribute("N", N);
node.AddAttribute("block_size", block_size);
node.AddAttribute("bits", static_cast<int64_t>(QBits));
node.AddAttribute("accuracy_level", static_cast<int64_t>(0));

graph.SetOutputs(std::vector<const NodeArg*>{Y});
ASSERT_STATUS_OK(graph.Resolve());

std::string model_bytes;
ASSERT_TRUE(model.ToProto().SerializeToString(&model_bytes));

// Save the optimized model + pre-packed initializers into a unique temp dir. Writing the prepacked
// initializers is the path that dereferenced the duplicate null buffer before the fix.
namespace fs = std::filesystem;
const fs::path tmp_dir = fs::temp_directory_path() / "ort_matmul2bits_lut_prepack_save_test";
std::error_code ec;
fs::remove_all(tmp_dir, ec);
ASSERT_TRUE(fs::create_directories(tmp_dir, ec)) << ec.message();
const fs::path optimized_model_path = tmp_dir / "optimized.onnx";

SessionOptions so;
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsMlasLutGemm, "1"));
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsSavePrePackedConstantInitializers, "1"));
so.optimized_model_filepath = optimized_model_path.native();

std::vector<OrtValue> fetches;
{
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_bytes.data(), static_cast<int>(model_bytes.size())));
// Initialization performs the LUT pre-pack and writes the optimized model with external
// pre-packed initializers. Before the fix this dereferenced the duplicate null packed buffer.
ASSERT_STATUS_OK(session.Initialize());

auto cpu_allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0];
std::vector<float> a_data = random.Gaussian<float>(AsSpan({M, K}), 0.0f, 0.25f);
OrtValue a_value;
CreateMLValue<float>(cpu_allocator, AsSpan({M, K}), a_data, &a_value);
NameMLValMap feeds{{"A", a_value}};

ASSERT_STATUS_OK(session.Run(RunOptions{}, feeds, std::vector<std::string>{"Y"}, &fetches));
}

ASSERT_EQ(fetches.size(), static_cast<size_t>(1));
EXPECT_TRUE(fs::exists(optimized_model_path));

fs::remove_all(tmp_dir, ec);
}

// Float zero point tests — directed QAD scenario (zp=1.5)
void RunTest2BitsFloatZP(int64_t M, int64_t N, int64_t K, int64_t block_size, float zp_value) {
RandomValueGenerator random{1234};
Expand Down
Loading
Loading