[CUDA] Add sliding-window support to XQA decode#29177
Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends the CUDA GroupQueryAttention contrib op’s non-quantized XQA decode path (FP16/BF16) to support sliding-window/local attention by threading local_window_size through the XQA launch stack, while keeping quantized XQA (INT8/FP8 KV cache) restricted to global attention. It also adds parity tests and a profiling script mode to validate/compare kernel selection and performance.
Changes:
- Enable non-quantized XQA decode when
local_window_size > 0, and plumblocal_window_sizethrough the XQA loader and generated launch wrapper (guarded bySLIDING_WINDOW). - Keep INT8/FP8 XQA global-only via an
is_global_attentionguard in dispatch. - Add new sliding-window parity tests plus a
--compare-xqaprofiling mode and GPT-OSS preset; introduce/relocate CUDA GQA documentation.
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc | Allow non-quantized XQA decode selection for sliding-window attention; keep quantized XQA global-only. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu | Pass parameters.local_window_size into the XQA launch (ExtremeDecoding). |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h | Extend XQA launch API with local_window_size. |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu | Thread local_window_size through fp16 dispatcher/instantiations. |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh | Enable SLIDING_WINDOW and pass local_window_size down to generated group kernels (fp16). |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu | Update explicit instantiation signature for head_size=64 (fp16). |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu | Update explicit instantiation signature for head_size=128 (fp16). |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu | Update explicit instantiation signature for head_size=256 (fp16). |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu | Thread local_window_size through bf16 dispatcher/instantiations. |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh | Enable SLIDING_WINDOW and pass local_window_size down to generated group kernels (bf16). |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu | Update explicit instantiation signature for head_size=64 (bf16). |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu | Update explicit instantiation signature for head_size=128 (bf16). |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu | Update explicit instantiation signature for head_size=256 (bf16). |
| onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh | Map ORT local_window_size to XQA slidingWinSize under #if SLIDING_WINDOW. |
| onnxruntime/test/python/transformers/test_gqa.py | Add sliding-window XQA parity test matrix (FP16/BF16, head_size 64/128, groups 4/8, sink on/off). |
| onnxruntime/test/python/transformers/profile_gqa.sh | Add --gpt-oss preset and --compare-xqa (profiles ORT_ENABLE_XQA=0 vs 1). |
| docs/contrib_ops/gqa.md | Remove/move prior top-level GQA doc (per #29173). |
| docs/contrib_ops/cuda/gqa.md | Add CUDA-specific GQA documentation (currently contains statements inconsistent with this PR’s new XQA sliding-window support). |
536708e to
2319869
Compare
Multi-reviewer synthesis — PR #29177
|
| cudaMemsetAsync(semaphores, 0, semaphore_size, stream); | ||
| } | ||
|
|
||
| #if SLIDING_WINDOW |
There was a problem hiding this comment.
Will this be removed once sliding window is supported for the quantized path?
|
Thanks for the thorough multi-agent synthesis. Addressed in Major
Minor
Readability
All affected GQA sliding-window parity tests (non-quantized + quantized, including the new boundary cases) and the zero-window rejection regression test pass locally. |
Description
The XQA decode kernel previously fell back to FlashDecode whenever a local
(sliding) attention window was configured, so GPT-OSS / Mistral / Gemma2 style
models could not use the faster XQA path on their sliding-window layers. This PR
wires
local_window_sizethrough the fp16/bf16 XQA kernels so theyserve both global and sliding-window attention, and adds parity tests that confirm
the new path is exercised.
Summary of Changes
Sliding-window XQA kernel
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cclocal_window_size == -1gate for XQA path; keep INT8/FP8 variants global-only via a newis_global_attentionguard.onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cuparameters.local_window_sizeintoExtremeDecoding.onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuhlocal_window_size(-1→max_seq_len, else the value) to XQAslidingWinSize, guarded by#if SLIDING_WINDOW.onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h,xqa_loader_fp16*.{cu,cuh},xqa_loader_bf16*.{cu,cuh}local_window_sizeparameter through the launch path; enable#define SLIDING_WINDOW 1in the fp16/bf16 impl headers.Global attention (
local_window_size == -1) maps to a window>= max_seq_len, sothe kernel's runtime masking guard is never taken — numerically identical to the
previous global-only behavior with zero added overhead.
Tests and profiling
onnxruntime/test/python/transformers/test_gqa.py: newTestXQASlidingWindowParityclass andgqa_xqa_sliding_window_test_cases()generator (fp16/bf16 × head_size {64, 128} × group {4, 8} × past/window relationships × with/without head_sink), forcingORT_ENABLE_XQA=1and checking parity against the reference.onnxruntime/test/python/transformers/profile_gqa.sh: add a--gpt-osspreset and a--compare-xqamode that profiles XQA vs FlashDecode for the same shape.Documentation
docs/contrib_ops/cuda/gqa.md(new) replacesdocs/contrib_ops/gqa.md, documenting the CUDA GroupQueryAttention backends and dispatch.Testing
cd onnxruntime/test/python/transformers && PYTHONPATH=<build_dir> python test_gqa.py TestXQASlidingWindowParity— all 32 cases pass on H200 (SM90).ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1(SdpaKernel=XQA) and annsystrace showingH64::grp4_fp16::kernel_mhalaunches instead offlash_fwd_splitkv_kernel.Motivation and Context
GPT-OSS-20B has 12 sliding-window layers (
local_window_size=128, head_sink, fp16,64 q / 8 kv heads, head_size 64). On H200 single-token decode the XQA kernel is
~2.2× faster than FlashDecode on these shapes, so enabling XQA for the
sliding-window layers improves end-to-end decode latency.
Checklist