[CUDA] Enable CUDA GQA QK-Norm and XQA decode#29186
Conversation
There was a problem hiding this comment.
Pull request overview
This PR expands com.microsoft::GroupQueryAttention on the CUDA EP to support fused per-head Q/K RMSNorm (QK-Norm) in the preprocess path (before RoPE), and restores the fast XQA decode route for non-quantized QK-Norm decode shapes. It also enables the pre-norm fusion pass for CUDA (previously WebGPU-only), updates operator/schema docs, and adds test/profiling coverage for the new routing and parity behavior.
Changes:
- Add CUDA QK-Norm plumbing and kernels (fused in
UnpackRoPEAppend, plus a standalone Q-only RMSNorm path for shared-KV decode). - Enable
GroupQueryAttentionPreNormFusionfor CUDA and add optimizer + Python parity tests (incl. explicit XQA decode checks for FP16/BF16). - Update profiling helpers and move/extend CUDA GQA documentation.
Reviewed changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/python/transformers/test_gqa.py | Adds QK-Norm config knobs, reference RMSNorm, QK-Norm parity tests, and XQA decode parity coverage. |
| onnxruntime/test/python/transformers/profile_gqa.sh | Extends CLI to toggle QK-Norm and improves nsys handling + compare mode. |
| onnxruntime/test/python/transformers/profile_gqa.py | Threads QK-Norm args through config and NVTX ranges for profiling. |
| onnxruntime/test/python/transformers/gqa_test_helper.py | Adds QK-Norm inputs/attrs to the helper model/config and random feed generation. |
| onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc | Expands fusion tests to cover CUDA-compatible fusion registration and CUDA EP assignment. |
| onnxruntime/test/contrib_ops/group_query_attention_op_test.cc | Updates CPU/CUDA contract tests for QK-Norm weight inputs. |
| onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h | Updates fusion docs to reflect CUDA+WebGPU support. |
| onnxruntime/core/optimizer/graph_transformer_utils.cc | Registers the pre-norm fusion transformer for CUDA + WebGPU. |
| onnxruntime/core/graph/contrib_ops/bert_defs.cc | Updates schema text to document CUDA+native WebGPU honoring QK-Norm weights. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention.h | Adds CUDA kernel member for qk_norm_epsilon_. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc | Accepts/validates QK-Norm weights, threads epsilon/flags, adjusts XQA and flash-decode routing gates. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh | Implements fused per-head RMSNorm in UnpackRoPEAppend and threads weights/epsilon through launch chain. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h | Updates buffer sizing requirements when QK-Norm requires a materialized Q scratch buffer. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu | Adds standalone per-head RMSNorm kernel for Q-only shared-KV decode and integrates QK-Norm into PrepareQKV/preprocess calls. |
| onnxruntime/contrib_ops/cuda/bert/attention_data.h | Adds QK-Norm weight pointers + epsilon to GroupQueryAttentionData. |
| onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc | Updates rejection text to reflect CUDA+WebGPU support. |
| onnxruntime/contrib_ops/cpu/bert/attention_parameters.h | Extends GroupQueryAttentionParameters with use_qk_norm + qk_norm_epsilon. |
| docs/ContribOperators.md | Updates generated schema docs to list CUDA+native WebGPU support for QK-Norm. |
| docs/contrib_ops/gqa.md | Removes the old (top-level) GQA doc in favor of a CUDA-specific doc path. |
| docs/contrib_ops/cuda/gqa.md | Adds CUDA-specific GQA documentation including QK-Norm behavior, dispatch rules, profiling, and testing. |
- group_query_attention_qkv.cuh: reduce QK-Norm sum once in tid==0 and broadcast inv_rms via shared memory to avoid redundant O(blockDim.x^2) shared reads. - pre_norm_fusion_test.cc: comment now says CUDA+WebGPU fusion path (test runs both). - test_gqa.py: TestGQAQKNorm now gated on has_cuda_device(80) with an accurate skip message instead of the misleading Flash Attention check.
|
Testing in Release builds is insufficient as it runs with debug diagnostics disabled. Non blockers. Missing negative tests:
|
Multi-reviewer synthesis — PR #29186
|
Description
Adds CUDA support for GroupQueryAttention QK-Norm by applying per-head Q/K RMSNorm before RoPE in the fused preprocess path. It also enables the pre-norm graph fusion for CUDA and allows non-quantized QK-Norm decode to use XQA, restoring the fast global decode path for GPT-OSS/Qwen-style shapes while keeping quantized-cache QK-Norm on the existing fallback path until scale handling is validated.
Summary of Changes
CUDA GroupQueryAttention
Fusion and Schemas
Tests, Docs, and Profiling
Testing
ninja onnxruntime_providers_cuda onnxruntime_test_allinbuild/cu130/Release../onnxruntime_test_all --gtest_filter="GraphTransformationTests.GroupQueryAttentionPreNormFusion*"(11 passed, 2 WebGPU skips).python -m pytest test_gqa.py::TestGQAQKNorm::test_gqa_qk_norm_past_xqa test_gqa.py::TestGQAQKNorm::test_gqa_qk_norm_past_xqa_bf16 -q(2 passed).python -m pytest test_gqa.py -k "QKNorm" -q(38 passed).git diff --check.ORT_ENABLE_XQA=1 ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1: FP16 and BF16 QK-Norm decode reportSdpaKernel=XQA.B=1,S=1,past=2048,N=64,Nkv=8,H=64,head_sink,QK-Norm) with nsys:H64::grp8_fp16::kernel_mhaaveraged ~8.21 us andUnpackRoPEAppend<half, half, 16, 64>averaged ~2.94 us.Checklist