[CUDA] Fuse MoE router bias into MatMulNBits GEMV#29170
Open
tianleiwu wants to merge 10 commits into
Open
Conversation
651acb5 to
2a2ce83
Compare
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds an exact-shape CUDA fast path for the GPT-OSS router MatMulNBits GEMV (M=1, N=32, K=2880, int4, block_size=32, no zero-points) and extends it to optionally apply the [32] bias in-kernel, plus a graph rewrite to fuse the corresponding Add into MatMulNBits under tight shape/EP gating and env-var opt-outs.
Changes:
- Adds a CUDA router GEMV specialization (
MatMulFloatInt4RouterKernel) and threads an optionalbias_datapointer through theTryMatMulNBits/TryMatMul4Bitsdispatch. - Updates
MatMulNBitsFusion+ transformer registration to enable CUDA-side bias fusion only for the exact router projection shape and when not opted out via env vars. - Adds/extends unit tests and experiment documentation to cover specialization dispatch, fallback, and bias-fusion opt-out behavior.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/optimizer/graph_transform_test.cc | Extends transformer tests to cover CUDA EP + exact router shape, and adds opt-out / unsupported-shape cases. |
| onnxruntime/test/contrib_ops/matmul_4bits_test.cc | Adds CUDA provider coverage for the exact router shape (specialized vs fallback) and biased GEMV. |
| onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc | Updates a MatMulNBits CUDA dispatch call site for the new bias pointer parameter. |
| onnxruntime/core/optimizer/matmul_nbits_fusion.cc | Adds CUDA-only gating for bias fusion and env-var opt-outs for the exact GPT-OSS router shape. |
| onnxruntime/core/optimizer/graph_transformer_utils.cc | Enables MatMulNBitsFusion for CUDA EP by expanding compatible EPs. |
| onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh | Extends CUDA TryMatMul* helpers to accept an optional bias pointer; disallows bias for int8. |
| onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc | Wires bias input through to CUDA TryMatMulNBits and tightens the unsupported-bias error path. |
| onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu | Implements the exact-shape router GEMV kernel and bias-in-kernel support; dispatches under env-var control. |
| docs/contrib_ops/cuda/qmoe_gemv_experiments.md | Records benchmarking/validation details for the specialization and bias fusion. |
The CUDA MatMulNBits kernel only accepts a fused bias on the exact router GEMV fast path (M==1). The fusion gate previously checked only N/K/bits/ block_size attributes, so a MatMulNBits+Add with matching attributes but M>1 (or dynamic M) could be fused and then throw at runtime. Gate CUDA bias fusion on A having a statically-known M==1, and add a regression test.
K and N became runtime const int64_t (router shape selection), so passing them to MlasBlockwiseQuantizedShape/BufferSizes (int params) now triggers -Werror,-Wshorten-64-to-32 on macOS/arm64 and Android. Cast to int.
…ests - Qualify optional<std::string> as std::optional<std::string> in matmul_4bits_test.cc (no using namespace std there). - Add explicit static_cast<int> for K/N passed to MlasBlockwiseQuantizedShape and MlasBlockwiseQuantizedBufferSizes to avoid int64->int narrowing warnings-as-errors.
…efactor unroll - IsSupportedRouterGemvShape now matches only the GPT-OSS-20B router shape (M=1, N=32, K=2880), matching the PR's stated exact-shape scope; remove untested Qwen3/Gemma shapes. - LaunchMatMulNBitsBiasAdd caps the grid at the CUDA gridDim.x limit and the kernel uses a grid-stride loop, avoiding silent truncation when casting to unsigned int. - Replace the RouterUnRollReduction macro with a templated __device__ RouterUnrollReduction helper. - Mark router dispatch dim3 blocks/threads const.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Fuse the MoE router
MatMulNBits + Add([32] bias)pattern into the CUDAMatMulNBitsrouter GEMV path.This PR keeps the public surface conservative:
M=1,N=32,K=2880, 4-bit weights,block_size=32, no zero points;MatMulNBitsshapes continue to use the existing generic path;ORT_DISABLE_QMOE_ROUTER_GEMV_SPECIALIZATION=1disables the exact router GEMV specialization;ORT_DISABLE_QMOE_ROUTER_BIAS_FUSION=1disables only the graph rewrite that folds the router bias intoMatMulNBits.Motivation and Context
GPT-OSS-20B decode runs a tiny router projection before each QMoE node. The router projection is an exact-shape int4
MatMulNBits, followed by a[32]bias add beforeQMoEconsumes the router logits.The existing generic int4 GEMV works, but this router shape is small enough that specializing it reduces router GEMV overhead. Once that specialization is active, folding the
[32]bias into the same kernel removes the remaining router-sideAddlaunch without changing the QMoE op contract.Key Changes
MatMulFloatInt4RouterKernel.MatMulNBitspath to pass an optional bias pointer to the router specialization.MatMulNBitsFusionto rewrite the exact GPT-OSS routerMatMulNBits + Addchain into biasedMatMulNBits.origin/mainWebGPU kernel-gated MatMulNBits fusion logic.Validation
Completed locally on the clean PR branch:
lintrunner -a docs/contrib_ops/cuda/qmoe_gemv_experiments.md onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh onnxruntime/core/optimizer/graph_transformer_utils.cc onnxruntime/core/optimizer/matmul_nbits_fusion.cc onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc onnxruntime/test/contrib_ops/matmul_4bits_test.cc onnxruntime/test/optimizer/graph_transform_test.ccgit diff --checkgit diff --cached --checkPreviously collected on the experiment branch before preparing this PR branch:
+1.6%to+1.8%throughput.Addnodes and measured about+0.2%throughput after the router GEMV specialization.Compiled C++ tests were not rerun from this new worktree because it does not have a configured build directory; CI should provide the full compiled validation matrix.