[CPU] Enable pre-packed weights sharing for MatMulNBits#29163
[CPU] Enable pre-packed weights sharing for MatMulNBits#29163derdeljan-msft wants to merge 3 commits into
Conversation
tianleiwu
left a comment
There was a problem hiding this comment.
Review summary
Thanks for enabling cross-session pre-packed weight sharing for MatMulNBits — the motivation (prefill + decode sessions sharing one in-memory copy of the packed weights) is clear and the share_all_prepacked_cpu_initializers opt-in plus content-addressing is a reasonable approach. The test matrix (4-bit/8-bit, fp32/fp16, symmetric/asymmetric, +/- bias, multiple block sizes and accuracy levels, plus an AddInitializer path and a negative control) is thorough and the new shared test helper keeps the two test files DRY.
Main concern — the cache key is computed from a partially-packed buffer:
The per-B cache key is produced by GenerateKeyForPrepackedWeightsMap() immediately after PrePack(input_idx == B) returns, but the scales and zero_points PrePack calls subsequently mutate that same buffer in place (the MlasQNBitGemmPackQuantBData(..., packed_b_.get(), ...) calls in the scales/zero_points branches). At hash time the zero-point region is still the zeroed placeholder and blksum was computed with no zero point (zp passed as nullptr during B packing). So the hash does not reflect the final packed bytes — specifically it does not capture zero_points.
Consequence under share_all: two CPU MatMulNBits initializers with byte-identical quantized B and identical scales but different zero_points would collide on the same key; the second one adopts the first's already-finalized buffer and, because packed_b_is_shared_ becomes true, skips packing its own zero points — silently producing a wrong result. For the intended prefill/decode same-model case the weights are identical so this never triggers, but share_all makes every initializer eligible and widens the collision surface beyond the old AddInitializer-only path. Could you confirm the B-only hash uniquely determines the fully-finalized buffer (i.e. that no post-hash packing step can differ between two initializers that hash equal), or compute the key after all packing for the node completes / fold zero_points into the hashed content? Inline detail below.
Minor notes:
- The
std::memset(packed_b_.get(), 0, packed_b_size_)runs for every CompInt8PrePackeven when noOrtPrepackedWeightsContaineris configured (sharing disabled). The zero-fill is only needed for hash stability, so it's wasted (one-time, session-init) work on the non-sharing path; consider gating it onprepacked_weights != nullptr. - On the ARM64/HQNBIT paths a
nullptrplaceholder is pushed forscales, andPrePackedWeights::GetHash()skips null buffers — so thescalescontainer key isop_type + hash-of-nothing, identical for everyMatMulNBitsnode. It's benign (no real buffer is shared), but it does incrementused_shared_pre_packed_weights_counter_for unrelated nodes, which could be mistaken for real scale sharing.
Verdict: COMMENT — primarily to confirm the hashing assumption above before merge.
MatMulNBits' CompInt8 (accuracy_level 4) packing is staged and stateful: PrePack(B) packs the quantized weights and accumulates a partial block sum into the buffer, then PrePack(scales)/PrePack(zero_points) consume that state to finalize it. MLAS requires each step to run exactly once per buffer (see SQ8BitGemmPackQuantBDataAndBlkSum). Cross-session pre-packed weight sharing broke this contract: the second session adopts the buffer the first session already finalized and then re-runs PrePack(scales)/PrePack(zero_points) on it, finalizing a second time over already-folded data. That corrupts the block-sum correction and produces wrong results. It reproduces on Linux ARM64, where ArmNeonIsQuantActivationsUnsigned selects the stateful correction path, and is latent in the AVX2/AVX512 packers that use the same design. Track the buffer each instance packs, and in UseSharedPrePackedBuffers detect when the buffer handed back came from another session (it differs from the one this instance packed) and skip the staged scale/zero-point re-pack. The first session and the non-sharing path adopt their own buffer and are unchanged; only the redundant re-pack in later sessions is removed. All changes are in PrePack/UseSharedPrePackedBuffers, so inference and the single-session path are unaffected.
6da9ed9 to
afb93d0
Compare
Description
Enable pre-packed weights sharing for
MatMulNBitsoperator on CPU. Weight sharing has two modes:AddInitializer)session.share_all_prepacked_cpu_initializersoption is configured. The reason for this option is because the reference to the original weight files is lost during weights prepacking for MatMulNBits. Entries are matched based on the weights hash.Motivation and Context
For executing ASG SLMs on CPU - there are two sessions, one for prefill stage and for decode stage (due to different shapes and session options). With this change, storing the weights in memory twice is avoided. The first sessions pre-packs the weights which the second session can reuse.
Confirmed memory reduction through the WPA memory traces.