Skip to content

[CUDA] Add sliding-window support to XQA decode#29177

Open
tianleiwu wants to merge 16 commits into
mainfrom
tlwu/update_xqa
Open

[CUDA] Add sliding-window support to XQA decode#29177
tianleiwu wants to merge 16 commits into
mainfrom
tlwu/update_xqa

Conversation

@tianleiwu

@tianleiwu tianleiwu commented Jun 20, 2026

Copy link
Copy Markdown
Contributor

Description

The XQA decode kernel previously fell back to FlashDecode whenever a local
(sliding) attention window was configured, so GPT-OSS / Mistral / Gemma2 style
models could not use the faster XQA path on their sliding-window layers. This PR
wires local_window_size through the fp16/bf16 XQA kernels so they
serve both global and sliding-window attention, and adds parity tests that confirm
the new path is exercised.

Summary of Changes

Sliding-window XQA kernel

File Change
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Drop the local_window_size == -1 gate for XQA path; keep INT8/FP8 variants global-only via a new is_global_attention guard.
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu Pass parameters.local_window_size into ExtremeDecoding.
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh Map ORT local_window_size (-1max_seq_len, else the value) to XQA slidingWinSize, guarded by #if SLIDING_WINDOW.
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h, xqa_loader_fp16*.{cu,cuh}, xqa_loader_bf16*.{cu,cuh} Thread a new local_window_size parameter through the launch path; enable #define SLIDING_WINDOW 1 in the fp16/bf16 impl headers.

Global attention (local_window_size == -1) maps to a window >= max_seq_len, so
the kernel's runtime masking guard is never taken — numerically identical to the
previous global-only behavior with zero added overhead.

Tests and profiling

  • onnxruntime/test/python/transformers/test_gqa.py: new TestXQASlidingWindowParity class and gqa_xqa_sliding_window_test_cases() generator (fp16/bf16 × head_size {64, 128} × group {4, 8} × past/window relationships × with/without head_sink), forcing ORT_ENABLE_XQA=1 and checking parity against the reference.
  • onnxruntime/test/python/transformers/profile_gqa.sh: add a --gpt-oss preset and a --compare-xqa mode that profiles XQA vs FlashDecode for the same shape.

Documentation

  • docs/contrib_ops/cuda/gqa.md (new) replaces docs/contrib_ops/gqa.md, documenting the CUDA GroupQueryAttention backends and dispatch.

Testing

  • cd onnxruntime/test/python/transformers && PYTHONPATH=<build_dir> python test_gqa.py TestXQASlidingWindowParity — all 32 cases pass on H200 (SM90).
  • Kernel selection verified via ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1 (SdpaKernel=XQA) and an nsys trace showing H64::grp4_fp16::kernel_mha launches instead of flash_fwd_splitkv_kernel.

Motivation and Context

GPT-OSS-20B has 12 sliding-window layers (local_window_size=128, head_sink, fp16,
64 q / 8 kv heads, head_size 64). On H200 single-token decode the XQA kernel is
~2.2× faster than FlashDecode on these shapes, so enabling XQA for the
sliding-window layers improves end-to-end decode latency.

Checklist

  • Tests added/updated
  • Documentation updated
  • No breaking changes (global-only behavior preserved; quantized paths unchanged)
  • CI passes

@tianleiwu tianleiwu changed the title Add sliding-window support to non-quantized XQA decode [CUDA] Add sliding-window support to non-quantized XQA decode Jun 20, 2026
@tianleiwu tianleiwu requested a review from Copilot June 20, 2026 17:05

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the CUDA GroupQueryAttention contrib op’s non-quantized XQA decode path (FP16/BF16) to support sliding-window/local attention by threading local_window_size through the XQA launch stack, while keeping quantized XQA (INT8/FP8 KV cache) restricted to global attention. It also adds parity tests and a profiling script mode to validate/compare kernel selection and performance.

Changes:

  • Enable non-quantized XQA decode when local_window_size > 0, and plumb local_window_size through the XQA loader and generated launch wrapper (guarded by SLIDING_WINDOW).
  • Keep INT8/FP8 XQA global-only via an is_global_attention guard in dispatch.
  • Add new sliding-window parity tests plus a --compare-xqa profiling mode and GPT-OSS preset; introduce/relocate CUDA GQA documentation.

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Allow non-quantized XQA decode selection for sliding-window attention; keep quantized XQA global-only.
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu Pass parameters.local_window_size into the XQA launch (ExtremeDecoding).
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h Extend XQA launch API with local_window_size.
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu Thread local_window_size through fp16 dispatcher/instantiations.
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh Enable SLIDING_WINDOW and pass local_window_size down to generated group kernels (fp16).
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu Update explicit instantiation signature for head_size=64 (fp16).
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu Update explicit instantiation signature for head_size=128 (fp16).
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu Update explicit instantiation signature for head_size=256 (fp16).
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu Thread local_window_size through bf16 dispatcher/instantiations.
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh Enable SLIDING_WINDOW and pass local_window_size down to generated group kernels (bf16).
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu Update explicit instantiation signature for head_size=64 (bf16).
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu Update explicit instantiation signature for head_size=128 (bf16).
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu Update explicit instantiation signature for head_size=256 (bf16).
onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh Map ORT local_window_size to XQA slidingWinSize under #if SLIDING_WINDOW.
onnxruntime/test/python/transformers/test_gqa.py Add sliding-window XQA parity test matrix (FP16/BF16, head_size 64/128, groups 4/8, sink on/off).
onnxruntime/test/python/transformers/profile_gqa.sh Add --gpt-oss preset and --compare-xqa (profiles ORT_ENABLE_XQA=0 vs 1).
docs/contrib_ops/gqa.md Remove/move prior top-level GQA doc (per #29173).
docs/contrib_ops/cuda/gqa.md Add CUDA-specific GQA documentation (currently contains statements inconsistent with this PR’s new XQA sliding-window support).

Comment thread onnxruntime/test/python/transformers/test_gqa.py Outdated
Comment thread docs/contrib_ops/cuda/gqa.md
Comment thread docs/contrib_ops/cuda/gqa.md Outdated
Comment thread docs/contrib_ops/cuda/gqa.md
Comment thread docs/contrib_ops/cuda/gqa.md Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 18 out of 18 changed files in this pull request and generated 2 comments.

Comment thread onnxruntime/test/python/transformers/profile_gqa.sh
Comment thread onnxruntime/test/python/transformers/test_gqa.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 18 out of 18 changed files in this pull request and generated no new comments.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 18 out of 18 changed files in this pull request and generated 1 comment.

Comment thread onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh
@titaiwangms

Copy link
Copy Markdown
Contributor

Multi-reviewer synthesis — PR #29177 [CUDA] Add sliding-window support to non-quantized XQA decode

Reviewed by 5 specialized agents (readability, code, critical/adversarial, deep-semantic, integration). Net: clean, well-isolated change. local_window_size is threaded consistently through every layer (group_query_attention.ccExtremeDecodingxqa_loader.h → all fp16/bf16 _64/_128/_256 instantiations → xqa_impl_gen.cuh), and the SLIDING_WINDOW capability is correctly compiled only into the non-quantized translation units while INT8/FP8 stay global-only via is_global_attention. The window mapping was verified by source derivation against the kernel and the test reference: XQA keeps [cacheSeqLen - slidingWinSize, cacheSeqLen-1] = exactly local_window_size tokens including the current decode token, matching ORT's convention and construct_local_maskno off-by-one. Note the deliberate asymmetry: XQA receives local_window_size directly, while Flash/unfused paths receive local_window_size - 1; this is correct, not a copy-paste slip, because the two kernels define their window param differently (both yield W tokens). Global invariance (-1 → max_seq_len) is bit-identical, structurally guaranteed by the runtime cacheSeqLen > slidingWinSize guard never firing. No Critical defects.

Note: the docs/contrib_ops/cuda/gqa.md change in this diff belongs to #29173 and will drop out once that merges — comments below treat it as out of scope.

Major

  1. Merge-order coordination with [CUDA] Enable CUDA GQA QK-Norm and XQA decode #29186 (same gating block). This PR rewrites the XQA gate in group_query_attention.cc (~line 405): it removes parameters.local_window_size == -1 from the outer condition and pushes it into a new is_global_attention guard on the quantized variants. PR [CUDA] Enable CUDA GQA QK-Norm and XQA decode #29186 edits this exact block to add the (!use_qk_norm || !is_inputs_quantized) condition. High risk of a text conflict or — worse — a silent logical merge that drops a condition from one PR. Whichever merges second must unify both: the sliding-window quantized guards and [CUDA] Enable CUDA GQA QK-Norm and XQA decode #29186's QK-Norm gate. Please coordinate.

  2. The two load-bearing safety claims are untested.

    • Global-unchanged (-1 → max_seq_len) is the core "zero behavior change" claim, but gqa_xqa_sliding_window_test_cases() never includes a local_window_size=-1 case, so the mapping path that's supposed to preserve global numerics isn't directly covered. Add explicit local_window_size=-1 parity cases (fp16/bf16, ±head_sink).
    • Quantized stays global-only: no test sets quantized KV + local_window_size>0 to assert it falls back off XQA. Add one that checks the kernel debug signal is not XQA and parity holds.

Minor

  • local_window_size is validated only after narrowing to int. group_query_attention.cc does static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1)) and then ORT_ENFORCE(== -1 || > 0). An out-of-range attribute (e.g. 2^32 + 128) can wrap to a valid-looking small window and pass validation, silently running a different window than the model specified. Validate the int64_t before casting: lw == -1 || (lw > 0 && lw <= INT_MAX).
  • Sliding-window still sizes split work by cache capacity, not the effective window. For max_seq_len >> local_window_size (e.g. 128k cache, 128-token window), computeNbSubSeqPerSeqMHA(... maxSeqLen) can enable multi-block mode and launch split CTAs across the whole capacity; the kernel skips leading tiles so results stay correct, but much of the expected perf/memory benefit of local attention may be lost. Consider effective_seq_len = min(max_seq_len, local_window_size) for scratch sizing / split heuristic (keep actual capacity for addressing). Perf-only, needs profiling to confirm — non-blocking: [needs-run: XQA sliding-window over-launches vs an effective-window split heuristic; repro=profile_gqa.sh --fp16 --compare-xqa --head-sink --head-size 64 --num-heads 64 --kv-num-heads 8 --local-window-size 128 --past-sequence-length 131071 --sequence-length 1; expect=extra multi-block/scratch overhead vs effective-window variant; cost=expensive]
  • Test matrix gaps: sliding cases cover only head_size ∈ {64,128} though SLIDING_WINDOW=1 is also compiled into the H256 fp16/bf16 kernels (untested); only window 128 is exercised — add a small boundary window (1 or 2) for inclusive/exclusive checks; and add a case pinned at the guard boundary cacheSeqLen == slidingWinSize (e.g. past=127, win=128) to lock down > vs >=.
  • Launch default arg weakens the compile-time thread-through guarantee. xqa_impl_gen.cuh Launch(..., const int local_window_size = -1) — a future non-quant caller that forgets the argument silently gets global attention instead of a compile error. All current callers pass it; consider dropping the default on the non-quant template so the compiler enforces the invariant.

Readability

  • xqa_impl_gen.cuh: the two #if SLIDING_WINDOW guards branch on a macro that is never mentioned in this file (it's #defined to 1 only in the fp16/bf16 impl headers, undefined in the quantized ones). Add a comment at the block stating where the macro originates and that the param is dead when 0.
  • The local_window_size = -1 default on the Launch template is a back-compat shim for the un-updated quantized callers, not a meaningful "global default" — document that, or update the quantized callers to pass -1 explicitly and drop the default.
  • local_window_size sits after scale in the loader signatures but is appended last in the innermost Launch template — inconsistent position adds friction when tracing the value down.
  • CaptureStdout (test) uses fd-level dup2 rather than redirect_stdout; add a one-line docstring noting it must capture C++ output written to OS fd 1 (which redirect_stdout can't intercept).

Praise

Excellent boundary discipline: isolating #define SLIDING_WINDOW 1 to the non-quantized impl headers means quantized TUs inherit SLIDING_WINDOW=0 and the -1 default, naturally ignoring the feature — no runtime branch in the quantized path. Contract changes are fully wired through every explicit template instantiation. The sliding_win_size mapping comment in xqa_impl_gen.cuh (passthrough for >0, sentinel substitution for -1, and the zero-overhead correctness argument in six lines) is exemplary.

🤖 Synthesized from a 5-agent review pipeline (readability · code · critical · deep · integration). Findings were verified against the diff/source where load-bearing.

cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
}

#if SLIDING_WINDOW

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be removed once sliding window is supported for the quantized path?

@tianleiwu tianleiwu changed the title [CUDA] Add sliding-window support to non-quantized XQA decode [CUDA] Add sliding-window support to XQA decode Jun 23, 2026
@tianleiwu tianleiwu requested a review from Copilot June 23, 2026 04:00

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 23 out of 23 changed files in this pull request and generated 7 comments.

Comment thread docs/contrib_ops/cuda/gqa.md Outdated
Comment thread docs/contrib_ops/cuda/gqa.md Outdated
Comment thread docs/contrib_ops/cuda/gqa.md Outdated
Comment thread docs/contrib_ops/cuda/gqa.md Outdated
Comment thread onnxruntime/test/python/transformers/test_gqa.py
Comment thread onnxruntime/test/python/transformers/test_gqa.py
Comment thread onnxruntime/test/python/transformers/test_gqa.py
@tianleiwu

Copy link
Copy Markdown
Contributor Author

Thanks for the thorough multi-agent synthesis. Addressed in ae18718 (on top of the quantized-enablement commit). Mapping each item:

Major

  1. Merge-order coordination with [CUDA] Enable CUDA GQA QK-Norm and XQA decode #29186 — resolved by the merge itself: [CUDA] Enable CUDA GQA QK-Norm and XQA decode #29186 is now in main and this branch has merged main, so the gate is unified (xqa_qk_norm_ok plus the sliding-window quantized guards live in the same block).
  2. Two load-bearing safety claims untested
    • Global-unchanged (-1 → max_seq_len): now exercised by every existing global XQA parity test. SLIDING_WINDOW=1 is compiled into all impl headers, so the -1 → max_seq_len mapping path runs on the global decode tests too (not just sliding cases).
    • Quantized stays global-only: this constraint was removed — the follow-up commit enables sliding window on the INT8/FP8 paths, covered by the new test_xqa_quantized_sliding_window_parity cases, which also assert SdpaKernel=XQA.

Minor

  • Validate local_window_size before narrowing to int — done. The int64_t attribute is now validated (-1, or 0 < lw <= INT_MAX) before the cast, so an out-of-range value can no longer wrap to a valid-looking small window.
  • Launch default arg weakens the compile-time thread-through — done. Dropped the = -1 default on the non-quant Launch template; all callers already pass local_window_size explicitly, so a future caller that forgets it now fails to compile instead of silently getting global attention.
  • Guard-boundary test — added a cacheSeqLen == slidingWinSize case (past=127, win=128) to both the non-quantized and quantized generators to lock down > vs >=.
  • Effective-window split heuristic — left as-is for now; it's perf-only and needs profiling, so tracking it as a non-blocking follow-up rather than changing scratch/split sizing in this PR.

Readability

  • Added a comment in xqa_impl_gen.cuh noting where SLIDING_WINDOW is defined (the fp16/bf16/int8/fp8 impl headers; default 0 in defines.h) and that the param is dead when 0.
  • Added a CaptureStdout docstring explaining the fd-level dup2 redirection is required to capture native C++ output on OS fd 1 (which redirect_stdout can't intercept).

All affected GQA sliding-window parity tests (non-quantized + quantized, including the new boundary cases) and the zero-window rejection regression test pass locally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants