Skip to content

[CUDA] Enable CUDA GQA QK-Norm and XQA decode#29186

Open
tianleiwu wants to merge 10 commits into
mainfrom
tlwu/rmsnorm_gqa
Open

[CUDA] Enable CUDA GQA QK-Norm and XQA decode#29186
tianleiwu wants to merge 10 commits into
mainfrom
tlwu/rmsnorm_gqa

Conversation

@tianleiwu

@tianleiwu tianleiwu commented Jun 20, 2026

Copy link
Copy Markdown
Contributor

Description

Adds CUDA support for GroupQueryAttention QK-Norm by applying per-head Q/K RMSNorm before RoPE in the fused preprocess path. It also enables the pre-norm graph fusion for CUDA and allows non-quantized QK-Norm decode to use XQA, restoring the fast global decode path for GPT-OSS/Qwen-style shapes while keeping quantized-cache QK-Norm on the existing fallback path until scale handling is validated.

Summary of Changes

CUDA GroupQueryAttention

  • Threads q_norm_weight / k_norm_weight and qk_norm_epsilon through CUDA GQA data/parameters.
  • Applies FP32 per-head RMSNorm to Q/K in UnpackRoPEAppend before RoPE and KV append.
  • Adds shared-KV Q-only normalization support.
  • Enables non-quantized QK-Norm decode to route through XQA after the fused preprocess normalizes Q/K.
  • Keeps quantized-cache QK-Norm decode gated off XQA pending normalized-K scale validation.

Fusion and Schemas

  • Enables GroupQueryAttentionPreNormFusion for CUDA and native WebGPU.
  • Updates contrib operator schema text and generated ContribOperators.md to document CUDA/native WebGPU QK-Norm support.
  • Updates CPU/JSEP rejection text for unsupported providers.

Tests, Docs, and Profiling

  • Adds CUDA optimizer coverage for the pre-norm fusion.
  • Adds Python GQA QK-Norm parity coverage, including explicit FP16/BF16 XQA decode tests.
  • Extends GQA profiling helpers with QK-Norm options and documents CUDA GQA behavior in docs/contrib_ops/cuda/gqa.md.

Testing

  • Built: ninja onnxruntime_providers_cuda onnxruntime_test_all in build/cu130/Release.
  • Ran: ./onnxruntime_test_all --gtest_filter="GraphTransformationTests.GroupQueryAttentionPreNormFusion*" (11 passed, 2 WebGPU skips).
  • Ran: python -m pytest test_gqa.py::TestGQAQKNorm::test_gqa_qk_norm_past_xqa test_gqa.py::TestGQAQKNorm::test_gqa_qk_norm_past_xqa_bf16 -q (2 passed).
  • Ran: python -m pytest test_gqa.py -k "QKNorm" -q (38 passed).
  • Ran: git diff --check.
  • Verified routing with ORT_ENABLE_XQA=1 ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1: FP16 and BF16 QK-Norm decode report SdpaKernel=XQA.
  • Profiled GPT-OSS-like packed FP16 shape (B=1,S=1,past=2048,N=64,Nkv=8,H=64,head_sink,QK-Norm) with nsys: H64::grp8_fp16::kernel_mha averaged ~8.21 us and UnpackRoPEAppend<half, half, 16, 64> averaged ~2.94 us.

Checklist

  • Tests added/updated
  • Documentation updated
  • No breaking changes
  • CI passes

@tianleiwu tianleiwu changed the title Enable CUDA GQA QK-Norm and XQA decode [CUDA] Enable CUDA GQA QK-Norm and XQA decode Jun 20, 2026

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR expands com.microsoft::GroupQueryAttention on the CUDA EP to support fused per-head Q/K RMSNorm (QK-Norm) in the preprocess path (before RoPE), and restores the fast XQA decode route for non-quantized QK-Norm decode shapes. It also enables the pre-norm fusion pass for CUDA (previously WebGPU-only), updates operator/schema docs, and adds test/profiling coverage for the new routing and parity behavior.

Changes:

  • Add CUDA QK-Norm plumbing and kernels (fused in UnpackRoPEAppend, plus a standalone Q-only RMSNorm path for shared-KV decode).
  • Enable GroupQueryAttentionPreNormFusion for CUDA and add optimizer + Python parity tests (incl. explicit XQA decode checks for FP16/BF16).
  • Update profiling helpers and move/extend CUDA GQA documentation.

Reviewed changes

Copilot reviewed 20 out of 20 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
onnxruntime/test/python/transformers/test_gqa.py Adds QK-Norm config knobs, reference RMSNorm, QK-Norm parity tests, and XQA decode parity coverage.
onnxruntime/test/python/transformers/profile_gqa.sh Extends CLI to toggle QK-Norm and improves nsys handling + compare mode.
onnxruntime/test/python/transformers/profile_gqa.py Threads QK-Norm args through config and NVTX ranges for profiling.
onnxruntime/test/python/transformers/gqa_test_helper.py Adds QK-Norm inputs/attrs to the helper model/config and random feed generation.
onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc Expands fusion tests to cover CUDA-compatible fusion registration and CUDA EP assignment.
onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Updates CPU/CUDA contract tests for QK-Norm weight inputs.
onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h Updates fusion docs to reflect CUDA+WebGPU support.
onnxruntime/core/optimizer/graph_transformer_utils.cc Registers the pre-norm fusion transformer for CUDA + WebGPU.
onnxruntime/core/graph/contrib_ops/bert_defs.cc Updates schema text to document CUDA+native WebGPU honoring QK-Norm weights.
onnxruntime/contrib_ops/cuda/bert/group_query_attention.h Adds CUDA kernel member for qk_norm_epsilon_.
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Accepts/validates QK-Norm weights, threads epsilon/flags, adjusts XQA and flash-decode routing gates.
onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh Implements fused per-head RMSNorm in UnpackRoPEAppend and threads weights/epsilon through launch chain.
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h Updates buffer sizing requirements when QK-Norm requires a materialized Q scratch buffer.
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu Adds standalone per-head RMSNorm kernel for Q-only shared-KV decode and integrates QK-Norm into PrepareQKV/preprocess calls.
onnxruntime/contrib_ops/cuda/bert/attention_data.h Adds QK-Norm weight pointers + epsilon to GroupQueryAttentionData.
onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Updates rejection text to reflect CUDA+WebGPU support.
onnxruntime/contrib_ops/cpu/bert/attention_parameters.h Extends GroupQueryAttentionParameters with use_qk_norm + qk_norm_epsilon.
docs/ContribOperators.md Updates generated schema docs to list CUDA+native WebGPU support for QK-Norm.
docs/contrib_ops/gqa.md Removes the old (top-level) GQA doc in favor of a CUDA-specific doc path.
docs/contrib_ops/cuda/gqa.md Adds CUDA-specific GQA documentation including QK-Norm behavior, dispatch rules, profiling, and testing.

Comment thread onnxruntime/test/python/transformers/test_gqa.py Outdated
Comment thread onnxruntime/test/python/transformers/test_gqa.py Outdated
Comment thread docs/contrib_ops/cuda/gqa.md
Comment thread docs/contrib_ops/cuda/gqa.md Outdated
Comment thread docs/contrib_ops/cuda/gqa.md Outdated
Comment thread docs/contrib_ops/cuda/gqa.md

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 20 out of 20 changed files in this pull request and generated 3 comments.

Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh Outdated
Comment thread onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc Outdated
Comment thread onnxruntime/test/python/transformers/test_gqa.py Outdated
- group_query_attention_qkv.cuh: reduce QK-Norm sum once in tid==0 and broadcast inv_rms via shared memory to avoid redundant O(blockDim.x^2) shared reads.
- pre_norm_fusion_test.cc: comment now says CUDA+WebGPU fusion path (test runs both).
- test_gqa.py: TestGQAQKNorm now gated on has_cuda_device(80) with an accurate skip message instead of the misleading Flash Attention check.
@yuslepukhin

yuslepukhin commented Jun 22, 2026

Copy link
Copy Markdown
Member

Testing in Release builds is insufficient as it runs with debug diagnostics disabled.

Non blockers. Missing negative tests:

  • No test for qk_norm_epsilon <= 0 (could produce NaN via rsqrtf). The attribute is read with GetAttrOrDefault("qk_norm_epsilon", 1e-6f) — if a model explicitly sets a negative epsilon, no validation rejects it.
  • No test for mismatched q_norm_weight/k_norm_weight shapes (one is (head_size,), other is (head_size+1,)) — though this is covered implicitly by the shape validation returning an error.
  • No linked GitHub issue for tracking/discoverability.

@titaiwangms

Copy link
Copy Markdown
Contributor

Multi-reviewer synthesis — PR #29186 [CUDA] Enable CUDA GQA QK-Norm and XQA decode

Reviewed by 5 specialized agents (readability, code, critical/adversarial, deep-semantic, integration). Net: solid, faithful implementation — the CUDA RMSNorm math matches the in-repo Python reference and the schema/WebGPU semantics, the multi-file invariants (schema attr, doc, attention_parameters.h, attention_data.h, kernel) all agree, and the XQA gating correctly keeps quantized-cache QK-Norm off the XQA path. No Critical defects. A couple of Major hardening/correctness gaps are worth addressing before merge.

Major

  1. qk_norm_epsilon is accepted without finite/positive validation. group_query_attention.cc reads qk_norm_epsilon_ = info.GetAttrOrDefault<float>("qk_norm_epsilon", 1e-6f) with no guard, then it flows directly into rsqrtf(sumsq/head_size + epsilon) in both the fused (group_query_attention_qkv.cuh) and standalone (group_query_attention_impl.cu) kernels. Adversarial/malformed model attributes fail open: epsilon == 0 on an all-zero Q/K row gives rsqrt(0)=inf → 0*inf=NaN; epsilon < 0 can make the radicand negative (NaN); NaN/Inf propagates. In GroupQueryAttentionPreNormFusion, fabs(q_eps - k_eps) > tol does not reject a NaN epsilon (NaN compares false), so it fuses silently. Suggest validating std::isfinite(qk_norm_epsilon_) && qk_norm_epsilon_ > 0 at kernel init (and rejecting non-finite epsilons in the fusion before fabs). (code + critical reviewers agree; confirmed no validation exists.)

  2. DML EP silently ignores q_norm_weight/k_norm_weight. The schema text this PR updates now mandates that EPs "must reject the node when this input is set." CPU and JSEP fail-fast; DmlOperatorGroupQueryAttention.cpp:34 only checks GetInputCount() >= 1 and never inspects inputs 14/15, so a model carrying QK-Norm weights run on (or falling back to) DML would execute without normalization and return wrong results, no warning. The new pre-norm fusion only targets {kCuda, kWebGpu}, so it won't create such a node for DML — this is a latent/pre-existing gap — but since this PR formalizes the contract, adding an explicit reject guard in the DML constructor is the right closure. (integration reviewer; verified.)

Minor

  • JSEP rejection message is now stale. js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts:338 still says the prologue is "implemented only on the native WebGPU EP." CUDA now implements it too — update to "CUDA and native WebGPU EPs" to match the CPU EP text this PR already fixed.
  • Shared-KV standalone RMSNorm kernel is unexercised by parity tests. PerHeadRMSNormBSNHKernel / LaunchPerHeadRMSNorm only run on the kv_sequence_length == 0 (shared-KV, Q-only) path, but every config in gqa_qk_norm_test_cases uses kv_sequence_length ∈ {1,64}. This kernel uses a different (tree vs linear) reduction than UnpackRoPEAppend and ships without numerical coverage. Add a parity config with kv_sequence_length=0 + has_qk_norm=True.
  • New CUDA input-contract guards are untested. No negative tests for: only one of q/k norm present, wrong norm-weight shape, wrong dtype. The op test was converted to an accept-success smoke test; add OpTester negatives asserting the exact error strings.
  • Quantized + QK-Norm → off-XQA gating is implemented but not directly asserted. Add a decode test with has_qk_norm=True + quantized KV + ORT_ENABLE_XQA=1 and assert the kernel is not XQA (debug signal / parity), so the gate can't silently regress.
  • Readability: the refactor folded the named is_xqa_smooth_softmax_supported booleans into one compound if, making the gate harder to parse — consider restoring named xqa_smooth_softmax_ok / xqa_qk_norm_ok. The explicit parameters.use_qk_norm = false; reset is redundant (struct already default-inits false) and invites doubt.

Open question

  • Cross-EP precision convention differs. WebGPU (rotary_embedding.cc:159/180) casts inv_rms to element type T before the x*inv_rms*w multiply (fp16/bf16 multiply), whereas CUDA and the Python reference do the entire multiply in fp32. Both are valid RMSNorm and agree within fp16 tolerance, but they are not bit-identical across EPs. Please confirm that's acceptable and that the gqa.md §3 prose ("reduced in FP32 … cast back to T") isn't read as implying byte-parity between EPs.

Nits

Fractional // 1.5 QK-Norm step number breaks the kernel's 1/2/3 phase numbering; inconsistent shared-mem buffer names (s_sumsq vs s_qk_reduce) for the same concept; 3-way nested ternary for norm_weight selection; SetOutputTolerance(1e6f) smoke-test value looks like a 1e-6f typo and needs a comment stating intent; idx % len(packed_opts) alternation trick in the test generator deserves a one-line comment.

Praise

fp32-accumulate + single-final-cast UnpackRoPEAppend faithfully reproduces the reference RMSNorm rather than a lossy fp16 reduction; correct norm-before-RoPE ordering; deliberate Flash-Decode disable under QK-Norm so the prologue is never bypassed; XQA quantized-QK-Norm fail-safe; clean reuse of GroupQueryAttentionPreNormFusion (just added kCuda to the compatible set); flawless threading of qk_norm_epsilon/q_norm_weight/k_norm_weight across the float/half/bf16 paths. Documentation (gqa.md, the LaTeX RMSNorm formula) is excellent.

🤖 Synthesized from a 5-agent review pipeline (readability · code · critical · deep · integration). Findings were verified against the diff/source where load-bearing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants