From 6ebf8638736b48cf3edbc181d28f4b58ccc7d8a1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jun 2026 18:49:00 -0700 Subject: [PATCH 1/9] update doc for cuda contrib op GQA --- docs/contrib_ops/cuda/gqa.md | 448 +++++++++++++++++++++++++++++++++++ docs/contrib_ops/gqa.md | 173 -------------- 2 files changed, 448 insertions(+), 173 deletions(-) create mode 100644 docs/contrib_ops/cuda/gqa.md delete mode 100644 docs/contrib_ops/gqa.md diff --git a/docs/contrib_ops/cuda/gqa.md b/docs/contrib_ops/cuda/gqa.md new file mode 100644 index 0000000000000..9c7c1a5313859 --- /dev/null +++ b/docs/contrib_ops/cuda/gqa.md @@ -0,0 +1,448 @@ +# GroupQueryAttention — Operator Documentation + +This document describes the `com.microsoft::GroupQueryAttention` (GQA) contrib operator: its schema, +the CUDA kernel backends and how one is selected, and the attention-sink (`head_sink`) decode path +that is accelerated by the XQA kernel. + +For CPU-specific implementation details (including the quantized KV-cache flash path), see +[cpu/gqa.md](cpu/gqa.md). + +--- + +## Table of Contents + +1. [Overview](#1-overview) +2. [Operator Schema](#2-operator-schema) +3. [Input Formats](#3-input-formats) +4. [KV Cache and Quantization](#4-kv-cache-and-quantization) +5. [Attention Sink (`head_sink`) and Smooth Softmax](#5-attention-sink-head_sink-and-smooth-softmax) +6. [CUDA Kernel Backends and Dispatch](#6-cuda-kernel-backends-and-dispatch) +7. [XQA Decode Path](#7-xqa-decode-path) +8. [XQA `head_sink` PrePack](#8-xqa-head_sink-prepack) +9. [Selecting a Kernel: Provider Option and Environment Variables](#9-selecting-a-kernel-provider-option-and-environment-variables) +10. [Profiling and Benchmarking](#10-profiling-and-benchmarking) +11. [Fast Build Options](#11-fast-build-options) +12. [Testing](#12-testing) +13. [Future Work and Known Limitations](#13-future-work-and-known-limitations) + +--- + +## 1. Overview + +GroupQueryAttention implements causal grouped-query attention with KV-cache (past/present) support. +Grouped-query attention uses fewer key/value heads than query heads: each KV head is shared by a +group of `num_heads / kv_num_heads` query heads. The operator also supports: + +- Rotary positional embeddings (RoPE) +- Past/present KV cache with optional in-place (shared) buffer +- Quantized KV cache (int4 / int8 / float8e4m3fn) to reduce memory footprint +- Optional attention bias and local (sliding) window attention +- Smooth softmax, including a per-head attention sink (`head_sink`) + +The operator schema is defined in +[onnxruntime/core/graph/contrib_ops/bert_defs.cc](../../onnxruntime/core/graph/contrib_ops/bert_defs.cc). +The CUDA kernel is implemented in +[onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc](../../onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc) +and [group_query_attention_impl.cu](../../onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu). + +## 2. Operator Schema + +Selected attributes: + +| Attribute | Description | +|-----------|-------------| +| `num_heads` | Number of query heads. | +| `kv_num_heads` | Number of key/value heads. `num_heads % kv_num_heads == 0`. | +| `scale` | Softmax scale. Defaults to `1/sqrt(head_size)`. | +| `softcap` | Optional logit soft-capping value. `0` disables it. | +| `local_window_size` | Left window size for local attention. `-1` means global attention. | +| `do_rotary` / `rotary_interleaved` | Enable RoPE and select interleaved vs. half-rotary layout. | +| `smooth_softmax` | Add a smooth factor to the softmax denominator. | +| `k_quant_type` / `v_quant_type` | KV cache quantization mode: `NONE`, `PER_TENSOR`, or `PER_CHANNEL`. | +| `kv_cache_bit_width` | Bit width of the quantized KV cache (`8` or `4`). | + +Selected inputs (see the schema for the full list and shapes): + +| Index | Name | Notes | +|-------|------|-------| +| 0 | `query` | `(batch, seq, hidden)`, or packed QKV. | +| 1, 2 | `key`, `value` | Optional when QKV is packed into `query`. | +| 3, 4 | `past_key`, `past_value` | BNSH cache. Shares the buffer with `present_*` when in-place. | +| 5 | `seqlens_k` | `total_sequence_lengths - 1` per batch entry. | +| 6 | `total_sequence_length` | Scalar used to distinguish prompt vs. decode. | +| 7, 8 | `cos_cache`, `sin_cache` | RoPE caches. | +| 11 | `head_sink` | `(num_heads,)` per-head attention sink (see §5). | +| 12, 13 | `k_scale`, `v_scale` | FP32 dequant scales for the quantized KV cache. | + +Outputs are `output`, `present_key`, `present_value`, and optional `output_qk`. + +## 3. Input Formats + +GQA accepts query/key/value in two layouts. The layout is inferred from whether `key` (input 1) +is present. + +### Unpacked Q, K, V (`Q_K_V_BSNH`) + +`key` and `value` are both provided: + +| Tensor | Shape | +|--------|-------| +| `query` | `(batch_size, sequence_length, num_heads * head_size)` | +| `key` | `(batch_size, sequence_length, kv_num_heads * head_size)` | +| `value` | `(batch_size, sequence_length, kv_num_heads * head_size)` | + +### Packed QKV (`QKV_BS3NH`) + +`key` and `value` are omitted (null) and Q, K, V are concatenated along the last dimension of +`query`: + +| Tensor | Shape | +|--------|-------| +| `query` | `(batch_size, sequence_length, (num_heads + 2 * kv_num_heads) * head_size)` | + +`head_size` is derived as `hidden_size / (num_heads + 2 * kv_num_heads)`. + +### KV cache layout + +`past_key` / `past_value` / `present_key` / `present_value` always use BNSH: +`(batch_size, kv_num_heads, cache_sequence_length, head_size)`. For a 4-bit quantized cache the +last dimension is `(head_size + 1) / 2` because two nibbles are packed per byte. + +### Constraints + +- `num_heads % kv_num_heads == 0` (each KV head is shared by `num_heads / kv_num_heads` query heads). +- `head_size == v_head_size` (Q and V share the head size). +- Q and K/V must have the same `sequence_length` (cross-attention is not supported). The exception + is the shared-buffer decode case where `kv_sequence_length == 0` (no new K/V to append — the past + buffer already holds the full KV cache). +- RoPE, packed-QKV unpacking, and KV-head expansion are handled internally (`PrepareQKV`) before the + selected backend runs, so every backend sees a consistent layout. + +## 4. KV Cache and Quantization + +### Layout and shared buffer + +The past/present KV cache uses BNSH layout +`(batch_size, kv_num_heads, cache_sequence_length, head_size)`. When `past_present_share_buffer` +holds (the past and present tensors alias the same memory), the cache length is the maximum +sequence length and new keys/values are appended in place. This shared-buffer mode is required by +the XQA decode path and by the Flash-Decoding fast path. + +### Quantized KV cache + +To reduce the KV-cache memory footprint, the cache may be stored quantized while `query` stays +FP16/BF16. Quantization is **symmetric** and configured by three attributes: + +| Attribute | Values | +|-----------|--------| +| `k_quant_type` / `v_quant_type` | `NONE`, `PER_TENSOR`, `PER_CHANNEL` | +| `kv_cache_bit_width` | `8` (INT8 / FP8) or `4` (INT4) | + +Supported storage types (`T_CACHE`) and their formula: + +| Type | Range | Quantize | +|------|-------|----------| +| INT8 | `[-128, 127]` | `q = clamp(round(x / scale), -128, 127)` | +| INT4 | `[-8, 7]`, two nibbles packed per byte | `q = clamp(round(x / scale), -8, 7)` | +| FP8 E4M3 | `[-448, 448]` | `q = clamp(x / scale, -448, 448)` (SM89+/Ada or SM90+) | + +- `k_scale` / `v_scale` (inputs 12, 13) are **always FP32**. For `PER_TENSOR` they are scalars; for + `PER_CHANNEL` they have shape `(kv_num_heads, 1, head_size)`. +- New keys/values are quantized as they are appended to the present cache; the attention kernel + dequantizes on the fly while computing scores. +- Registered type combinations are `T ∈ {float16, bfloat16}` × `T_CACHE ∈ {same as T, int8, FP8E4M3, uint8 (int4)}`. + +### How quantized decode is served + +The quantized KV-cache path is handled by the **XQA** decode kernel (see §7). XQA requires +`PER_TENSOR` scaling with `k_scale` and `v_scale` pointing to the **same** FP32 tensor, +`head_size ∈ {64, 128, 256}`, and a query/KV group size in `{4, 8, 16, 32}`. FP8 additionally +requires SM89+ (Ada) or SM90+. + +INT8 cache kernels are always built; FP8 (`onnxruntime_USE_FP8_KV_CACHE`, default ON) and INT4 +(`onnxruntime_USE_INT4_KV_CACHE`, default OFF) are gated by build options (see §11). + +## 5. Attention Sink (`head_sink`) and Smooth Softmax + +An attention sink adds a learned per-head bias term to the softmax denominator. With sink value `s_h` +for head `h`, the attention weights over `T` cached positions become: + +$$ +\text{softmax}_i = \frac{e^{x_i - m}}{e^{s_h - m} + \sum_{j} e^{x_j - m}}, \quad m = \max\left(s_h, \max_j x_j\right) +$$ + +This is equivalent to appending a single extra logit `s_h` (whose value contributes nothing to the +output, only to normalization). GPT-OSS style models use this to let a head attend to "nothing". + +In the kernel, providing the `head_sink` input is treated as smooth softmax: +`parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr`. The `head_sink` tensor is +1D of shape `(num_heads,)` and matches the operator's floating-point type (`float16` or `bfloat16` on +the XQA path). + +## 6. CUDA Kernel Backends and Dispatch + +The CUDA EP can route a GQA node to one of five backends. They are evaluated in a fixed priority +order and the first eligible backend wins: + +**XQA → cuDNN SDPA → Flash Attention → Memory Efficient Attention (MEA) → Unfused** + +| Priority | Backend | Selected when (summary) | +|----------|---------|-------------------------| +| 1 | **XQA** | Single-token global decode (`seq_len == 1`), shared KV buffer. Fastest decode path; the only backend that serves a quantized KV cache. | +| 2 | **cuDNN SDPA** | Non-quantized FP16/BF16 causal attention. Auto-preferred on SM≥90 (Hopper/Blackwell). | +| 3 | **Flash Attention** | General FP16/BF16 prompt and decode, including local window, softcap, and packed QKV. | +| 4 | **Memory Efficient Attention (MEA)** | Fallback for FP16/FP32 (and BF16 on SM80+). | +| 5 | **Unfused** | Last-resort fallback (e.g. `head_size > 256`). Any head size, GQA, sliding window, softcap. | + +The selected backend is reported in the kernel debug info as `SdpaKernel=...` when debug info is +enabled (see §10). + +### 6.1 XQA + +Checked first. Used only for single-token global decode under the conditions detailed in §7. When +XQA is selected, no other backend is considered. + +### 6.2 cuDNN SDPA + +Eligible when **all** of the following hold: + +- not already selected for XQA; +- KV cache is **not** quantized (`T_CACHE == T`); +- `softcap == 0`, no smooth softmax, and no `head_sink`; +- no local (sliding) window (`local_window_size == -1`); +- past/present KV in BNSH (`Q_K_V_BNSH`); +- cuDNN SDPA is enabled — either explicitly (`ORT_ENABLE_CUDNN_FLASH_ATTENTION=1` or the cuDNN bit of + `sdpa_kernel`), or auto-preferred on SM≥90 when no kernel is explicitly pinned; +- cuDNN ≥ 9.3 (stable) and `is_supported` returns true for the shape. + +### 6.3 Flash Attention + +Eligible when: + +- not XQA and not cuDNN SDPA; +- FP16/BF16 (`sizeof(T) == 2`) and Flash is enabled (not `ORT_DISABLE_FLASH_ATTENTION`, not disabled + via `sdpa_kernel`, and built with `USE_FLASH_ATTENTION`); +- `flash::is_supported` is true for `head_size` / `num_heads` / `kv_num_heads`. + +Flash supports local window, softcap, RoPE, and packed QKV. For decode it additionally uses a +**Flash-Decoding** split-KV fast path (`seq_len == 1`, shared buffer, non-quantized), unless +`ORT_DISABLE_FLASH_DECODE=1`. + +### 6.4 Memory Efficient Attention (MEA) + +Fallback when XQA, cuDNN SDPA, and Flash are all ineligible: + +- MEA enabled (not `ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION`, built with `USE_MEMORY_EFFICIENT_ATTENTION`); +- `has_memory_efficient_attention(sm, is_fp16, is_bf16, head_size)` is true — FP16/FP32 broadly, + BF16 on SM80+. + +When the query/KV head counts differ, the KV heads are expanded to `num_heads` into a scratch buffer. + +### 6.5 Unfused + +Last-resort path, activated when XQA / cuDNN / Flash / MEA are all ineligible **and**: + +- KV cache is not quantized; +- no smooth softmax and no `head_sink`; +- past/present KV in BNSH. + +It supports any `head_size` (FP32 QK accumulation), GQA, sliding window, and softcap — for example +`head_size > 256` with past KV. The unfused (math) path can never be turned off and is always +available as a fallback. + +## 7. XQA Decode Path + +XQA (a highly optimized cross/decode attention kernel) is used only when **all** of the following hold: + +1. Compute capability SM 8.0+ (Ampere or newer). +2. Decoding phase (not the first prompt) with `sequence_length == 1`. +3. `kv_sequence_length > 0` (there is a new K/V to append). +4. Past and present KV cache share the same buffer. +5. No softcap. +6. Standard softmax, **or** smooth softmax expressed via a `head_sink` tensor (non-quantized KV cache). +7. No local (sliding) window attention — global attention only. +8. Supported `head_size` (64, 128, or 256) and group size. + +`head_sink` (attention sink) is supported on the non-quantized XQA path only. Quantized KV cache +(int8 / fp8) paths explicitly reject a non-null attention sink, so a GQA node with both `head_sink` +and a quantized cache falls back to Flash/Flash-Decoding. + +XQA selection defaults are: + +- **Quantized KV cache (int8 / fp8):** on by default. +- **Non-quantized with a `head_sink` input:** on by default (GPT-OSS style decode). +- **Non-quantized without `head_sink`:** opt-in via `ORT_ENABLE_XQA=1`. + +Setting `ORT_ENABLE_XQA=0` disables XQA for the non-quantized path regardless of `head_sink`. + +## 8. XQA `head_sink` PrePack + +XQA consumes the attention sink as an FP32 buffer, while the model stores `head_sink` as FP16/BF16. To +avoid converting on every decode step, `GroupQueryAttention::PrePack` converts a **constant-initializer** +`head_sink` once into a cached FP32 device buffer (`xqa_head_sink_`): + +- The cached buffer is reused for every launch when XQA is eligible. +- A dynamic / non-initializer `head_sink` is **not** prepacked; the kernel instead reserves a small FP32 + scratch buffer and converts the sink per launch (`xqa_head_sink_needs_conversion = true`). +- `PrePack` keeps `is_packed = false` so the original FP16/BF16 `head_sink` is still delivered to the + Flash/fallback paths when XQA is disabled or ineligible. + +## 9. Selecting a Kernel: Provider Option and Environment Variables + +### `sdpa_kernel` provider option + +The CUDA EP exposes a `sdpa_kernel` provider option (a bitmask defined by `AttentionBackend`) that +pins which fused attention backends are allowed. It applies to GroupQueryAttention, +MultiHeadAttention, and Attention nodes. + +| Bit value | Backend | +|-----------|---------| +| `0` | Default — selection follows heuristics / environment variables (auto-prefers cuDNN SDPA on SM≥90). | +| `1` | Flash Attention | +| `2` | Memory Efficient Attention | +| `8` | cuDNN SDPA | +| `16` | Unfused (math) — note the unfused fallback can never actually be turned off | + +Bits can be OR-ed together. Any positive value is treated as an **explicit** selection: only the +listed backends are enabled and the automatic cuDNN-on-SM≥90 preference is disabled. **XQA is not +part of this bitmask** — it is controlled separately by `ORT_ENABLE_XQA`. + +```python +import onnxruntime as ort + +sess = ort.InferenceSession( + "model.onnx", + providers=[("CUDAExecutionProvider", {"sdpa_kernel": "1"})], # 1 = Flash Attention only +) +``` + +### Environment variables + +| Variable | Effect | +|----------|--------| +| `ORT_ENABLE_XQA` | `1` enables the XQA decode path for the non-quantized KV cache; `0` disables XQA entirely (including the quantized and `head_sink` default-on paths). Unset: on for quantized / `head_sink`, off otherwise (see §7). | +| `ORT_DISABLE_FLASH_ATTENTION` | `1` disables Flash Attention. | +| `ORT_DISABLE_FLASH_DECODE` | `1` disables the Flash-Decoding split-KV optimization. | +| `ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION` | `1` disables Memory Efficient Attention. | +| `ORT_ENABLE_CUDNN_FLASH_ATTENTION` | `1` enables cuDNN SDPA; `0` disables it and also disables the SM≥90 auto-preference. | +| `ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO` | `1` prints the selected backend (`SdpaKernel=...`) per node (see §10). | + +A positive `sdpa_kernel` value takes precedence over these environment defaults. Environment +variables are read once when the kernel is constructed. + +## 10. Profiling and Benchmarking + +### Verify which backend ran + +Set `ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1`. For each GQA node the kernel prints a line such as: + +``` +Operator=GroupQueryAttention Node= DataType=fp16 SdpaKernel=XQA +``` + +`SdpaKernel` is one of `XQA`, `FLASH_ATTENTION`, `EFFICIENT_ATTENTION`, `CUDNN_FLASH_ATTENTION`, or +`MATH` (unfused). Use this to confirm that an env var / `sdpa_kernel` choice took effect. + +### Benchmark and profiling scripts + +Located in `onnxruntime/test/python/transformers/`: + +| Script | Purpose | +|--------|---------| +| [profile_gqa.py](../../onnxruntime/test/python/transformers/profile_gqa.py) | Profile GQA (incl. quantized KV cache) with NVTX markers; examples for Nsight Compute (`ncu`) and Nsight Systems (`nsys`). | +| [benchmark_gqa.py](../../onnxruntime/test/python/transformers/benchmark_gqa.py) | Triton-based throughput comparison across dense / local / packed-QKV and INT4/INT8/FP8 variants. | +| [benchmark_gqa_windows.py](../../onnxruntime/test/python/transformers/benchmark_gqa_windows.py) | GQA benchmark variant for Windows. | +| [benchmark_gqa_cpu_flash.py](../../onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py) | CPU flash-vs-naive GQA benchmark. | + +Example kernel-level and timeline profiling: + +```bash +cd onnxruntime/test/python/transformers + +# Kernel-level analysis with Nsight Compute +ncu --set full -o gqa_int8 python profile_gqa.py --mode int8 --warmup 5 --repeat 1 + +# Timeline with Nsight Systems, then parse kernel timings +nsys profile -o gqa_int8 --export=sqlite python profile_gqa.py --mode int8 --warmup 5 --repeat 10 +python parse_nsys.py gqa_int8.sqlite +``` + +ONNX Runtime's built-in profiler (`SessionOptions.enable_profiling = True`) also emits a JSON +timeline with per-node durations. + +## 11. Fast Build Options + +These CMake options speed up CUDA builds during development. Pass them through +`--cmake_extra_defines` (see the `ort-build` skill). + +| Option | Default | Effect | +|--------|---------|--------| +| `onnxruntime_QUICK_BUILD` | `OFF` | Builds only the `hdim128` FP16/BF16 Flash Attention kernels. Greatly reduces compile time, but **changes dispatch**: shapes with `head_size != 128` fall back to Memory Efficient Attention because Flash is no longer compiled for them. Do not use it to characterize Flash-vs-arch behavior. | +| `onnxruntime_USE_FP8_KV_CACHE` | `ON` | Builds the FP8 (E4M3) quantized KV-cache kernels (`-DUSE_FP8_KV_CACHE=1`). | +| `onnxruntime_USE_INT4_KV_CACHE` | `OFF` | Builds the INT4 quantized KV-cache kernels (`-DUSE_INT4_KV_CACHE=1`). A `kv_cache_bit_width == 4` node errors out if this is off. | + +Other ways to shorten the iteration loop: + +- Restrict GPU architectures with `CMAKE_CUDA_ARCHITECTURES` (e.g. + `--cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80`) so kernels are not compiled for unused SMs. +- Build only the CUDA provider target: + `./build.sh --config Release --build --parallel --target onnxruntime_providers_cuda`. +- Skip `--update` when you only edited existing `.cc` / `.h` / `.cu` files. + +```bash +./build.sh --config Release --parallel --use_cuda \ + --cuda_home /usr/local/cuda --cudnn_home /usr/local/cuda \ + --cmake_extra_defines onnxruntime_QUICK_BUILD=ON onnxruntime_USE_INT4_KV_CACHE=ON +``` + +## 12. Testing + +CUDA parity tests live in +[onnxruntime/test/python/transformers/test_gqa.py](../../onnxruntime/test/python/transformers/test_gqa.py): + +- `TestXQAQuantizedParity` — XQA per-tensor int8 quantized decode parity. +- `TestXQAHeadSinkParity` — non-quantized XQA decode parity with a `head_sink` (attention sink) input. + +`TestXQAQuantizedParity` sets `ORT_ENABLE_XQA=1` to force the XQA path. `TestXQAHeadSinkParity` +instead clears `ORT_ENABLE_XQA` to validate that XQA is enabled by default when a `head_sink` input +is present. Both compare against a PyTorch reference (`attention_ref` with `smooth_softmax_ref`). + +## 13. Future Work and Known Limitations + +The following features are missing or limited in the CUDA GQA kernel and would broaden coverage of +popular LLMs. Listed roughly by impact. + +### High impact + +1. **Fused QK-Norm (per-head Q/K RMSNorm prologue).** The CUDA kernel rejects `q_norm_weight` / + `k_norm_weight`; only the WebGPU EP implements the fused prologue. Required by **Qwen3, + Gemma 2/3, OLMo2, SmolLM3**, etc., which otherwise run the normalization unfused. +2. **Sliding-window + attention-sink on the fused decode path.** XQA requires global attention + (`local_window_size == -1`), so **GPT-OSS / Mistral / Gemma 2** layers that combine a sliding + window with a `head_sink` fall back to Flash / Flash-Decoding instead of XQA. Unifying sliding + window and sink in XQA would close this gap. +3. **Softcap on the fastest kernels.** Logit soft-capping (**Gemma 2**) disables both XQA and cuDNN + SDPA, forcing the Flash / MEA / unfused paths. Adding softcap support to XQA and cuDNN would + recover decode throughput. +4. **Attention bias / ALiBi.** `attention_bias` is rejected outright. Needed for ALiBi-style models + and additive-mask use cases, though less commonly used in current popular decoder-only LLMs. + +### Medium impact + +5. **Quantized KV cache coverage.** Quantized decode is XQA-only and narrow: `PER_TENSOR` with + `k_scale == v_scale`, `head_size ∈ {64, 128, 256}`, group size `{4, 8, 16, 32}`. Gaps worth + filling: `PER_CHANNEL` serving, prompt-phase quantized attention, INT4 enabled by default, and + `head_sink` combined with a quantized cache (currently rejected). +6. **Paged KV cache / continuous batching.** GQA uses a contiguous shared buffer; there is a + separate `PagedAttention` op, but GQA itself has no paged-cache path. Paged KV is what + high-throughput serving (vLLM-style) needs. +7. **MLA (Multi-head Latent Attention).** **DeepSeek-V2/V3** use latent KV compression with a + `v_head_size` that differs from `head_size`; GQA assumes `head_size == v_head_size`. This needs a + distinct kernel/op rather than a GQA tweak. + +### Lower impact / niche + +8. **Returning attention weights (`output_qk`).** Never supported by the CUDA fused kernels. Only + relevant for interpretability or speculative-decode scoring. +9. **Cross-attention (different Q vs KV sequence lengths).** Rejected by the input checker. + Encoder-decoder / multimodal cross-attention is not covered by GQA. diff --git a/docs/contrib_ops/gqa.md b/docs/contrib_ops/gqa.md deleted file mode 100644 index 08596ff4b5dd9..0000000000000 --- a/docs/contrib_ops/gqa.md +++ /dev/null @@ -1,173 +0,0 @@ -# GroupQueryAttention — Operator Documentation - -This document describes the `com.microsoft::GroupQueryAttention` (GQA) contrib operator: its schema, -the CUDA kernel backends and how one is selected, and the attention-sink (`head_sink`) decode path -that is accelerated by the XQA kernel. - -For CPU-specific implementation details (including the quantized KV-cache flash path), see -[cpu/gqa.md](cpu/gqa.md). - ---- - -## Table of Contents - -1. [Overview](#1-overview) -2. [Operator Schema](#2-operator-schema) -3. [KV Cache and Quantization](#3-kv-cache-and-quantization) -4. [Attention Sink (`head_sink`) and Smooth Softmax](#4-attention-sink-head_sink-and-smooth-softmax) -5. [CUDA Kernel Backends and Dispatch](#5-cuda-kernel-backends-and-dispatch) -6. [XQA Decode Path](#6-xqa-decode-path) -7. [XQA `head_sink` PrePack](#7-xqa-head_sink-prepack) -8. [Environment Variables](#8-environment-variables) -9. [Testing](#9-testing) - ---- - -## 1. Overview - -GroupQueryAttention implements causal grouped-query attention with KV-cache (past/present) support. -Grouped-query attention uses fewer key/value heads than query heads: each KV head is shared by a -group of `num_heads / kv_num_heads` query heads. The operator also supports: - -- Rotary positional embeddings (RoPE) -- Past/present KV cache with optional in-place (shared) buffer -- Quantized KV cache (int4 / int8 / float8e4m3fn) to reduce memory footprint -- Optional attention bias and local (sliding) window attention -- Smooth softmax, including a per-head attention sink (`head_sink`) - -The operator schema is defined in -[onnxruntime/core/graph/contrib_ops/bert_defs.cc](../../onnxruntime/core/graph/contrib_ops/bert_defs.cc). -The CUDA kernel is implemented in -[onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc](../../onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc) -and [group_query_attention_impl.cu](../../onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu). - -## 2. Operator Schema - -Selected attributes: - -| Attribute | Description | -|-----------|-------------| -| `num_heads` | Number of query heads. | -| `kv_num_heads` | Number of key/value heads. `num_heads % kv_num_heads == 0`. | -| `scale` | Softmax scale. Defaults to `1/sqrt(head_size)`. | -| `softcap` | Optional logit soft-capping value. `0` disables it. | -| `local_window_size` | Left window size for local attention. `-1` means global attention. | -| `do_rotary` / `rotary_interleaved` | Enable RoPE and select interleaved vs. half-rotary layout. | -| `smooth_softmax` | Add a smooth factor to the softmax denominator. | -| `k_quant_type` / `v_quant_type` | KV cache quantization mode: `NONE`, `PER_TENSOR`, or `PER_CHANNEL`. | -| `kv_cache_bit_width` | Bit width of the quantized KV cache (`8` or `4`). | - -Selected inputs (see the schema for the full list and shapes): - -| Index | Name | Notes | -|-------|------|-------| -| 0 | `query` | `(batch, seq, hidden)`, or packed QKV. | -| 1, 2 | `key`, `value` | Optional when QKV is packed into `query`. | -| 3, 4 | `past_key`, `past_value` | BNSH cache. Shares the buffer with `present_*` when in-place. | -| 5 | `seqlens_k` | `total_sequence_lengths - 1` per batch entry. | -| 6 | `total_sequence_length` | Scalar used to distinguish prompt vs. decode. | -| 7, 8 | `cos_cache`, `sin_cache` | RoPE caches. | -| 11 | `head_sink` | `(num_heads,)` per-head attention sink (see §4). | -| 12, 13 | `k_scale`, `v_scale` | FP32 dequant scales for the quantized KV cache. | - -Outputs are `output`, `present_key`, `present_value`, and optional `output_qk`. - -## 3. KV Cache and Quantization - -The past/present KV cache uses BNSH layout `(batch_size, kv_num_heads, cache_sequence_length, head_size)`. -When `past_present_share_buffer` holds (the past and present tensors alias the same memory), the cache -length is the maximum sequence length and new keys/values are appended in place. This shared-buffer mode -is required by the XQA decode path. - -When quantization is enabled, `k_quant_type` and `v_quant_type` select `PER_TENSOR` or `PER_CHANNEL` -scaling, and `kv_cache_bit_width` selects 8-bit or 4-bit storage. The `k_scale` / `v_scale` inputs are -always FP32. - -## 4. Attention Sink (`head_sink`) and Smooth Softmax - -An attention sink adds a learned per-head bias term to the softmax denominator. With sink value `s_h` -for head `h`, the attention weights over `T` cached positions become: - -$$ -\text{softmax}_i = \frac{e^{x_i - m}}{e^{s_h - m} + \sum_{j} e^{x_j - m}}, \quad m = \max\left(s_h, \max_j x_j\right) -$$ - -This is equivalent to appending a single extra logit `s_h` (whose value contributes nothing to the -output, only to normalization). GPT-OSS style models use this to let a head attend to "nothing". - -In the kernel, providing the `head_sink` input is treated as smooth softmax: -`parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr`. The `head_sink` tensor is -1D of shape `(num_heads,)` and matches the operator's floating-point type (`float16` or `bfloat16` on -the XQA path). - -## 5. CUDA Kernel Backends and Dispatch - -The CUDA EP can route a GQA node to several backends. At runtime it selects the first eligible one: - -| Backend | Typical use | -|---------|-------------| -| **XQA** | Single-token global decode (`seq_len == 1`), shared KV buffer. Fastest decode path. | -| **Flash Attention / Flash Decoding** | General prompt and decode, including local window and softcap. | -| **cuDNN SDPA** | Preferred on SM≥90 for non-quantized FP16/BF16 causal attention. | -| **Memory Efficient Attention** | Fallback for FP16/FP32 (and BF16 on SM80+). | -| **Unfused** | Last-resort fallback (e.g. `head_size > 256` with past KV). | - -The selected backend is reported in the kernel debug info as `SdpaKernel=...` when debug info is enabled. - -## 6. XQA Decode Path - -XQA (a highly optimized cross/decode attention kernel) is used only when **all** of the following hold: - -1. Compute capability SM 8.0+ (Ampere or newer). -2. Decoding phase (not the first prompt) with `sequence_length == 1`. -3. `kv_sequence_length > 0` (there is a new K/V to append). -4. Past and present KV cache share the same buffer. -5. No softcap. -6. Standard softmax, **or** smooth softmax expressed via a `head_sink` tensor (non-quantized KV cache). -7. No local (sliding) window attention — global attention only. -8. Supported `head_size` (64, 128, or 256) and group size. - -`head_sink` (attention sink) is supported on the non-quantized XQA path only. Quantized KV cache -(int8 / fp8) paths explicitly reject a non-null attention sink, so a GQA node with both `head_sink` -and a quantized cache falls back to Flash/Flash-Decoding. - -XQA selection defaults are: - -- **Quantized KV cache (int8 / fp8):** on by default. -- **Non-quantized with a `head_sink` input:** on by default (GPT-OSS style decode). -- **Non-quantized without `head_sink`:** opt-in via `ORT_ENABLE_XQA=1`. - -Setting `ORT_ENABLE_XQA=0` disables XQA for the non-quantized path regardless of `head_sink`. - -## 7. XQA `head_sink` PrePack - -XQA consumes the attention sink as an FP32 buffer, while the model stores `head_sink` as FP16/BF16. To -avoid converting on every decode step, `GroupQueryAttention::PrePack` converts a **constant-initializer** -`head_sink` once into a cached FP32 device buffer (`xqa_head_sink_`): - -- The cached buffer is reused for every launch when XQA is eligible. -- A dynamic / non-initializer `head_sink` is **not** prepacked; the kernel instead reserves a small FP32 - scratch buffer and converts the sink per launch (`xqa_head_sink_needs_conversion = true`). -- `PrePack` keeps `is_packed = false` so the original FP16/BF16 `head_sink` is still delivered to the - Flash/fallback paths when XQA is disabled or ineligible. - -## 8. Environment Variables - -| Variable | Effect | -|----------|--------| -| `ORT_ENABLE_XQA` | `1` enables the XQA decode path for the non-quantized KV cache (default off; default on for quantized). | -| `ORT_DISABLE_FLASH_DECODE` | `1` disables the Flash Decoding split-KV optimization. | - -These are read once when the kernel is constructed. - -## 9. Testing - -CUDA parity tests live in -[onnxruntime/test/python/transformers/test_gqa.py](../../onnxruntime/test/python/transformers/test_gqa.py): - -- `TestXQAQuantizedParity` — XQA per-tensor int8 quantized decode parity. -- `TestXQAHeadSinkParity` — non-quantized XQA decode parity with a `head_sink` (attention sink) input. - -`TestXQAQuantizedParity` sets `ORT_ENABLE_XQA=1` to force the XQA path. `TestXQAHeadSinkParity` -instead clears `ORT_ENABLE_XQA` to validate that XQA is enabled by default when a `head_sink` input -is present. Both compare against a PyTorch reference (`attention_ref` with `smooth_softmax_ref`). From caf4eeeefe8f8bfde1e41ba36b70583c8a9444c4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jun 2026 20:19:08 -0700 Subject: [PATCH 2/9] Initial draft of rms norm and gqa --- docs/contrib_ops/cuda/gqa.md | 43 +++- .../cpu/bert/attention_parameters.h | 2 + .../contrib_ops/cuda/bert/attention_data.h | 6 + .../cuda/bert/group_query_attention.cc | 57 +++++- .../cuda/bert/group_query_attention.h | 1 + .../cuda/bert/group_query_attention_impl.cu | 81 +++++++- .../cuda/bert/group_query_attention_impl.h | 4 +- .../cuda/bert/group_query_attention_qkv.cuh | 67 ++++++- .../test/python/transformers/test_gqa.py | 184 +++++++++++++++++- 9 files changed, 414 insertions(+), 31 deletions(-) diff --git a/docs/contrib_ops/cuda/gqa.md b/docs/contrib_ops/cuda/gqa.md index 9c7c1a5313859..9aa9ab1cd464c 100644 --- a/docs/contrib_ops/cuda/gqa.md +++ b/docs/contrib_ops/cuda/gqa.md @@ -58,6 +58,7 @@ Selected attributes: | `local_window_size` | Left window size for local attention. `-1` means global attention. | | `do_rotary` / `rotary_interleaved` | Enable RoPE and select interleaved vs. half-rotary layout. | | `smooth_softmax` | Add a smooth factor to the softmax denominator. | +| `qk_norm_epsilon` | Epsilon for the fused per-head Q/K RMSNorm (QK-Norm) prologue. Defaults to `1e-6`. | | `k_quant_type` / `v_quant_type` | KV cache quantization mode: `NONE`, `PER_TENSOR`, or `PER_CHANNEL`. | | `kv_cache_bit_width` | Bit width of the quantized KV cache (`8` or `4`). | @@ -73,6 +74,7 @@ Selected inputs (see the schema for the full list and shapes): | 7, 8 | `cos_cache`, `sin_cache` | RoPE caches. | | 11 | `head_sink` | `(num_heads,)` per-head attention sink (see §5). | | 12, 13 | `k_scale`, `v_scale` | FP32 dequant scales for the quantized KV cache. | +| 14, 15 | `q_norm_weight`, `k_norm_weight` | `(head_size,)` per-head Q/K RMSNorm weights (QK-Norm, see §3). Both must be present together. | Outputs are `output`, `present_key`, `present_value`, and optional `output_qk`. @@ -118,6 +120,30 @@ last dimension is `(head_size + 1) / 2` because two nibbles are packed per byte. - RoPE, packed-QKV unpacking, and KV-head expansion are handled internally (`PrepareQKV`) before the selected backend runs, so every backend sees a consistent layout. +### Fused QK-Norm (per-head Q/K RMSNorm) + +When the optional `q_norm_weight` (input 14) and `k_norm_weight` (input 15) tensors are provided, the +CUDA kernel applies a fused per-head RMS normalization to Q and K **before** RoPE. This matches the +QK-Norm used by **Qwen3, Gemma 2/3, OLMo2, SmolLM3**, etc. For each head, over the `head_size` +channels: + +$$ +x_\text{norm}[c] = x[c] \cdot \frac{1}{\sqrt{\frac{1}{H}\sum_{j} x[j]^2 + \epsilon}} \cdot w[c] +$$ + +where `H = head_size`, `w` is the per-head weight vector (`q_norm_weight` for Q, `k_norm_weight` for +K), and `epsilon = qk_norm_epsilon` (default `1e-6`). The sum of squares is reduced in FP32 for +numerical stability and the result is cast back to the operator type `T`. + +- Both weights are 1D tensors of shape `(head_size,)`, share the operator's element type `T` + (`float16`/`bfloat16`), and are **shared across all heads**. They must be supplied together — + providing only one is rejected. +- The normalization is fused into the `PrepareQKV` prologue (`UnpackRoPEAppend` for the new-KV path, + or a standalone per-head RMSNorm kernel for the shared-buffer Q-only decode case), so it composes + with packed QKV, RoPE, KV-head expansion, and the quantized KV cache. +- Because the fused decode kernels (XQA and Flash-Decoding) do their own RoPE/append internally and + bypass `PrepareQKV`, they are disabled when QK-Norm is present (see §6). + ## 4. KV Cache and Quantization ### Layout and shared buffer @@ -197,6 +223,11 @@ order and the first eligible backend wins: The selected backend is reported in the kernel debug info as `SdpaKernel=...` when debug info is enabled (see §10). +> **QK-Norm interaction.** When `q_norm_weight` / `k_norm_weight` are present (see §3), the fused +> decode paths that bypass `PrepareQKV` — **XQA** and the **Flash-Decoding fast path** — are +> disabled so the QK-Norm prologue always runs. Such nodes therefore route to Flash Attention +> (or cuDNN SDPA / MEA / Unfused), all of which consume the normalized Q/K produced by `PrepareQKV`. + ### 6.1 XQA Checked first. Used only for single-token global decode under the conditions detailed in §7. When @@ -402,10 +433,14 @@ CUDA parity tests live in - `TestXQAQuantizedParity` — XQA per-tensor int8 quantized decode parity. - `TestXQAHeadSinkParity` — non-quantized XQA decode parity with a `head_sink` (attention sink) input. +- `TestGQAQKNorm` — fused per-head Q/K RMSNorm (QK-Norm) parity for prompt and decode (past), FP16 and + BF16, across packed/unpacked Q/K/V and with/without RoPE. `TestXQAQuantizedParity` sets `ORT_ENABLE_XQA=1` to force the XQA path. `TestXQAHeadSinkParity` instead clears `ORT_ENABLE_XQA` to validate that XQA is enabled by default when a `head_sink` input is present. Both compare against a PyTorch reference (`attention_ref` with `smooth_softmax_ref`). +`TestGQAQKNorm` applies the RMSNorm-before-RoPE reference to Q and K and compares against the CUDA +output. ## 13. Future Work and Known Limitations @@ -414,9 +449,11 @@ popular LLMs. Listed roughly by impact. ### High impact -1. **Fused QK-Norm (per-head Q/K RMSNorm prologue).** The CUDA kernel rejects `q_norm_weight` / - `k_norm_weight`; only the WebGPU EP implements the fused prologue. Required by **Qwen3, - Gemma 2/3, OLMo2, SmolLM3**, etc., which otherwise run the normalization unfused. +1. **Fused QK-Norm (per-head Q/K RMSNorm prologue).** *Implemented.* The CUDA kernel applies the + fused per-head RMSNorm to Q and K before RoPE when `q_norm_weight` / `k_norm_weight` are provided + (see §3), matching **Qwen3, Gemma 2/3, OLMo2, SmolLM3**, etc. Remaining limitation: QK-Norm + disables the fused decode kernels (XQA and Flash-Decoding) and routes through Flash / cuDNN / + MEA instead, so QK-Norm decode does not yet get the XQA fast path. 2. **Sliding-window + attention-sink on the fused decode path.** XQA requires global attention (`local_window_size == -1`), so **GPT-OSS / Mistral / Gemma 2** layers that combine a sliding window with a `head_sink` fall back to Flash / Flash-Decoding instead of XQA. Unifying sliding diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index 5b7624d11c6fd..11be4be1638df 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -93,6 +93,8 @@ struct GroupQueryAttentionParameters : AttentionParameters { bool is_first_prompt; // indicates whether this is first decoding step bool rotary_interleaved; bool use_smooth_softmax; + bool use_qk_norm = false; // per-head Q/K RMSNorm (QK-Norm) prologue before RoPE (inputs 14/15) + float qk_norm_epsilon = 1e-6f; // epsilon for the QK-Norm RMSNorm float softcap; AttentionQkvFormat past_kv_format; int zeros_count; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 74aeaf1285e8a..f6a7b788819f2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -157,6 +157,12 @@ struct GroupQueryAttentionData { const T* sin_cache = nullptr; const T* head_sink = nullptr; + // Optional per-head Q/K RMSNorm (QK-Norm) weights, shape (head_size,), shared across heads. + // Both are non-null together (validated in the op) and trigger the fused normalization before RoPE. + const T* q_norm_weight = nullptr; + const T* k_norm_weight = nullptr; + float qk_norm_epsilon = 1e-6f; + const float* k_scale = nullptr; const float* v_scale = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 60235024b9118..c9dbe1c73d29a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -105,6 +105,7 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); softcap_ = info.GetAttrOrDefault("softcap", 0.0f); use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; + qk_norm_epsilon_ = info.GetAttrOrDefault("qk_norm_epsilon", 1e-6f); k_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("k_quant_type", "NONE")); v_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("v_quant_type", "NONE")); @@ -229,16 +230,17 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons const Tensor* k_scale = context->Input(12); const Tensor* v_scale = context->Input(13); - // q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only - // GroupQueryAttentionPreNormFusion optimizer pass. The CUDA kernel does not implement - // the fused per-head Q/K RMS normalization prologue, so reject the node if either input - // is present rather than silently dropping the normalization. - if ((context->InputCount() > 14 && context->Input(14) != nullptr) || - (context->InputCount() > 15 && context->Input(15) != nullptr)) { + // q_norm_weight (input 14) / k_norm_weight (input 15) carry the per-head Q/K RMSNorm (QK-Norm) + // prologue weights, each of shape (head_size,) and shared across heads. They are populated by the + // GroupQueryAttentionPreNormFusion optimizer pass for Qwen3 / Gemma 2-3 / OLMo2 / SmolLM3 style + // models. Both must be present together; shape validation and wiring happen after CheckInputs, + // where head_size is known. + const Tensor* q_norm_weight = (context->InputCount() > 14) ? context->Input(14) : nullptr; + const Tensor* k_norm_weight = (context->InputCount() > 15) ? context->Input(15) : nullptr; + if ((q_norm_weight != nullptr) != (k_norm_weight != nullptr)) { return ORT_MAKE_STATUS( ONNXRUNTIME, INVALID_ARGUMENT, - "GroupQueryAttention (CUDA): q_norm_weight / k_norm_weight inputs are not supported. " - "The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP."); + "GroupQueryAttention (CUDA): q_norm_weight and k_norm_weight must be provided together."); } if (k_quant_type_ != KVQuantizationType::NONE) { @@ -307,6 +309,28 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons head_sink, parameters)); + // Validate and enable the per-head Q/K RMSNorm (QK-Norm) prologue (inputs 14/15). Both weights + // must be 1D tensors of shape (head_size) with element type T (shared across all heads). + parameters.use_qk_norm = false; + if (q_norm_weight != nullptr) { + const auto& q_norm_shape = q_norm_weight->Shape(); + const auto& k_norm_shape = k_norm_weight->Shape(); + if (q_norm_shape.NumDimensions() != 1 || q_norm_shape[0] != parameters.head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "q_norm_weight must be a 1D tensor of shape (head_size=", parameters.head_size, ")"); + } + if (k_norm_shape.NumDimensions() != 1 || k_norm_shape[0] != parameters.head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "k_norm_weight must be a 1D tensor of shape (head_size=", parameters.head_size, ")"); + } + if (!q_norm_weight->IsDataType() || !k_norm_weight->IsDataType()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "q_norm_weight/k_norm_weight type must match GroupQueryAttention input type T."); + } + parameters.use_qk_norm = true; + } + parameters.qk_norm_epsilon = qk_norm_epsilon_; + parameters.local_window_size = local_window_size_; parameters.is_unidirectional = is_unidirectional_; parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr; @@ -403,8 +427,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // An explicit ORT_ENABLE_XQA=0 (xqa_force_disabled_) still wins and turns XQA off entirely. // The dtype guard mirrors enable_xqa_ (XQA only supports fp16/bf16); ineligible cases fall back below. constexpr bool kIsFp16OrBf16 = std::is_same_v || std::is_same_v; + // The XQA decode kernel performs its own RoPE/append and bypasses PrepareQKV, so it cannot apply + // the fused QK-Norm prologue. Disable XQA when q/k norm weights are present and fall back to the + // Flash/cuDNN/MEA paths (all of which route through PrepareQKV's normalization). const bool xqa_enabled_for_run = - !xqa_force_disabled_ && (enable_xqa_ || (kIsFp16OrBf16 && use_xqa_attention_sinks)); + !xqa_force_disabled_ && !parameters.use_qk_norm && + (enable_xqa_ || (kIsFp16OrBf16 && use_xqa_attention_sinks)); if (xqa_enabled_for_run && (device_prop.major >= 8) && !parameters.is_first_prompt && @@ -530,7 +558,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.kv_num_heads); data.use_flash_attention = use_flash_attention; - data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_sequence_length > 0 && parameters.past_present_share_buffer && !is_inputs_quantized; + // The fast-decode path lets the flash kernel perform RoPE and KV-append internally, bypassing + // PrepareQKV (and therefore the fused QK-Norm prologue). Disable it when q/k norm weights are + // present so the regular FlashAttention path (which normalizes via PrepareQKV) is used instead. + data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_sequence_length > 0 && parameters.past_present_share_buffer && !is_inputs_quantized && !parameters.use_qk_norm; if (use_flash_attention) { // Allocate Flash specific buffers (Softmax LSE, Accum) @@ -713,6 +744,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons data.head_sink = reinterpret_cast(head_sink->Data()); } + if (parameters.use_qk_norm) { + data.q_norm_weight = reinterpret_cast(q_norm_weight->Data()); + data.k_norm_weight = reinterpret_cast(k_norm_weight->Data()); + data.qk_norm_epsilon = qk_norm_epsilon_; + } + #if DUMP_TENSOR_LEVEL > 0 DUMP_TENSOR_INIT(); // Dump Scales diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index d5b980bdca290..8a4b987066239 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -33,6 +33,7 @@ class GroupQueryAttention final : public CudaKernel { bool do_rotary_; bool rotary_interleaved_; bool use_smooth_softmax_; + float qk_norm_epsilon_; // epsilon for the per-head Q/K RMSNorm (QK-Norm) prologue float scale_; float softcap_; bool disable_flash_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 6a55f18bd939a..13b013b9069b2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -88,6 +88,63 @@ Status LaunchConvertHeadSinkToFloat( return CUDA_CALL(cudaGetLastError()); } +// Standalone per-head RMS normalization (QK-Norm prologue) for Q in BSNH layout. +// Each block handles one (b, s, head) vector and normalizes over head_size: +// out[c] = in[c] * rsqrt(mean(in^2) + epsilon) * weight[c] +// where weight has shape (head_size,) and is shared across heads. This is used by the shared-KV +// (kv_sequence_length == 0) path, which processes Q only; the new-KV path folds the equivalent +// normalization into UnpackRoPEAppend before RoPE. +template +__global__ void PerHeadRMSNormBSNHKernel( + T* output, const T* input, const T* weight, const int head_size, const float epsilon) { + const int s = blockIdx.x; + const int b = blockIdx.y; + const int n = blockIdx.z; + const int sequence_length = gridDim.x; + const int num_heads = gridDim.z; + const int i = threadIdx.x; + + const int64_t base = (((static_cast(b) * sequence_length + s) * num_heads) + n) * head_size; + + extern __shared__ float s_sumsq[]; + const float x = (i < head_size) ? static_cast(input[base + i]) : 0.0f; + s_sumsq[i] = x * x; + __syncthreads(); + + // Tree reduction over a power-of-two block size; padded entries (i >= head_size) are zero. + for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (i < stride) { + s_sumsq[i] += s_sumsq[i + stride]; + } + __syncthreads(); + } + + if (i < head_size) { + const float inv_rms = rsqrtf(s_sumsq[0] / static_cast(head_size) + epsilon); + output[base + i] = static_cast(x * inv_rms * static_cast(weight[i])); + } +} + +template +Status LaunchPerHeadRMSNorm( + cudaStream_t stream, T* output, const T* input, const T* weight, const float epsilon, + const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const int max_threads_per_block) { + // Round the thread count up to a power of two so the tree reduction is exact. + int tpb = 1; + while (tpb < head_size) { + tpb <<= 1; + } + if (tpb > max_threads_per_block) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "head_size (", head_size, ") exceeds max threads for QK-Norm."); + } + const dim3 grid(sequence_length, batch_size, num_heads); + const dim3 block(tpb); + const size_t smem = static_cast(tpb) * sizeof(float); + PerHeadRMSNormBSNHKernel<<>>(output, input, weight, head_size, epsilon); + return CUDA_CALL(cudaGetLastError()); +} + // Internal helper to get Q, K, V pointers, handling packed input // // This function orchestrates the preparation of Q, K, and V tensors for attention kernels. @@ -116,7 +173,7 @@ Status PrepareQKV( T* q_out = reinterpret_cast(data.qkv_buffer); - if (!parameters.is_packed_qkv && !parameters.do_rotary) { + if (!parameters.is_packed_qkv && !parameters.do_rotary && !parameters.use_qk_norm) { q_out = nullptr; } @@ -153,6 +210,16 @@ Status PrepareQKV( // above has already populated the present buffer with the shared KV data. // In both cases, only Q processing (RoPE if configured) is needed here. if (kv_sequence_length == 0) { + // QK-Norm: normalize Q (BSNH) into q_out before RoPE. K is already normalized in the shared cache + // (it was normalized when first appended), so only Q needs processing on this path. + const T* q_rope_input = data.query; + if (parameters.use_qk_norm) { + ORT_RETURN_IF_ERROR((LaunchPerHeadRMSNorm( + stream, q_out, data.query, data.q_norm_weight, parameters.qk_norm_epsilon, + batch_size, sequence_length, num_heads, head_size, max_threads_per_block))); + // RoPE (if any) runs in-place on the normalized Q; the rotary kernel handles in-place safely. + q_rope_input = q_out; + } if (parameters.do_rotary && data.cos_cache != nullptr && data.sin_cache != nullptr) { // Apply RoPE to Q only using the standalone rotary embedding kernel. // Q is in BSNH format; the kernel writes rotated Q to q_out. @@ -161,7 +228,7 @@ Status PrepareQKV( const int pos_format = data.position_ids != nullptr ? 1 : 2; if constexpr (std::is_same::value) { ORT_RETURN_IF_ERROR((LaunchRotaryEmbeddingKernel( - stream, reinterpret_cast(q_out), reinterpret_cast(data.query), + stream, reinterpret_cast(q_out), reinterpret_cast(q_rope_input), data.position_ids, data.past_seq_lens, reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), batch_size, sequence_length, num_heads, head_size, parameters.rotary_dim, max_cache_length, @@ -169,7 +236,7 @@ Status PrepareQKV( max_threads_per_block, false /* is_input_bnsh_format: Q is BSNH */))); } else if constexpr (std::is_same::value) { ORT_RETURN_IF_ERROR((LaunchRotaryEmbeddingKernel( - stream, reinterpret_cast(q_out), reinterpret_cast(data.query), + stream, reinterpret_cast(q_out), reinterpret_cast(q_rope_input), data.position_ids, data.past_seq_lens, reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), batch_size, sequence_length, num_heads, head_size, parameters.rotary_dim, max_cache_length, @@ -177,7 +244,7 @@ Status PrepareQKV( max_threads_per_block, false /* is_input_bnsh_format: Q is BSNH */))); } } - // If do_rotary is false, Q is used directly from data.query (q_out == nullptr). + // If do_rotary is false and QK-Norm is off, Q is used directly from data.query (q_out == nullptr). // K/V present buffers already point to the shared past — no work needed. } else { ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( @@ -191,6 +258,7 @@ Status PrepareQKV( reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, is_cache_bnsh, parameters.k_quant_type, + data.q_norm_weight, data.k_norm_weight, parameters.qk_norm_epsilon, stream, max_threads_per_block))); } @@ -668,6 +736,9 @@ Status ExtremeDecoding( parameters.rotary_interleaved, !past_bsnh, // is_cache_bnsh parameters.k_quant_type, + // XQA does not support the QK-Norm prologue and is disabled when q/k norm weights are present, + // so no normalization is applied on this path. + nullptr, nullptr, parameters.qk_norm_epsilon, stream, device_prop.maxThreadsPerBlock))); @@ -894,6 +965,7 @@ Status DequantizeFlashAttentionFallback( parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH), parameters.k_quant_type, + data.q_norm_weight, data.k_norm_weight, parameters.qk_norm_epsilon, stream, device_prop.maxThreadsPerBlock))); // Step 2: Dequantize Entire Cache @@ -978,6 +1050,7 @@ Status FlashAttentionAndQuantizeKV( parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, false, // BSNH for scratch KVQuantizationType::NONE, + data.q_norm_weight, data.k_norm_weight, parameters.qk_norm_epsilon, stream, max_threads_per_block))); // 2. Run Float Flash Attention diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 348dc0832d3ba..a38156c7fbffe 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -127,7 +127,9 @@ struct GQABufferRequirements { } // Unfused fallback: needs Q buffer for rotary embedding output. - if (req.qkv_buffer_bytes == 0 && (params.do_rotary || params.is_packed_qkv)) { + // QK-Norm also requires a materialized Q buffer to hold the normalized (and optionally rotated) Q, + // even when rotary is disabled and the input is not packed. + if (req.qkv_buffer_bytes == 0 && (params.do_rotary || params.is_packed_qkv || params.use_qk_norm)) { req.qkv_buffer_bytes = elem_size * q_elements; } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index 5fb36e094482b..3825f37ce1220 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -61,7 +61,12 @@ __global__ void UnpackRoPEAppend( const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, - const bool per_channel) { + const bool per_channel, + // QK-Norm (per-head Q/K RMSNorm) weights of shape (head_size,), shared across heads. + // nullptr disables normalization for the corresponding head type; V heads are never normalized. + const T* q_norm_weight, + const T* k_norm_weight, + const float qk_norm_epsilon) { using LoadT = float4; constexpr int elements_per_thread = sizeof(LoadT) / sizeof(T); @@ -84,6 +89,9 @@ __global__ void UnpackRoPEAppend( const int sequence_length = gridDim.x; // Number of new tokens in this launch __shared__ T shared_head[MAX_HEAD_SIZE]; + // Per-block reduction buffer for the QK-Norm sum-of-squares. One block handles one (b, s, head), + // and blockDim.x == head_size / elements_per_thread (<= MAX_HEAD_SIZE / elements_per_thread). + __shared__ float s_qk_reduce[MAX_HEAD_SIZE / elements_per_thread]; // Determine Head Type and Offset within the packed hidden dimension [Q, K, V] enum HeadType { QUERY, @@ -139,6 +147,38 @@ __global__ void UnpackRoPEAppend( // Non-interleaved RoPE requires full head visibility to pair channels (h, h + d/2). // We use shared memory as a staging buffer to allow any thread to access its pair. const bool is_qk = (head_type == QUERY || head_type == KEY); + + // 1.5 QK-Norm: per-head RMSNorm applied BEFORE RoPE (Qwen3 / Gemma 2-3 / OLMo2 / SmolLM3). + // Each block processes a single head, so head_type (and thus norm_weight) is uniform across the + // block, which makes the __syncthreads below safe. Q heads use q_norm_weight, K heads use + // k_norm_weight, and V heads are skipped. The weight is shared across heads (indexed by channel). + const T* norm_weight = (head_type == QUERY) ? q_norm_weight : ((head_type == KEY) ? k_norm_weight : nullptr); + if (is_qk && norm_weight != nullptr) { + float partial = 0.0f; + if (valid) { +#pragma unroll + for (int i = 0; i < elements_per_thread; ++i) { + const float f = static_cast(vals[i]); + partial += f * f; + } + } + s_qk_reduce[tid] = partial; + __syncthreads(); + // blockDim.x == head_size / elements_per_thread is small (<= 64). A linear reduction is robust + // for any (possibly non-power-of-two) thread count and avoids tree-reduction edge cases. + float sumsq = 0.0f; + for (int t = 0; t < blockDim.x; ++t) { + sumsq += s_qk_reduce[t]; + } + const float inv_rms = rsqrtf(sumsq / static_cast(head_size) + qk_norm_epsilon); + if (valid) { +#pragma unroll + for (int i = 0; i < elements_per_thread; ++i) { + vals[i] = static_cast(static_cast(vals[i]) * inv_rms * static_cast(norm_weight[h + i])); + } + } + } + if (valid && rotary_dim > 0 && is_qk && !interleaved) { T* shared_ptr = &shared_head[h]; *reinterpret_cast(shared_ptr) = *reinterpret_cast(vals); @@ -276,27 +316,32 @@ Status DispatchUnpackRoPEAppendHeadSize( const int num_heads, const int kv_num_heads, const int head_size, const int d, const int max_seqlen, const int* past_seq_lens, const T* cos_cache, const T* sin_cache, const int rotary_dim, - const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const bool per_channel) { + const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const bool per_channel, + const T* q_norm_weight, const T* k_norm_weight, const float qk_norm_epsilon) { if (head_size <= 64) { UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); } else if (head_size <= 128) { UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); } else if (head_size <= 256) { UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); } else if (head_size <= 512) { UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size (", head_size, ") exceeds maximum supported MAX_HEAD_SIZE (512)."); } @@ -318,6 +363,7 @@ Status LaunchUnpackRoPEAppend( const int* past_seq_lens, const T* cos_cache, const T* sin_cache, const int rotary_dim, const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const KVQuantizationType k_quant_type, + const T* q_norm_weight, const T* k_norm_weight, const float qk_norm_epsilon, cudaStream_t stream, const int max_threads_per_block) { static_assert(std::is_same::type>::value); static_assert(std::is_same::type>::value); @@ -365,7 +411,8 @@ Status LaunchUnpackRoPEAppend( return DispatchUnpackRoPEAppendHeadSize( grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); } else if constexpr (std::is_same::value #ifdef USE_FP8_KV_CACHE || std::is_same::value @@ -375,14 +422,16 @@ Status LaunchUnpackRoPEAppend( return DispatchUnpackRoPEAppendHeadSize( grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); #ifdef USE_INT4_KV_CACHE } else if constexpr (std::is_same::value) { // INT4 quantization (packed 2 elements per byte) return DispatchUnpackRoPEAppendHeadSize( grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); #endif } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported cache type U for GQA quantization."); diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 529eae1494e94..3f4f9e4ad0542 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -104,6 +104,10 @@ class GQAConfig: v_quant_type: str = "NONE" kv_cache_bit_width: int = 0 + # Fused per-head Q/K RMSNorm (QK-Norm) applied before RoPE. Weight shape (head_size,) shared across heads. + has_qk_norm: bool = False + qk_norm_epsilon: float = 1e-6 + # ################################################################################################# # Rotary Embedding Implementations (CPU and CUDA) @@ -205,6 +209,25 @@ def make_head_sink_initializer(head_sink, ort_type, num_heads): return helper.make_tensor(name="head_sink", data_type=ort_type, dims=[num_heads], vals=raw, raw=True) +def make_qk_norm_weights(head_size, device, torch_type, seed=7): + """Generate deterministic per-head Q/K RMSNorm weights of shape (head_size,).""" + gen = torch.Generator(device=device).manual_seed(seed) + q_w = (1.0 + 0.1 * torch.randn(head_size, generator=gen, device=device, dtype=torch.float32)).to(torch_type) + k_w = (1.0 + 0.1 * torch.randn(head_size, generator=gen, device=device, dtype=torch.float32)).to(torch_type) + return q_w, k_w + + +def apply_qk_rmsnorm(x, weight, eps): + """Reference per-head RMSNorm over the last (head_size) dim, computed in float32 then cast back. + + x_norm[c] = x[c] * rsqrt(mean(x^2) + eps) * weight[c] + """ + dtype = x.dtype + xf = x.to(torch.float32) + inv_rms = torch.rsqrt(xf.pow(2).mean(dim=-1, keepdim=True) + eps) + return (xf * inv_rms * weight.to(torch.float32)).to(dtype) + + def create_gqa_node_and_io( config: GQAConfig, ort_type, @@ -264,6 +287,8 @@ def create_gqa_node_and_io( "k_scale" if config.share_kv_scale and config.k_quant_type != "NONE" else ("v_scale" if config.v_quant_type != "NONE" else ""), + "q_norm_weight" if config.has_qk_norm else "", + "k_norm_weight" if config.has_qk_norm else "", ] # Remove trailing empty strings @@ -280,6 +305,8 @@ def create_gqa_node_and_io( else {} ) + qk_norm_attributes = {"qk_norm_epsilon": config.qk_norm_epsilon} if config.has_qk_norm else {} + node = helper.make_node( op_type="GroupQueryAttention", inputs=inputs, @@ -294,6 +321,7 @@ def create_gqa_node_and_io( smooth_softmax=1 if config.use_smooth_softmax else 0, qk_output=output_qk, **quantization_attributes, + **qk_norm_attributes, domain="com.microsoft", ) @@ -372,6 +400,10 @@ def create_gqa_node_and_io( else: graph_input.append(helper.make_tensor_value_info("head_sink", ort_type, [config.num_heads])) + if config.has_qk_norm: + graph_input.append(helper.make_tensor_value_info("q_norm_weight", ort_type, [config.head_size])) + graph_input.append(helper.make_tensor_value_info("k_norm_weight", ort_type, [config.head_size])) + # --- Graph Outputs --- output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size] if config.kv_cache_type == "int4": @@ -467,6 +499,8 @@ def gqa_prompt_func( device, share_buffer=True, ort_type=TensorProto.FLOAT16, + q_norm_weight=None, + k_norm_weight=None, ): if not config.kv_cache_type: config.kv_cache_type = "float16" if ort_type == TensorProto.FLOAT16 else "bfloat16" @@ -537,6 +571,10 @@ def gqa_prompt_func( if config.has_head_sink and head_sink is not None: bind_tensor(io_binding, "head_sink", head_sink, device, ort_type) + if config.has_qk_norm and q_norm_weight is not None: + bind_tensor(io_binding, "q_norm_weight", q_norm_weight, device, ort_type) + bind_tensor(io_binding, "k_norm_weight", k_norm_weight, device, ort_type) + # 6. Quantization scales if k_scale is not None: k_scale_ort_type = TensorProto.FLOAT @@ -627,6 +665,8 @@ def gqa_past_func( device, share_buffer=True, ort_type=TensorProto.FLOAT16, + q_norm_weight=None, + k_norm_weight=None, ): if not config.kv_cache_type: config.kv_cache_type = "float16" if ort_type == TensorProto.FLOAT16 else "bfloat16" @@ -702,6 +742,10 @@ def gqa_past_func( if config.has_head_sink and head_sink is not None and not head_sink_as_initializer: bind_tensor(io_binding, "head_sink", head_sink, device, ort_type) + if config.has_qk_norm and q_norm_weight is not None: + bind_tensor(io_binding, "q_norm_weight", q_norm_weight, device, ort_type) + bind_tensor(io_binding, "k_norm_weight", k_norm_weight, device, ort_type) + # 6. Quantization if k_scale is not None: k_scale_ort_type = TensorProto.FLOAT @@ -1000,13 +1044,18 @@ def parity_check_gqa_prompt( rotary_seqlens = torch.zeros(config.batch_size, device=device, dtype=torch.long) cos, sin, q_ro, k_ro = None, None, q, new_k + q_norm_weight, k_norm_weight = None, None + if config.has_qk_norm: + q_norm_weight, k_norm_weight = make_qk_norm_weights(config.head_size, device, torch_type) + q_ro = apply_qk_rmsnorm(q, q_norm_weight, config.qk_norm_epsilon) + k_ro = apply_qk_rmsnorm(new_k, k_norm_weight, config.qk_norm_epsilon) if config.rotary: rotary_dim = math.floor(config.head_size / 16) * 16 angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device=device) * 2 * math.pi cos = torch.cos(angle).to(dtype=torch_type) sin = torch.sin(angle).to(dtype=torch_type) - q_ro = apply_rotary_embedding(q.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device) - k_ro = apply_rotary_embedding(new_k.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device) + q_ro = apply_rotary_embedding(q_ro.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device) + k_ro = apply_rotary_embedding(k_ro.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device) position_ids, attention_bias = None, None if config.has_position_ids: @@ -1096,6 +1145,8 @@ def parity_check_gqa_prompt( device=device, share_buffer=config.share_buffer, ort_type=ort_type, + q_norm_weight=q_norm_weight, + k_norm_weight=k_norm_weight, ) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() @@ -1318,13 +1369,18 @@ def parity_check_gqa_past( v_cache_ref = v_ref_dequant.clone().transpose(1, 2) cos, sin, q_ro, k_ro = None, None, q, new_k + q_norm_weight, k_norm_weight = None, None + if config.has_qk_norm: + q_norm_weight, k_norm_weight = make_qk_norm_weights(config.head_size, device, torch_type) + q_ro = apply_qk_rmsnorm(q, q_norm_weight, config.qk_norm_epsilon) + k_ro = apply_qk_rmsnorm(new_k, k_norm_weight, config.qk_norm_epsilon) if config.rotary: rotary_dim = math.floor(config.head_size / 16) * 16 angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device=device) * 2 * math.pi cos = torch.cos(angle).to(dtype=torch_type) sin = torch.sin(angle).to(dtype=torch_type) - q_ro = apply_rotary_embedding(q.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device) - k_ro = apply_rotary_embedding(new_k.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device) + q_ro = apply_rotary_embedding(q_ro.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device) + k_ro = apply_rotary_embedding(k_ro.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device) position_ids, attention_bias = None, None total_seq_len = config.past_kv_sequence_length + config.q_sequence_length @@ -1426,6 +1482,8 @@ def parity_check_gqa_past( device=device, share_buffer=config.share_buffer, ort_type=ort_type, + q_norm_weight=q_norm_weight, + k_norm_weight=k_norm_weight, ) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() @@ -2105,6 +2163,124 @@ def test_gqa_past_flash_attention_bf16(self, name, config): ) +def gqa_qk_norm_test_cases(is_past: bool): + """Configs exercising the fused per-head Q/K RMSNorm (QK-Norm) prologue before RoPE.""" + head_sizes = [64, 128] + head_groups = [(8, 2), (4, 4)] + rotary_opts = [(False, False), (True, False), (True, True)] + packed_opts = [False, True] + idx = 0 + for h in head_sizes: + for n, n2 in head_groups: + for rotary, interleaved in rotary_opts: + if rotary and h % 16 != 0: + continue + packed = packed_opts[idx % len(packed_opts)] + idx += 1 + if is_past: + b, s, s2 = 2, 1, 127 + config = GQAConfig( + batch_size=b, + q_sequence_length=s, + kv_sequence_length=s, + past_kv_sequence_length=s2, + buffer_sequence_length=s + s2 + 8, + num_heads=n, + kv_num_heads=n2, + head_size=h, + rotary=rotary, + rotary_interleaved=interleaved, + packed=packed, + share_buffer=True, + has_qk_norm=True, + ) + else: + b, s = 2, 64 + config = GQAConfig( + batch_size=b, + q_sequence_length=s, + kv_sequence_length=s, + buffer_sequence_length=s + 8, + num_heads=n, + kv_num_heads=n2, + head_size=h, + rotary=rotary, + rotary_interleaved=interleaved, + packed=packed, + share_buffer=True, + has_qk_norm=True, + ) + name = f"{'past' if is_past else 'prompt'}_b{b}_nh{n}_{n2}_h{h}_rot{rotary}{interleaved}_pkd{packed}" + yield name, config + + +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +class TestGQAQKNorm(unittest.TestCase): + def tearDown(self): + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + gc.collect() + + @parameterized.expand(gqa_qk_norm_test_cases(is_past=False)) + def test_gqa_qk_norm_prompt(self, name, config): + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + @parameterized.expand(gqa_qk_norm_test_cases(is_past=True)) + def test_gqa_qk_norm_past(self, name, config): + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + @parameterized.expand(gqa_qk_norm_test_cases(is_past=True)) + def test_gqa_qk_norm_past_bf16(self, name, config): + if not torch.cuda.is_bf16_supported(): + self.skipTest("BFloat16 not supported on this device") + + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + + config.kv_cache_type = "bfloat16" + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) + + @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") @unittest.skipIf(not has_quantized_kv_cache(), "Quantized KV Cache is not available, skipping tests.") class TestFlashGQABF16QuantizedKV(unittest.TestCase): From a948c3ebacba9e46e8de23bccf669790730a4195 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jun 2026 20:34:43 -0700 Subject: [PATCH 3/9] upate profiler helper --- .../python/transformers/gqa_test_helper.py | 29 +++++ .../test/python/transformers/profile_gqa.py | 15 ++- .../test/python/transformers/profile_gqa.sh | 113 +++++++++++++++--- 3 files changed, 141 insertions(+), 16 deletions(-) diff --git a/onnxruntime/test/python/transformers/gqa_test_helper.py b/onnxruntime/test/python/transformers/gqa_test_helper.py index d3dd86ea9bbc6..cc50c763eb087 100644 --- a/onnxruntime/test/python/transformers/gqa_test_helper.py +++ b/onnxruntime/test/python/transformers/gqa_test_helper.py @@ -311,6 +311,8 @@ def __init__( kv_cache_type: str = "float16", share_kv_scale: bool = False, has_head_sink: bool = False, + has_qk_norm: bool = False, + qk_norm_epsilon: float = 1e-6, ): super().__init__( "GroupQueryAttention", @@ -343,6 +345,8 @@ def __init__( self.v_quant_type = v_quant_type self.share_kv_scale = share_kv_scale self.has_head_sink = has_head_sink + self.has_qk_norm = has_qk_norm + self.qk_norm_epsilon = qk_norm_epsilon # Determine bit width from cache type if applicable if kv_cache_type == "int4": self.kv_cache_bit_width = 4 @@ -363,6 +367,9 @@ def shape_dict(self): ) if self.has_head_sink: shapes["head_sink"] = (self.num_heads,) + if self.has_qk_norm: + shapes["q_norm_weight"] = (self.head_size,) + shapes["k_norm_weight"] = (self.head_size,) # Note: We don't adjust shapes for int4 here because the parent's random_inputs # creates float tensors first, then quantization will pack them return shapes @@ -378,6 +385,15 @@ def random_inputs(self): if self.has_head_sink: feeds["head_sink"] = torch.rand((self.num_heads,), device=self.device, dtype=self.dtype) + if self.has_qk_norm: + generator = torch.Generator(device=self.device).manual_seed(7) + feeds["q_norm_weight"] = (1.0 + 0.1 * torch.randn( + self.head_size, generator=generator, device=self.device, dtype=torch.float32 + )).to(self.dtype) + feeds["k_norm_weight"] = (1.0 + 0.1 * torch.randn( + self.head_size, generator=generator, device=self.device, dtype=torch.float32 + )).to(self.dtype) + # Generate quantized cache and scales if quantization is enabled if self.k_quant_type != "NONE": # Compute scales from the generated float cache @@ -432,6 +448,8 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): "head_sink" if config.has_head_sink else "", "k_scale" if config.k_quant_type != "NONE" else "", "v_scale" if config.v_quant_type != "NONE" else "", + "q_norm_weight" if config.has_qk_norm else "", + "k_norm_weight" if config.has_qk_norm else "", ] # Remove trailing empty strings while node_inputs and node_inputs[-1] == "": @@ -449,6 +467,9 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): "domain": "com.microsoft", } + if config.has_qk_norm: + node_attrs["qk_norm_epsilon"] = config.qk_norm_epsilon + # Add quantization attributes if enabled if config.k_quant_type != "NONE": node_attrs["k_quant_type"] = config.k_quant_type @@ -521,6 +542,14 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): if config.has_head_sink: graph_input.append(helper.make_tensor_value_info("head_sink", float_type, list(shape_dict["head_sink"]))) + if config.has_qk_norm: + graph_input.extend( + [ + helper.make_tensor_value_info("q_norm_weight", float_type, list(shape_dict["q_norm_weight"])), + helper.make_tensor_value_info("k_norm_weight", float_type, list(shape_dict["k_norm_weight"])), + ] + ) + # Add scale inputs for quantization # Shape depends on quantization type: # - PER_TENSOR: [1] diff --git a/onnxruntime/test/python/transformers/profile_gqa.py b/onnxruntime/test/python/transformers/profile_gqa.py index 49ce26b5126b8..394e2d6689cc9 100644 --- a/onnxruntime/test/python/transformers/profile_gqa.py +++ b/onnxruntime/test/python/transformers/profile_gqa.py @@ -73,6 +73,8 @@ def create_gqa_config( has_head_sink: bool = False, device: str = "cuda", share_kv_scale: bool = False, + has_qk_norm: bool = False, + qk_norm_epsilon: float = 1e-6, ) -> GroupQueryAttentionConfig: """Create a GQA config based on the mode.""" if mode == "fp16": @@ -118,6 +120,8 @@ def create_gqa_config( v_quant_type=v_quant_type, kv_cache_type=kv_cache_type, share_kv_scale=share_kv_scale, + has_qk_norm=has_qk_norm, + qk_norm_epsilon=qk_norm_epsilon, ) return config @@ -158,6 +162,7 @@ def run_comparison(args): print(f"Config: batch={args.batch_size}, seq_len={args.sequence_length}, past_seq={args.past_sequence_length}") print(f" num_heads={args.num_heads}, kv_heads={args.kv_num_heads}, head_size={args.head_size}") print(f" packed_qkv={args.is_packed_qkv}, rotary={not args.no_rotary}, head_sink={args.head_sink}") + print(f" qk_norm={args.qk_norm}, qk_norm_epsilon={args.qk_norm_epsilon}") print(f" warmup={args.warmup}, repeat={args.repeat}") print(f"{'=' * 70}\n") @@ -179,10 +184,14 @@ def run_comparison(args): do_rotary=not args.no_rotary, has_head_sink=args.head_sink, share_kv_scale=args.share_kv_scale, + has_qk_norm=args.qk_norm, + qk_norm_epsilon=args.qk_norm_epsilon, ) - avg_ms = benchmark_gqa(config, warmup=args.warmup, repeat=args.repeat, mode=mode) + range_name = f"{mode}_qknorm" if args.qk_norm else mode + avg_ms = benchmark_gqa(config, warmup=args.warmup, repeat=args.repeat, mode=range_name) results[mode] = avg_ms - print(f" {mode.upper():6s} (dtype={config.dtype}): {avg_ms:.4f} ms") + suffix = "+QKNorm" if args.qk_norm else "" + print(f" {mode.upper() + suffix:13s} (dtype={config.dtype}): {avg_ms:.4f} ms") # Print comparison if we have baseline baseline = "fp16" if "fp16" in results else ("bf16" if "bf16" in results else None) @@ -216,6 +225,8 @@ def main(): parser.add_argument("--repeat", type=int, default=100, help="Benchmark iterations") parser.add_argument("--is-packed-qkv", action="store_true", help="Use packed QKV") parser.add_argument("--head-sink", action="store_true", help="Add a head_sink input") + parser.add_argument("--qk-norm", action="store_true", help="Add q_norm_weight/k_norm_weight inputs") + parser.add_argument("--qk-norm-epsilon", type=float, default=1e-6, help="QK-Norm epsilon") parser.add_argument("--no-rotary", action="store_true", help="Disable rotary embeddings") parser.add_argument("--share-kv-scale", action="store_true", help="Share KV scale tensor for XQA") diff --git a/onnxruntime/test/python/transformers/profile_gqa.sh b/onnxruntime/test/python/transformers/profile_gqa.sh index 3fe10eaf35ad2..0dfb820fecbef 100644 --- a/onnxruntime/test/python/transformers/profile_gqa.sh +++ b/onnxruntime/test/python/transformers/profile_gqa.sh @@ -10,7 +10,9 @@ # ./profile_gqa.sh --all # ./profile_gqa.sh --fp16 --int8 # ./profile_gqa.sh --fp16 --past-sequence-length 8192 --local-window-size 128 -# ./profile_gqa.sh --bf16 --num-heads 64 --kv-num-heads 8 +# ./profile_gqa.sh --bf16 --num-heads 64 --kv-num-heads 8 --head-size 128 +# ./profile_gqa.sh --fp16 --compare-qk-norm --past-sequence-length 2048 +# NSYS=~/cuda13.0/bin/nsys ./profile_gqa.sh --fp16 --qk-norm # CUDA_VISIBLE_DEVICES=1 PYTHON=python3 ./profile_gqa.sh --int4 # @@ -19,6 +21,7 @@ set -o pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PY="${PYTHON:-python}" +NSYS="${NSYS:-}" # Parse arguments RUN_FP16=false @@ -31,11 +34,20 @@ RUN_BF16=false BATCH_SIZE="" SEQUENCE_LENGTH="" PAST_SEQUENCE_LENGTH="" +MAX_SEQUENCE_LENGTH="" PACKED_QKV="" SHARE_KV_SCALE="" NUM_HEADS="" KV_NUM_HEADS="" +HEAD_SIZE="" LOCAL_WINDOW_SIZE="" +HEAD_SINK="" +NO_ROTARY="" +QK_NORM=false +COMPARE_QK_NORM=false +QK_NORM_EPSILON="" +WARMUP=5 +REPEAT=100 while [[ "$#" -gt 0 ]]; do case $1 in --fp16) @@ -66,6 +78,29 @@ while [[ "$#" -gt 0 ]]; do RUN_BF16=true echo "==== 🚀 All runs enabled ====" ;; + --qk-norm) + QK_NORM=true + echo "==== QK-Norm enabled ====" + ;; + --compare-qk-norm) + COMPARE_QK_NORM=true + echo "==== Compare baseline vs QK-Norm enabled ====" + ;; + --qk-norm-epsilon) + QK_NORM_EPSILON="--qk-norm-epsilon $2" + echo "==== QK-Norm epsilon: $2 ====" + shift + ;; + --warmup) + WARMUP="$2" + echo "==== Warmup iterations: $2 ====" + shift + ;; + --repeat) + REPEAT="$2" + echo "==== Repeat iterations: $2 ====" + shift + ;; -b|--batch-size) BATCH_SIZE="--batch-size $2" echo "==== Batch size: $2 ====" @@ -81,6 +116,11 @@ while [[ "$#" -gt 0 ]]; do echo "==== Past sequence length: $2 ====" shift ;; + --max-sequence-length) + MAX_SEQUENCE_LENGTH="--max-sequence-length $2" + echo "==== Max sequence length: $2 ====" + shift + ;; --qkv) PACKED_QKV="--is-packed-qkv" echo "==== Packed QKV enabled ====" @@ -99,11 +139,24 @@ while [[ "$#" -gt 0 ]]; do echo "==== KV Num Heads: $2 ====" shift ;; + --head-size) + HEAD_SIZE="--head-size $2" + echo "==== Head size: $2 ====" + shift + ;; -w|--local-window-size) LOCAL_WINDOW_SIZE="--local-window-size $2" echo "==== Local window size: $2 ====" shift ;; + --head-sink) + HEAD_SINK="--head-sink" + echo "==== Head sink enabled ====" + ;; + --no-rotary) + NO_ROTARY="--no-rotary" + echo "==== Rotary disabled ====" + ;; *) echo "Unknown option: $1" exit 1 @@ -113,10 +166,19 @@ while [[ "$#" -gt 0 ]]; do done # Build extra args string -EXTRA_ARGS="${BATCH_SIZE} ${SEQUENCE_LENGTH} ${PAST_SEQUENCE_LENGTH} ${PACKED_QKV} ${SHARE_KV_SCALE} ${NUM_HEADS} ${KV_NUM_HEADS} ${LOCAL_WINDOW_SIZE}" +EXTRA_ARGS="${BATCH_SIZE} ${SEQUENCE_LENGTH} ${PAST_SEQUENCE_LENGTH} ${MAX_SEQUENCE_LENGTH} ${PACKED_QKV} ${SHARE_KV_SCALE} ${NUM_HEADS} ${KV_NUM_HEADS} ${HEAD_SIZE} ${LOCAL_WINDOW_SIZE} ${HEAD_SINK} ${NO_ROTARY} ${QK_NORM_EPSILON}" -if ! command -v nsys >/dev/null; then +if [[ -z "${NSYS}" ]]; then + if command -v nsys >/dev/null; then + NSYS="$(command -v nsys)" + elif [[ -x "${HOME}/cuda13.0/bin/nsys" ]]; then + NSYS="${HOME}/cuda13.0/bin/nsys" + fi +fi + +if [[ -z "${NSYS}" || ! -x "${NSYS}" ]]; then echo "Error: nsys not found. Install NVIDIA Nsight Systems or add it to PATH." + echo " Or set NSYS=/path/to/nsys (for example NSYS=~/cuda13.0/bin/nsys)." exit 1 fi @@ -129,12 +191,13 @@ else echo " Falling back to --skip-first to exclude warmup-like first calls." fi -# profile_one [env_var=value ...] +# profile_one [env_var=value ...] profile_one() { local mode="$1" local tag="$2" local base="$3" - shift 3 + local cli_extra="$4" + shift 4 local env_args=() local e @@ -145,34 +208,56 @@ profile_one() { echo "" echo "---- Profiling ${mode} ----" rm -f "${base}.nsys-rep" "${base}.sqlite" - nsys profile -t cuda,nvtx --force-overwrite true "${env_args[@]}" -o "${base}" --export=sqlite \ - "${PY}" "${SCRIPT_DIR}/profile_gqa.py" --mode "${mode}" --warmup 5 --repeat 100 ${EXTRA_ARGS} + "${NSYS}" profile -t cuda,nvtx --force-overwrite true "${env_args[@]}" -o "${base}" --export=sqlite \ + "${PY}" "${SCRIPT_DIR}/profile_gqa.py" --mode "${mode}" --warmup "${WARMUP}" --repeat "${REPEAT}" ${EXTRA_ARGS} ${cli_extra} echo "" echo "---- Kernel results (${mode}) ----" + local range_name="benchmark_${mode}" + if [[ "${cli_extra}" == *"--qk-norm"* ]]; then + range_name="benchmark_${mode}_qknorm" + fi if [[ "${HAVE_NVTX}" -eq 1 ]]; then - "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --nvtx-range "benchmark_${mode}" --tag "${tag}" + "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --nvtx-range "${range_name}" --tag "${tag}" \ + --pattern "%onnxruntime%" --pattern "%cudnn%" + else + "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --skip-first "${WARMUP}" --tag "${tag}" \ + --pattern "%onnxruntime%" --pattern "%cudnn%" + fi +} + +profile_mode() { + local mode="$1" + local tag="$2" + local base="$3" + shift 3 + + if [[ "${COMPARE_QK_NORM}" == true ]]; then + profile_one "${mode}" "${tag}" "${base}" "" "$@" + profile_one "${mode}" "${tag}QK" "${base}_qknorm" "--qk-norm" "$@" + elif [[ "${QK_NORM}" == true ]]; then + profile_one "${mode}" "${tag}QK" "${base}_qknorm" "--qk-norm" "$@" else - "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --skip-first 5 --tag "${tag}" + profile_one "${mode}" "${tag}" "${base}" "" "$@" fi } if [ "$RUN_FP16" = true ]; then - profile_one fp16 Fp16 gqa_fp16 + profile_mode fp16 Fp16 gqa_fp16 fi if [ "$RUN_BF16" = true ]; then - profile_one bf16 Bf16 gqa_bf16 + profile_mode bf16 Bf16 gqa_bf16 fi if [ "$RUN_INT8" = true ]; then - profile_one int8 Int8 gqa_int8 ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=0 + profile_mode int8 Int8 gqa_int8 ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=0 fi if [ "$RUN_INT8_QUANT" = true ]; then - profile_one int8 Int8Q gqa_int8_quant ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=1 + profile_mode int8 Int8Q gqa_int8_quant ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=1 fi if [ "$RUN_INT4" = true ]; then - profile_one int4 Int4 gqa_int4 + profile_mode int4 Int4 gqa_int4 fi From ba42f9cedacf0084f1635d7b7220f111b636b671 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jun 2026 22:19:06 -0700 Subject: [PATCH 4/9] Enable XQA for GQA QK-Norm decode --- docs/ContribOperators.md | 2 +- docs/contrib_ops/cuda/gqa.md | 20 +++--- .../cpu/bert/group_query_attention.cc | 4 +- .../cuda/bert/group_query_attention.cc | 10 +-- .../cuda/bert/group_query_attention_impl.cu | 4 +- .../cuda/bert/group_query_attention_impl.h | 6 +- .../core/graph/contrib_ops/bert_defs.cc | 2 +- .../core/optimizer/graph_transformer_utils.cc | 3 +- .../group_query_attention_pre_norm_fusion.h | 8 +-- .../group_query_attention_op_test.cc | 2 +- ...up_query_attention_pre_norm_fusion_test.cc | 32 +++++++--- .../python/transformers/gqa_test_helper.py | 12 ++-- .../test/python/transformers/test_gqa.py | 62 +++++++++++++++++++ 13 files changed, 125 insertions(+), 42 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 38d101786b41a..79296736faf97 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2675,7 +2675,7 @@ This version of the operator has been available since version 1 of the 'com.micr
v_scale (optional) : T_KV_SCALE
Scale tensor for past_value.
q_norm_weight (optional) : T
-
Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion folds that pattern into this input. Currently honored by the native WebGPU execution provider only; JSEP WebGPU/JS and other EPs must reject the node when this input is set.
+
Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion folds that pattern into this input. Currently honored by the CUDA and native WebGPU execution providers; JSEP WebGPU/JS and other EPs must reject the node when this input is set.
k_norm_weight (optional) : T
Optional 1D tensor of shape (head_size). See q_norm_weight. Must be provided together with q_norm_weight.
diff --git a/docs/contrib_ops/cuda/gqa.md b/docs/contrib_ops/cuda/gqa.md index 9aa9ab1cd464c..3ca27ae388498 100644 --- a/docs/contrib_ops/cuda/gqa.md +++ b/docs/contrib_ops/cuda/gqa.md @@ -141,8 +141,9 @@ numerical stability and the result is cast back to the operator type `T`. - The normalization is fused into the `PrepareQKV` prologue (`UnpackRoPEAppend` for the new-KV path, or a standalone per-head RMSNorm kernel for the shared-buffer Q-only decode case), so it composes with packed QKV, RoPE, KV-head expansion, and the quantized KV cache. -- Because the fused decode kernels (XQA and Flash-Decoding) do their own RoPE/append internally and - bypass `PrepareQKV`, they are disabled when QK-Norm is present (see §6). +- Because the Flash-Decoding fast path does its own RoPE/append internally and bypasses `PrepareQKV`, + it is disabled when QK-Norm is present (see §6). The non-quantized XQA decode path can still run + with QK-Norm: CUDA normalizes Q/K in the `UnpackRoPEAppend` preprocess before launching XQA. ## 4. KV Cache and Quantization @@ -223,10 +224,12 @@ order and the first eligible backend wins: The selected backend is reported in the kernel debug info as `SdpaKernel=...` when debug info is enabled (see §10). -> **QK-Norm interaction.** When `q_norm_weight` / `k_norm_weight` are present (see §3), the fused -> decode paths that bypass `PrepareQKV` — **XQA** and the **Flash-Decoding fast path** — are -> disabled so the QK-Norm prologue always runs. Such nodes therefore route to Flash Attention -> (or cuDNN SDPA / MEA / Unfused), all of which consume the normalized Q/K produced by `PrepareQKV`. +> **QK-Norm interaction.** When `q_norm_weight` / `k_norm_weight` are present (see §3), the +> Flash-Decoding fast path is disabled so the QK-Norm prologue always runs. Non-quantized XQA decode +> remains eligible for supported shapes: the `UnpackRoPEAppend` preprocess normalizes Q/K, applies +> RoPE, appends K/V, and then XQA consumes the normalized Q and cache. Quantized-cache QK-Norm decode +> still falls back to Flash Attention (or cuDNN SDPA / MEA / Unfused) until normalized-K scale +> handling is validated for XQA. ### 6.1 XQA @@ -451,9 +454,8 @@ popular LLMs. Listed roughly by impact. 1. **Fused QK-Norm (per-head Q/K RMSNorm prologue).** *Implemented.* The CUDA kernel applies the fused per-head RMSNorm to Q and K before RoPE when `q_norm_weight` / `k_norm_weight` are provided - (see §3), matching **Qwen3, Gemma 2/3, OLMo2, SmolLM3**, etc. Remaining limitation: QK-Norm - disables the fused decode kernels (XQA and Flash-Decoding) and routes through Flash / cuDNN / - MEA instead, so QK-Norm decode does not yet get the XQA fast path. + (see §3), matching **Qwen3, Gemma 2/3, OLMo2, SmolLM3**, etc. Remaining limitation: QK-Norm + disables Flash-Decoding, and quantized-cache QK-Norm does not yet get the XQA fast path. 2. **Sliding-window + attention-sink on the fused decode path.** XQA requires global attention (`local_window_size == -1`), so **GPT-OSS / Mistral / Gemma 2** layers that combine a sliding window with a `head_sink` fall back to Flash / Flash-Decoding instead of XQA. Unifying sliding diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 61ae474703213..e36bdb2de263a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -84,7 +84,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { "kv_cache_bit_width must be 0 when quantization is disabled, got ", kv_cache_bit_width_); } - // q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only + // q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the CUDA/WebGPU // GroupQueryAttentionPreNormFusion optimizer pass. The CPU kernel does not implement // the fused per-head Q/K RMS normalization prologue, so reject the node if either input // is present rather than silently dropping the normalization. @@ -93,7 +93,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { return ORT_MAKE_STATUS( ONNXRUNTIME, INVALID_ARGUMENT, "GroupQueryAttention (CPU): q_norm_weight / k_norm_weight inputs are not supported. " - "The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP."); + "The per-head Q/K RMS normalization prologue is implemented only on the CUDA and WebGPU EPs."); } GroupQueryAttentionParameters parameters = {}; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index c9dbe1c73d29a..23ebad6c22589 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -427,11 +427,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // An explicit ORT_ENABLE_XQA=0 (xqa_force_disabled_) still wins and turns XQA off entirely. // The dtype guard mirrors enable_xqa_ (XQA only supports fp16/bf16); ineligible cases fall back below. constexpr bool kIsFp16OrBf16 = std::is_same_v || std::is_same_v; - // The XQA decode kernel performs its own RoPE/append and bypasses PrepareQKV, so it cannot apply - // the fused QK-Norm prologue. Disable XQA when q/k norm weights are present and fall back to the - // Flash/cuDNN/MEA paths (all of which route through PrepareQKV's normalization). + // QK-Norm can use XQA for the non-quantized KV-cache path: ExtremeDecoding runs the same + // UnpackRoPEAppend preprocess before XQA, so Q/K can be normalized before the XQA kernel consumes + // Q and the appended cache. Keep quantized QK-Norm off the XQA route until scale correctness is + // validated for normalized K before quantized-cache append. + const bool xqa_qk_norm_supported = !parameters.use_qk_norm || !is_inputs_quantized; const bool xqa_enabled_for_run = - !xqa_force_disabled_ && !parameters.use_qk_norm && + !xqa_force_disabled_ && xqa_qk_norm_supported && (enable_xqa_ || (kIsFp16OrBf16 && use_xqa_attention_sinks)); if (xqa_enabled_for_run && (device_prop.major >= 8) && diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 13b013b9069b2..fda4ae66417fc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -736,9 +736,7 @@ Status ExtremeDecoding( parameters.rotary_interleaved, !past_bsnh, // is_cache_bnsh parameters.k_quant_type, - // XQA does not support the QK-Norm prologue and is disabled when q/k norm weights are present, - // so no normalization is applied on this path. - nullptr, nullptr, parameters.qk_norm_epsilon, + data.q_norm_weight, data.k_norm_weight, parameters.qk_norm_epsilon, stream, device_prop.maxThreadsPerBlock))); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index a38156c7fbffe..b1f979fef4169 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -76,9 +76,9 @@ struct GQABufferRequirements { const size_t v_elements = k_elements; if (use_xqa) { - if (params.do_rotary || params.is_packed_qkv) { - // XQA need scratch for rotated/unpacked Q. - // RoPE K is written directly to cache by the fused kernel. + if (params.do_rotary || params.is_packed_qkv || params.use_qk_norm) { + // XQA needs scratch for rotated/unpacked/normalized Q. + // RoPE/QK-Norm K is written directly to cache by the fused preprocess kernel. req.qkv_buffer_bytes = elem_size * q_elements; } return req; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 5d537fa59bfab..896774fb5c8d8 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1324,7 +1324,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a " "per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap " "their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion " - "folds that pattern into this input. Currently honored by the native WebGPU execution provider only; " + "folds that pattern into this input. Currently honored by the CUDA and native WebGPU execution providers; " "JSEP WebGPU/JS and other EPs must reject the node when this input is set.", "T", OpSchema::Optional) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index a998eeacda734..56fc3ca7a9855 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -453,7 +453,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_ep)); transformers.emplace_back(std::make_unique( - InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); + InlinedHashSet{onnxruntime::kCudaExecutionProvider, + onnxruntime::kWebGpuExecutionProvider})); bool has_matmul_nbits_mlp_kernel = false; bool has_matmul_nbits_qkv_kernel = false; if (execution_providers != nullptr) { diff --git a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h index b69199bb5324d..f3cb131a0b8fb 100644 --- a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h +++ b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h @@ -27,12 +27,12 @@ on inputs 0 (query) and 1 (key) of an unfused GroupQueryAttention node: When matched, the six Reshape/SLN nodes are removed and the pre-norm Q and K projections feed GQA directly. The kernel is responsible for applying the RMS -norm internally (currently the WebGPU EP). +norm internally (currently the CUDA and WebGPU EPs). Only fires for execution providers passed in `compatible_execution_providers`. -At present this fusion is registered for the WebGPU EP only, because the -in-kernel norm path is currently implemented there. The CPU, CUDA, and JSEP -GroupQueryAttention kernels reject q_norm_weight / k_norm_weight inputs. +At present this fusion is registered for the CUDA and WebGPU EPs, because those +kernels implement the in-kernel norm path. CPU and JSEP GroupQueryAttention +kernels reject q_norm_weight / k_norm_weight inputs. */ class GroupQueryAttentionPreNormFusion : public GraphTransformer { public: diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 821f43971848a..f72cac1aa66b8 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -492,7 +492,7 @@ static void ApplyPerHeadRmsNormBSNH(std::vector& data, } // Runs GroupQueryAttention with do_rotary=1 and optional q/k norm weights. -// If q_norm_weight/k_norm_weight are provided, this exercises the WebGPU-only +// If q_norm_weight/k_norm_weight are provided, this exercises the EP-specific // q/k norm input contract. CPU callers should pass nullptr for both and feed // pre-normalized Q/K values instead. static std::vector RunGQARotaryWithOptionalQKNorm( diff --git a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc index ef0f40a51c838..d5e2a8deb877d 100644 --- a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc +++ b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc @@ -246,16 +246,34 @@ Status CheckUnfusedGraph(const Graph& graph) { } // namespace -// Helper: build the transformer registered for the WebGPU EP only (matches production). +// Helper: build the transformer for the original WebGPU-only tests. std::unique_ptr MakeWebGpuTransformer() { return std::make_unique( InlinedHashSet{kWebGpuExecutionProvider}); } +// Helper: build the production-compatible transformer. +std::unique_ptr MakeCudaWebGpuTransformer() { + return std::make_unique( + InlinedHashSet{kCudaExecutionProvider, kWebGpuExecutionProvider}); +} + TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesQwenPattern) { auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); }; ASSERT_STATUS_OK(TestGraphTransformer( - build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + build, /*opset_version=*/21, *logger_, MakeCudaWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesCudaAssignedQwenPattern) { + auto build = [](ModelTestBuilder& builder) { + BuildQwenQkPostNormPattern(builder, BuildOptions{}); + for (auto& node : builder.graph_.Nodes()) { + const_cast(node).SetExecutionProviderType(kCudaExecutionProvider); + } + }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeCudaWebGpuTransformer(), TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph)); } @@ -272,7 +290,7 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesQwenPatter // reject this *CurrentOpset model at load; allow the unreleased opset. Remove once opset 27 ships. // Tracked by #28966; this is the WebGPU attention-fusion path that surfaced #28969. ASSERT_STATUS_OK(TestGraphTransformer( - build, /*opset_version=*/current_opset, *logger_, MakeWebGpuTransformer(), + build, /*opset_version=*/current_opset, *logger_, MakeCudaWebGpuTransformer(), TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph, ModelOptions{kAllowReleasedOpsetsOnly, /*strict_shape_type_inference*/ false})); } @@ -363,7 +381,7 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNon1DNor } TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsCpuEp) { - // Build the pattern but assign all nodes to CPU EP. The fusion is gated to WebGPU only, + // Build the pattern but assign all nodes to CPU EP. The fusion is gated to CUDA/WebGPU only, // so the graph must remain unfused. auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); @@ -372,13 +390,13 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsCpuEp) { } }; ASSERT_STATUS_OK(TestGraphTransformer( - build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + build, /*opset_version=*/21, *logger_, MakeCudaWebGpuTransformer(), TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); } TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsJsEp) { // JSEP does not implement the fused per-head Q/K RMSNorm prologue, so the optimizer - // (which we now register for WebGPU only) must leave JSEP-assigned graphs alone. + // (which we now register for CUDA/WebGPU only) must leave JSEP-assigned graphs alone. auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); for (auto& node : builder.graph_.Nodes()) { @@ -386,7 +404,7 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsJsEp) { } }; ASSERT_STATUS_OK(TestGraphTransformer( - build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + build, /*opset_version=*/21, *logger_, MakeCudaWebGpuTransformer(), TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); } diff --git a/onnxruntime/test/python/transformers/gqa_test_helper.py b/onnxruntime/test/python/transformers/gqa_test_helper.py index cc50c763eb087..73fe45ac8fc7e 100644 --- a/onnxruntime/test/python/transformers/gqa_test_helper.py +++ b/onnxruntime/test/python/transformers/gqa_test_helper.py @@ -387,12 +387,12 @@ def random_inputs(self): if self.has_qk_norm: generator = torch.Generator(device=self.device).manual_seed(7) - feeds["q_norm_weight"] = (1.0 + 0.1 * torch.randn( - self.head_size, generator=generator, device=self.device, dtype=torch.float32 - )).to(self.dtype) - feeds["k_norm_weight"] = (1.0 + 0.1 * torch.randn( - self.head_size, generator=generator, device=self.device, dtype=torch.float32 - )).to(self.dtype) + feeds["q_norm_weight"] = ( + 1.0 + 0.1 * torch.randn(self.head_size, generator=generator, device=self.device, dtype=torch.float32) + ).to(self.dtype) + feeds["k_norm_weight"] = ( + 1.0 + 0.1 * torch.randn(self.head_size, generator=generator, device=self.device, dtype=torch.float32) + ).to(self.dtype) # Generate quantized cache and scales if quantization is enabled if self.k_quant_type != "NONE": diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 3f4f9e4ad0542..2a8de08daaa0d 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -2258,6 +2258,68 @@ def test_gqa_qk_norm_past(self, name, config): atol=atol["fp16"], ) + def test_gqa_qk_norm_past_xqa(self): + config = GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=127, + buffer_sequence_length=136, + num_heads=32, + kv_num_heads=8, + head_size=128, + rotary=True, + rotary_interleaved=False, + packed=False, + share_buffer=True, + has_qk_norm=True, + ) + + with scoped_env_var("ORT_ENABLE_XQA", "1"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_gqa_qk_norm_past_xqa_bf16(self): + if not torch.cuda.is_bf16_supported(): + self.skipTest("BFloat16 not supported on this device") + + config = GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=127, + buffer_sequence_length=136, + num_heads=32, + kv_num_heads=8, + head_size=128, + rotary=True, + rotary_interleaved=False, + packed=False, + share_buffer=True, + has_qk_norm=True, + ) + config.kv_cache_type = "bfloat16" + + with scoped_env_var("ORT_ENABLE_XQA", "1"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) + @parameterized.expand(gqa_qk_norm_test_cases(is_past=True)) def test_gqa_qk_norm_past_bf16(self, name, config): if not torch.cuda.is_bf16_supported(): From db4b25496a06a5ab9b61844cee6da07fc720955c Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 20 Jun 2026 01:10:49 -0700 Subject: [PATCH 5/9] fix test --- .../contrib_ops/group_query_attention_op_test.cc | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index f72cac1aa66b8..b6ce51ad0fd3a 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -93,7 +93,7 @@ static void RunGQASeqlensKTest( tester.Run(expect, expected_message, {}, nullptr, &execution_providers); } -// CPU GroupQueryAttention does not implement the WebGPU-only fused Q/K RMS-norm prologue +// CPU GroupQueryAttention does not implement the CUDA/WebGPU fused Q/K RMS-norm prologue // inputs (q_norm_weight/k_norm_weight at indices 14/15). Ensure we reject these explicitly. TEST(GroupQueryAttentionTest, CpuRejectsQKNormWeightInputs) { constexpr int batch_size = 1; @@ -145,9 +145,9 @@ TEST(GroupQueryAttentionTest, CpuRejectsQKNormWeightInputs) { {}, nullptr, &execution_providers); } -// CUDA GroupQueryAttention also does not implement the WebGPU-only fused Q/K RMS-norm -// prologue inputs (q_norm_weight/k_norm_weight at indices 14/15). Ensure the guard is covered. -TEST(GroupQueryAttentionTest, CudaRejectsQKNormWeightInputs) { +// CUDA GroupQueryAttention implements the fused Q/K RMS-norm prologue inputs +// (q_norm_weight/k_norm_weight at indices 14/15). Ensure the input contract is accepted. +TEST(GroupQueryAttentionTest, CudaAcceptsQKNormWeightInputs) { auto cuda_ep = DefaultCudaExecutionProvider(); if (!cuda_ep) { GTEST_SKIP() << "CUDA EP not available"; @@ -195,10 +195,13 @@ TEST(GroupQueryAttentionTest, CudaRejectsQKNormWeightInputs) { tester.AddOutput("present_value", {batch_size, kv_num_heads, sequence_length, head_size}, std::vector(batch_size * kv_num_heads * sequence_length * head_size, MLFloat16(0.0f))); + // This is an input-contract smoke test. The dedicated QK-Norm functional tests cover numerical equivalence. + tester.SetOutputTolerance(1e6f); + std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectFailure, - "q_norm_weight / k_norm_weight inputs are not supported", + tester.Run(OpTester::ExpectResult::kExpectSuccess, + "", {}, nullptr, &execution_providers); } From 2379cbbab6b54f0dfae1155ac7a3fbdf8228614d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 20 Jun 2026 18:24:14 +0000 Subject: [PATCH 6/9] docs: fix relative links in CUDA GQA documentation --- docs/contrib_ops/cuda/gqa.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/contrib_ops/cuda/gqa.md b/docs/contrib_ops/cuda/gqa.md index 3ca27ae388498..aa289a0de9a3c 100644 --- a/docs/contrib_ops/cuda/gqa.md +++ b/docs/contrib_ops/cuda/gqa.md @@ -5,7 +5,7 @@ the CUDA kernel backends and how one is selected, and the attention-sink (`head_ that is accelerated by the XQA kernel. For CPU-specific implementation details (including the quantized KV-cache flash path), see -[cpu/gqa.md](cpu/gqa.md). +[cpu/gqa.md](../cpu/gqa.md). --- @@ -40,10 +40,10 @@ group of `num_heads / kv_num_heads` query heads. The operator also supports: - Smooth softmax, including a per-head attention sink (`head_sink`) The operator schema is defined in -[onnxruntime/core/graph/contrib_ops/bert_defs.cc](../../onnxruntime/core/graph/contrib_ops/bert_defs.cc). +[onnxruntime/core/graph/contrib_ops/bert_defs.cc](../../../onnxruntime/core/graph/contrib_ops/bert_defs.cc). The CUDA kernel is implemented in -[onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc](../../onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc) -and [group_query_attention_impl.cu](../../onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu). +[onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc](../../../onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc) +and [group_query_attention_impl.cu](../../../onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu). ## 2. Operator Schema @@ -383,10 +383,10 @@ Located in `onnxruntime/test/python/transformers/`: | Script | Purpose | |--------|---------| -| [profile_gqa.py](../../onnxruntime/test/python/transformers/profile_gqa.py) | Profile GQA (incl. quantized KV cache) with NVTX markers; examples for Nsight Compute (`ncu`) and Nsight Systems (`nsys`). | -| [benchmark_gqa.py](../../onnxruntime/test/python/transformers/benchmark_gqa.py) | Triton-based throughput comparison across dense / local / packed-QKV and INT4/INT8/FP8 variants. | -| [benchmark_gqa_windows.py](../../onnxruntime/test/python/transformers/benchmark_gqa_windows.py) | GQA benchmark variant for Windows. | -| [benchmark_gqa_cpu_flash.py](../../onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py) | CPU flash-vs-naive GQA benchmark. | +| [profile_gqa.py](../../../onnxruntime/test/python/transformers/profile_gqa.py) | Profile GQA (incl. quantized KV cache) with NVTX markers; examples for Nsight Compute (`ncu`) and Nsight Systems (`nsys`). | +| [benchmark_gqa.py](../../../onnxruntime/test/python/transformers/benchmark_gqa.py) | Triton-based throughput comparison across dense / local / packed-QKV and INT4/INT8/FP8 variants. | +| [benchmark_gqa_windows.py](../../../onnxruntime/test/python/transformers/benchmark_gqa_windows.py) | GQA benchmark variant for Windows. | +| [benchmark_gqa_cpu_flash.py](../../../onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py) | CPU flash-vs-naive GQA benchmark. | Example kernel-level and timeline profiling: @@ -432,7 +432,7 @@ Other ways to shorten the iteration loop: ## 12. Testing CUDA parity tests live in -[onnxruntime/test/python/transformers/test_gqa.py](../../onnxruntime/test/python/transformers/test_gqa.py): +[onnxruntime/test/python/transformers/test_gqa.py](../../../onnxruntime/test/python/transformers/test_gqa.py): - `TestXQAQuantizedParity` — XQA per-tensor int8 quantized decode parity. - `TestXQAHeadSinkParity` — non-quantized XQA decode parity with a `head_sink` (attention sink) input. From 3ae346b71aa3750148d5fc0c0fec15f11ca2390e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 20 Jun 2026 11:44:40 -0700 Subject: [PATCH 7/9] address feedback --- .../optimizer/group_query_attention_pre_norm_fusion_test.cc | 2 +- onnxruntime/test/python/transformers/test_gqa.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc index d5e2a8deb877d..ba9656d304773 100644 --- a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc +++ b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc @@ -269,7 +269,7 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesCudaAssign auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); for (auto& node : builder.graph_.Nodes()) { - const_cast(node).SetExecutionProviderType(kCudaExecutionProvider); + node.SetExecutionProviderType(kCudaExecutionProvider); } }; ASSERT_STATUS_OK(TestGraphTransformer( diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 2a8de08daaa0d..c9e739faefe72 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -571,7 +571,7 @@ def gqa_prompt_func( if config.has_head_sink and head_sink is not None: bind_tensor(io_binding, "head_sink", head_sink, device, ort_type) - if config.has_qk_norm and q_norm_weight is not None: + if config.has_qk_norm and q_norm_weight is not None and k_norm_weight is not None: bind_tensor(io_binding, "q_norm_weight", q_norm_weight, device, ort_type) bind_tensor(io_binding, "k_norm_weight", k_norm_weight, device, ort_type) @@ -742,7 +742,7 @@ def gqa_past_func( if config.has_head_sink and head_sink is not None and not head_sink_as_initializer: bind_tensor(io_binding, "head_sink", head_sink, device, ort_type) - if config.has_qk_norm and q_norm_weight is not None: + if config.has_qk_norm and q_norm_weight is not None and k_norm_weight is not None: bind_tensor(io_binding, "q_norm_weight", q_norm_weight, device, ort_type) bind_tensor(io_binding, "k_norm_weight", k_norm_weight, device, ort_type) From eefc52803579efbbcec7e309d05bd16082945004 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 20 Jun 2026 12:04:00 -0700 Subject: [PATCH 8/9] Address review: optimize QK-Norm reduction and clarify test wording - group_query_attention_qkv.cuh: reduce QK-Norm sum once in tid==0 and broadcast inv_rms via shared memory to avoid redundant O(blockDim.x^2) shared reads. - pre_norm_fusion_test.cc: comment now says CUDA+WebGPU fusion path (test runs both). - test_gqa.py: TestGQAQKNorm now gated on has_cuda_device(80) with an accurate skip message instead of the misleading Flash Attention check. --- .../cuda/bert/group_query_attention_qkv.cuh | 14 ++++++++++---- .../group_query_attention_pre_norm_fusion_test.cc | 2 +- onnxruntime/test/python/transformers/test_gqa.py | 2 +- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index 3825f37ce1220..0c62aef11d53a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -166,11 +166,17 @@ __global__ void UnpackRoPEAppend( __syncthreads(); // blockDim.x == head_size / elements_per_thread is small (<= 64). A linear reduction is robust // for any (possibly non-power-of-two) thread count and avoids tree-reduction edge cases. - float sumsq = 0.0f; - for (int t = 0; t < blockDim.x; ++t) { - sumsq += s_qk_reduce[t]; + // Reduce once in tid==0 and broadcast inv_rms via shared memory to avoid the redundant + // O(blockDim.x^2) shared reads that result from every thread summing the partials. + if (tid == 0) { + float sumsq = 0.0f; + for (int t = 0; t < blockDim.x; ++t) { + sumsq += s_qk_reduce[t]; + } + s_qk_reduce[0] = rsqrtf(sumsq / static_cast(head_size) + qk_norm_epsilon); } - const float inv_rms = rsqrtf(sumsq / static_cast(head_size) + qk_norm_epsilon); + __syncthreads(); + const float inv_rms = s_qk_reduce[0]; if (valid) { #pragma unroll for (int i = 0; i < elements_per_thread; ++i) { diff --git a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc index ba9656d304773..aa89c79030dd0 100644 --- a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc +++ b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc @@ -288,7 +288,7 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesQwenPatter auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); }; // opset 27 is under development in ONNX 1.22 (released map-max 27 > last release 26), so strict legs // reject this *CurrentOpset model at load; allow the unreleased opset. Remove once opset 27 ships. - // Tracked by #28966; this is the WebGPU attention-fusion path that surfaced #28969. + // Tracked by #28966; this runs the CUDA+WebGPU attention-fusion path that surfaced #28969. ASSERT_STATUS_OK(TestGraphTransformer( build, /*opset_version=*/current_opset, *logger_, MakeCudaWebGpuTransformer(), TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph, diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index c9e739faefe72..22259a4070899 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -2214,7 +2214,7 @@ def gqa_qk_norm_test_cases(is_past: bool): yield name, config -@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +@unittest.skipIf(not has_cuda_device(80), "CUDA GQA QK-Norm requires Ampere or higher GPU, skipping tests.") class TestGQAQKNorm(unittest.TestCase): def tearDown(self): if torch.cuda.is_available(): From 81ffb7189ff8cf3d5700e6f0914f607aa42555d7 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Jun 2026 13:25:57 -0700 Subject: [PATCH 9/9] address feedbacks --- .../jsep/webgpu/ops/group-query-attention.ts | 4 +- .../cuda/bert/group_query_attention.cc | 13 +++- .../group_query_attention_pre_norm_fusion.cc | 3 + .../DmlOperatorGroupQueryAttention.cpp | 7 ++ .../group_query_attention_op_test.cc | 74 ++++++++++++++++--- ...up_query_attention_pre_norm_fusion_test.cc | 21 ++++++ .../test/python/transformers/test_gqa.py | 50 ++++++++++--- 7 files changed, 145 insertions(+), 27 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index ada0c65bdd8a7..4e2d3b82ce155 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -328,14 +328,14 @@ const generatePositionIdsProgramInfo = ( }; export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => { - // q_norm_weight (input 14) / k_norm_weight (input 15) are emitted by the WebGPU-only + // q_norm_weight (input 14) / k_norm_weight (input 15) are emitted by the CUDA/native WebGPU // GroupQueryAttentionPreNormFusion optimizer pass. JSEP does not implement the fused // per-head Q/K RMS normalization prologue, so reject the node if either input is present // (regardless of rank, including scalars) rather than silently dropping the normalization. if ((context.inputs.length > 14 && context.inputs[14]) || (context.inputs.length > 15 && context.inputs[15])) { throw new Error( 'GroupQueryAttention (JSEP): q_norm_weight / k_norm_weight inputs are not supported. ' + - 'The per-head Q/K RMS normalization prologue is implemented only on the native WebGPU EP.', + 'The per-head Q/K RMS normalization prologue is implemented only on the CUDA and native WebGPU EPs.', ); } const params = validateInputs(context.inputs, attributes); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index ee70cf892949b..fb82bbee2a84b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -1,9 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include #include #include -#include #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cuda_type_conversion.h" @@ -106,6 +107,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) softcap_ = info.GetAttrOrDefault("softcap", 0.0f); use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; qk_norm_epsilon_ = info.GetAttrOrDefault("qk_norm_epsilon", 1e-6f); + ORT_ENFORCE(std::isfinite(qk_norm_epsilon_) && qk_norm_epsilon_ > 0.0f, + "GroupQueryAttention (CUDA): qk_norm_epsilon must be finite and positive."); k_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("k_quant_type", "NONE")); v_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("v_quant_type", "NONE")); @@ -304,7 +307,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // Validate and enable the per-head Q/K RMSNorm (QK-Norm) prologue (inputs 14/15). Both weights // must be 1D tensors of shape (head_size) with element type T (shared across all heads). - parameters.use_qk_norm = false; if (q_norm_weight != nullptr) { const auto& q_norm_shape = q_norm_weight->Shape(); const auto& k_norm_shape = k_norm_weight->Shape(); @@ -417,6 +419,9 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // UnpackRoPEAppend preprocess before XQA, so Q/K can be normalized before the XQA kernel consumes // Q and the appended cache. Keep quantized QK-Norm off the XQA route until scale correctness is // validated for normalized K before quantized-cache append. + const bool xqa_qk_norm_ok = !parameters.use_qk_norm || !is_inputs_quantized; + const bool use_xqa_attention_sinks = parameters.use_smooth_softmax && head_sink != nullptr && !is_inputs_quantized; + const bool xqa_smooth_softmax_ok = !parameters.use_smooth_softmax || (head_sink != nullptr && !is_inputs_quantized); if (enable_xqa_ && (device_prop.major >= 8) && !parameters.is_first_prompt && @@ -425,8 +430,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.past_present_share_buffer && parameters.softcap == 0.0f && parameters.local_window_size == -1 && - (!parameters.use_qk_norm || !is_inputs_quantized) && - (!parameters.use_smooth_softmax || (head_sink != nullptr && !is_inputs_quantized))) { + xqa_qk_norm_ok && + xqa_smooth_softmax_ok) { int group_size = parameters.num_heads / parameters.kv_num_heads; bool is_int8_quantized_supported = is_int8 && diff --git a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc index 3271c4cc9a22f..f8763306ced46 100644 --- a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc +++ b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc @@ -167,6 +167,9 @@ bool MatchPreNormReshapeChain(Graph& graph, } const auto* sln_eps_attr = graph_utils::GetNodeAttribute(*sln, "epsilon"); const float sln_eps = (sln_eps_attr == nullptr) ? 1e-5f : sln_eps_attr->f(); + if (!std::isfinite(sln_eps) || sln_eps <= 0.0f) { + return false; + } // Inner reshape (between projection and SLN). if (sln->InputDefs().empty() || sln->InputDefs()[0] == nullptr) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp index bf5a182f9662c..ae5434b2bae04 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp @@ -34,6 +34,13 @@ class DmlOperatorGroupQueryAttention : public DmlOperator, public GroupQueryAtte ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 1); ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1); + constexpr uint32_t qNormWeightIndex = 14; + constexpr uint32_t kNormWeightIndex = 15; + const bool hasQNormWeight = kernelCreationContext.GetInputCount() > qNormWeightIndex && kernelCreationContext.IsInputValid(qNormWeightIndex); + const bool hasKNormWeight = kernelCreationContext.GetInputCount() > kNormWeightIndex && kernelCreationContext.IsInputValid(kNormWeightIndex); + ML_CHECK_VALID_ARGUMENT(!hasQNormWeight && !hasKNormWeight, + "GroupQueryAttention (DML): q_norm_weight / k_norm_weight inputs are not supported."); + std::vector> inputIndices(inputCount); inputIndices[queryIndex] = queryIndex; inputIndices[keyIndex] = keyIndex; diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index aeb63f7edbcc1..3e1f87af8b36b 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -168,9 +168,14 @@ TEST(GroupQueryAttentionTest, CpuRejectsQKNormWeightInputs) { {}, nullptr, &execution_providers); } -// CUDA GroupQueryAttention implements the fused Q/K RMS-norm prologue inputs -// (q_norm_weight/k_norm_weight at indices 14/15). Ensure the input contract is accepted. -TEST(GroupQueryAttentionTest, CudaAcceptsQKNormWeightInputs) { +static void RunCudaQKNormInputContractTest( + OpTester::ExpectResult expect, + const std::string& expected_message, + std::optional qk_norm_epsilon = std::nullopt, + bool include_q_norm_weight = true, + bool include_k_norm_weight = true, + int q_norm_weight_size = 8, + int k_norm_weight_size = 8) { auto cuda_ep = DefaultCudaExecutionProvider(); if (!cuda_ep) { GTEST_SKIP() << "CUDA EP not available"; @@ -187,6 +192,9 @@ TEST(GroupQueryAttentionTest, CudaAcceptsQKNormWeightInputs) { OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + if (qk_norm_epsilon.has_value()) { + tester.AddAttribute("qk_norm_epsilon", *qk_norm_epsilon); + } tester.AddInput("query", {batch_size, sequence_length, hidden_size}, std::vector(batch_size * sequence_length * hidden_size, MLFloat16(0.1f))); @@ -208,8 +216,16 @@ TEST(GroupQueryAttentionTest, CudaAcceptsQKNormWeightInputs) { tester.AddOptionalInputEdge(); // k_scale tester.AddOptionalInputEdge(); // v_scale - tester.AddInput("q_norm_weight", {head_size}, std::vector(head_size, MLFloat16(1.0f))); - tester.AddInput("k_norm_weight", {head_size}, std::vector(head_size, MLFloat16(1.0f))); + if (include_q_norm_weight) { + tester.AddInput("q_norm_weight", {q_norm_weight_size}, + std::vector(q_norm_weight_size, MLFloat16(1.0f))); + } else if (include_k_norm_weight) { + tester.AddOptionalInputEdge(); + } + if (include_k_norm_weight) { + tester.AddInput("k_norm_weight", {k_norm_weight_size}, + std::vector(k_norm_weight_size, MLFloat16(1.0f))); + } tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, std::vector(batch_size * sequence_length * hidden_size, MLFloat16(0.0f))); @@ -218,14 +234,50 @@ TEST(GroupQueryAttentionTest, CudaAcceptsQKNormWeightInputs) { tester.AddOutput("present_value", {batch_size, kv_num_heads, sequence_length, head_size}, std::vector(batch_size * kv_num_heads * sequence_length * head_size, MLFloat16(0.0f))); - // This is an input-contract smoke test. The dedicated QK-Norm functional tests cover numerical equivalence. - tester.SetOutputTolerance(1e6f); + if (expect == OpTester::ExpectResult::kExpectSuccess) { + // This is an input-contract smoke test. The dedicated QK-Norm functional tests cover numerical equivalence. + tester.SetOutputTolerance(1e6f); + } std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, - "", - {}, nullptr, &execution_providers); + execution_providers.push_back(std::move(cuda_ep)); + tester.Run(expect, expected_message, {}, nullptr, &execution_providers); +} + +// CUDA GroupQueryAttention implements the fused Q/K RMS-norm prologue inputs +// (q_norm_weight/k_norm_weight at indices 14/15). Ensure the input contract is accepted. +TEST(GroupQueryAttentionTest, CudaAcceptsQKNormWeightInputs) { + RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectSuccess, ""); +} + +TEST(GroupQueryAttentionTest, CudaRejectsQKNormMissingKWeight) { + RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectFailure, + "q_norm_weight and k_norm_weight must be provided together", + std::nullopt, + /*include_q_norm_weight=*/true, + /*include_k_norm_weight=*/false); +} + +TEST(GroupQueryAttentionTest, CudaRejectsQKNormWrongKWeightShape) { + RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectFailure, + "k_norm_weight must be a 1D tensor of shape", + std::nullopt, + /*include_q_norm_weight=*/true, + /*include_k_norm_weight=*/true, + /*q_norm_weight_size=*/8, + /*k_norm_weight_size=*/9); +} + +TEST(GroupQueryAttentionTest, CudaRejectsZeroQKNormEpsilon) { + RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectFailure, + "qk_norm_epsilon must be finite and positive", + 0.0f); +} + +TEST(GroupQueryAttentionTest, CudaRejectsNegativeQKNormEpsilon) { + RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectFailure, + "qk_norm_epsilon must be finite and positive", + -1.0e-6f); } // Regression: negative seqlens_k wraps to huge size_t, causing GEMM OOB. diff --git a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc index aa89c79030dd0..2ac880a5459c5 100644 --- a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc +++ b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include #include "core/graph/node_attr_utils.h" @@ -362,6 +363,26 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsEpsilonM TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); } +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNonPositiveEpsilon) { + BuildOptions opts; + opts.q_epsilon = 0.0f; + opts.k_epsilon = 0.0f; + auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNonFiniteEpsilon) { + BuildOptions opts; + opts.q_epsilon = std::numeric_limits::quiet_NaN(); + opts.k_epsilon = std::numeric_limits::quiet_NaN(); + auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); +} + TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsBadInnerReshape) { BuildOptions opts; opts.break_k_inner_reshape_shape = true; diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index fb60ed897528b..f0f6865bb463d 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -513,8 +513,9 @@ def gqa_prompt_func( q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) if new_k is not None: - new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + kv_hidden_size = config.kv_num_heads * config.head_size + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, kv_hidden_size)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, kv_hidden_size)) sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=[resolve_cuda_plugin_ep(ep)]) @@ -681,8 +682,9 @@ def gqa_past_func( q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) if new_k is not None: - new_k = torch.reshape(new_k, (config.batch_size, config.q_sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.q_sequence_length, -1)) + kv_hidden_size = config.kv_num_heads * config.head_size + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, kv_hidden_size)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, kv_hidden_size)) sess_options = SessionOptions() # sess_options.log_severity_level = 0 @@ -690,7 +692,7 @@ def gqa_past_func( io_binding = ort_session.io_binding() # Common inputs - total_seq_len = config.past_kv_sequence_length + config.q_sequence_length + total_seq_len = config.past_kv_sequence_length + config.kv_sequence_length # 1. Bind 'query' bind_tensor(io_binding, "query", q, device, ort_type) @@ -1328,7 +1330,7 @@ def parity_check_gqa_past( new_k = ( torch.randn( config.batch_size, - config.q_sequence_length, + config.kv_sequence_length, config.kv_num_heads, config.head_size, device=device, @@ -1383,7 +1385,7 @@ def parity_check_gqa_past( k_ro = apply_rotary_embedding(k_ro.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device) position_ids, attention_bias = None, None - total_seq_len = config.past_kv_sequence_length + config.q_sequence_length + total_seq_len = config.past_kv_sequence_length + config.kv_sequence_length if config.has_position_ids: position_ids = (cache_seqlens.unsqueeze(1) + torch.arange(config.q_sequence_length, device=device)).long() if config.has_attention_bias: @@ -1397,7 +1399,7 @@ def parity_check_gqa_past( arange = rearrange(torch.arange(config.buffer_sequence_length, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.q_sequence_length + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.kv_sequence_length ) k_to_cache = k_ro @@ -1425,7 +1427,7 @@ def parity_check_gqa_past( k_cache_ref[update_mask] = rearrange(k_to_cache, "b s ... -> (b s) ...").to(k_cache_ref.dtype) v_cache_ref[update_mask] = rearrange(v_to_cache, "b s ... -> (b s) ...").to(v_cache_ref.dtype) - key_padding_mask = arange < cache_seqlens_expanded + config.q_sequence_length + key_padding_mask = arange < cache_seqlens_expanded + config.kv_sequence_length out_ref, _ = attention_ref( q=q_ro, @@ -1461,7 +1463,7 @@ def parity_check_gqa_past( k_ort = k_ort.contiguous() v_ort = v_ort.contiguous() - ort_seqlens = cache_seqlens + config.q_sequence_length - 1 + ort_seqlens = cache_seqlens + config.kv_sequence_length - 1 out, present_k, present_v = gqa_past_func( q=q_ort, @@ -2287,6 +2289,34 @@ def test_gqa_qk_norm_past_xqa(self): atol=atol["fp16"], ) + def test_gqa_qk_norm_past_shared_kv(self): + config = GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=0, + past_kv_sequence_length=127, + buffer_sequence_length=135, + num_heads=8, + kv_num_heads=2, + head_size=64, + rotary=False, + packed=False, + share_buffer=True, + has_qk_norm=True, + ) + + with scoped_env_var("ORT_ENABLE_XQA", "1"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + def test_gqa_qk_norm_past_xqa_bf16(self): if not torch.cuda.is_bf16_supported(): self.skipTest("BFloat16 not supported on this device")