Skip to content

[CUDA] Fuse MoE router bias into MatMulNBits GEMV#29170

Open
tianleiwu wants to merge 10 commits into
mainfrom
tlwu/20260619/qmoe_router_bias_gemv
Open

[CUDA] Fuse MoE router bias into MatMulNBits GEMV#29170
tianleiwu wants to merge 10 commits into
mainfrom
tlwu/20260619/qmoe_router_bias_gemv

Conversation

@tianleiwu

@tianleiwu tianleiwu commented Jun 20, 2026

Copy link
Copy Markdown
Contributor

Description

Fuse the MoE router MatMulNBits + Add([32] bias) pattern into the CUDA MatMulNBits router GEMV path.

This PR keeps the public surface conservative:

  • no QMoE op schema change;
  • no new router/top-k QMoE inputs;
  • the optimized path is exact-shape gated to the GPT-OSS router projection: M=1, N=32, K=2880, 4-bit weights, block_size=32, no zero points;
  • all other MatMulNBits shapes continue to use the existing generic path;
  • ORT_DISABLE_QMOE_ROUTER_GEMV_SPECIALIZATION=1 disables the exact router GEMV specialization;
  • ORT_DISABLE_QMOE_ROUTER_BIAS_FUSION=1 disables only the graph rewrite that folds the router bias into MatMulNBits.

Motivation and Context

GPT-OSS-20B decode runs a tiny router projection before each QMoE node. The router projection is an exact-shape int4 MatMulNBits, followed by a [32] bias add before QMoE consumes the router logits.

The existing generic int4 GEMV works, but this router shape is small enough that specializing it reduces router GEMV overhead. Once that specialization is active, folding the [32] bias into the same kernel removes the remaining router-side Add launch without changing the QMoE op contract.

Key Changes

  • Adds an exact-shape CUDA router GEMV specialization in MatMulFloatInt4RouterKernel.
  • Extends the CUDA MatMulNBits path to pass an optional bias pointer to the router specialization.
  • Extends MatMulNBitsFusion to rewrite the exact GPT-OSS router MatMulNBits + Add chain into biased MatMulNBits.
  • Keeps the transformer registration compatible with the current origin/main WebGPU kernel-gated MatMulNBits fusion logic.
  • Adds graph transformer and MatMul4Bits coverage for the specialization, fallback, and bias-fusion opt-out behavior.
  • Records the router GEMV and router bias fusion measurements in the QMoE GEMV experiment log.

Validation

Completed locally on the clean PR branch:

  • lintrunner -a docs/contrib_ops/cuda/qmoe_gemv_experiments.md onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh onnxruntime/core/optimizer/graph_transformer_utils.cc onnxruntime/core/optimizer/matmul_nbits_fusion.cc onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc onnxruntime/test/contrib_ops/matmul_4bits_test.cc onnxruntime/test/optimizer/graph_transform_test.cc
  • git diff --check
  • git diff --cached --check

Previously collected on the experiment branch before preparing this PR branch:

  • Graph transformer tests for router GEMV/bias fusion passed.
  • MatMul4Bits provider coverage for router GEMV specialization/fallback passed.
  • Nsight confirmed the exact router specialization dispatches for GPT-OSS decode router projections.
  • CUDA-graph GPT-OSS decode A/B showed the router GEMV specialization at about +1.6% to +1.8% throughput.
  • Router bias fusion removed all 24 real GPT-OSS router bias Add nodes and measured about +0.2% throughput after the router GEMV specialization.

Compiled C++ tests were not rerun from this new worktree because it does not have a configured build directory; CI should provide the full compiled validation matrix.

@tianleiwu tianleiwu force-pushed the tlwu/20260619/qmoe_router_bias_gemv branch from 651acb5 to 2a2ce83 Compare June 20, 2026 00:31
@tianleiwu tianleiwu requested a review from Copilot June 20, 2026 00:47

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an exact-shape CUDA fast path for the GPT-OSS router MatMulNBits GEMV (M=1, N=32, K=2880, int4, block_size=32, no zero-points) and extends it to optionally apply the [32] bias in-kernel, plus a graph rewrite to fuse the corresponding Add into MatMulNBits under tight shape/EP gating and env-var opt-outs.

Changes:

  • Adds a CUDA router GEMV specialization (MatMulFloatInt4RouterKernel) and threads an optional bias_data pointer through the TryMatMulNBits/TryMatMul4Bits dispatch.
  • Updates MatMulNBitsFusion + transformer registration to enable CUDA-side bias fusion only for the exact router projection shape and when not opted out via env vars.
  • Adds/extends unit tests and experiment documentation to cover specialization dispatch, fallback, and bias-fusion opt-out behavior.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
onnxruntime/test/optimizer/graph_transform_test.cc Extends transformer tests to cover CUDA EP + exact router shape, and adds opt-out / unsupported-shape cases.
onnxruntime/test/contrib_ops/matmul_4bits_test.cc Adds CUDA provider coverage for the exact router shape (specialized vs fallback) and biased GEMV.
onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc Updates a MatMulNBits CUDA dispatch call site for the new bias pointer parameter.
onnxruntime/core/optimizer/matmul_nbits_fusion.cc Adds CUDA-only gating for bias fusion and env-var opt-outs for the exact GPT-OSS router shape.
onnxruntime/core/optimizer/graph_transformer_utils.cc Enables MatMulNBitsFusion for CUDA EP by expanding compatible EPs.
onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh Extends CUDA TryMatMul* helpers to accept an optional bias pointer; disallows bias for int8.
onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc Wires bias input through to CUDA TryMatMulNBits and tightens the unsupported-bias error path.
onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu Implements the exact-shape router GEMV kernel and bias-in-kernel support; dispatches under env-var control.
docs/contrib_ops/cuda/qmoe_gemv_experiments.md Records benchmarking/validation details for the specialization and bias fusion.

Comment thread onnxruntime/core/optimizer/matmul_nbits_fusion.cc Outdated
The CUDA MatMulNBits kernel only accepts a fused bias on the exact router
GEMV fast path (M==1). The fusion gate previously checked only N/K/bits/
block_size attributes, so a MatMulNBits+Add with matching attributes but
M>1 (or dynamic M) could be fused and then throw at runtime. Gate CUDA
bias fusion on A having a statically-known M==1, and add a regression test.
K and N became runtime const int64_t (router shape selection), so passing
them to MlasBlockwiseQuantizedShape/BufferSizes (int params) now triggers
-Werror,-Wshorten-64-to-32 on macOS/arm64 and Android. Cast to int.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.

Comment thread onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Comment thread onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Comment thread onnxruntime/test/optimizer/graph_transform_test.cc Outdated
Comment thread onnxruntime/test/optimizer/graph_transform_test.cc Outdated
Comment thread onnxruntime/test/optimizer/graph_transform_test.cc Outdated
Comment thread onnxruntime/test/optimizer/graph_transform_test.cc Outdated
Comment thread onnxruntime/test/optimizer/graph_transform_test.cc Outdated
Comment thread onnxruntime/test/optimizer/graph_transform_test.cc Outdated
…ests

- Qualify optional<std::string> as std::optional<std::string> in matmul_4bits_test.cc (no using namespace std there).
- Add explicit static_cast<int> for K/N passed to MlasBlockwiseQuantizedShape and MlasBlockwiseQuantizedBufferSizes to avoid int64->int narrowing warnings-as-errors.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.

Comment thread onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc Outdated
@tianleiwu tianleiwu marked this pull request as ready for review June 20, 2026 23:50
@tianleiwu tianleiwu changed the title [CUDA] Fuse GPT-OSS router bias into MatMulNBits GEMV [CUDA] Fuse MoE router bias into MatMulNBits GEMV Jun 21, 2026
@tianleiwu tianleiwu requested a review from Copilot June 22, 2026 16:33

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.

Comment thread onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu
Comment thread onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu Outdated
…efactor unroll

- IsSupportedRouterGemvShape now matches only the GPT-OSS-20B router shape (M=1, N=32, K=2880), matching the PR's stated exact-shape scope; remove untested Qwen3/Gemma shapes.
- LaunchMatMulNBitsBiasAdd caps the grid at the CUDA gridDim.x limit and the kernel uses a grid-stride loop, avoiding silent truncation when casting to unsigned int.
- Replace the RouterUnRollReduction macro with a templated __device__ RouterUnrollReduction helper.
- Mark router dispatch dim3 blocks/threads const.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants