Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2675,7 +2675,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>v_scale</tt> (optional) : T_KV_SCALE</dt>
<dd>Scale tensor for past_value.</dd>
<dt><tt>q_norm_weight</tt> (optional) : T</dt>
<dd>Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion folds that pattern into this input. Currently honored by the native WebGPU execution provider only; JSEP WebGPU/JS and other EPs must reject the node when this input is set.</dd>
<dd>Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion folds that pattern into this input. Currently honored by the CUDA and native WebGPU execution providers; JSEP WebGPU/JS and other EPs must reject the node when this input is set.</dd>
<dt><tt>k_norm_weight</tt> (optional) : T</dt>
<dd>Optional 1D tensor of shape (head_size). See q_norm_weight. Must be provided together with q_norm_weight.</dd>
</dl>
Expand Down
45 changes: 42 additions & 3 deletions docs/contrib_ops/cuda/gqa.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Selected attributes:
| `local_window_size` | Left window size for local attention. `-1` means global attention. |
| `do_rotary` / `rotary_interleaved` | Enable RoPE and select interleaved vs. half-rotary layout. |
| `smooth_softmax` | Add a smooth factor to the softmax denominator. |
| `qk_norm_epsilon` | Epsilon for the fused per-head Q/K RMSNorm (QK-Norm) prologue. Defaults to `1e-6`. |
| `k_quant_type` / `v_quant_type` | KV cache quantization mode: `NONE`, `PER_TENSOR`, or `PER_CHANNEL`. |
| `kv_cache_bit_width` | Bit width of the quantized KV cache (`8` or `4`). |

Expand All @@ -73,6 +74,7 @@ Selected inputs (see the schema for the full list and shapes):
| 7, 8 | `cos_cache`, `sin_cache` | RoPE caches. |
| 11 | `head_sink` | `(num_heads,)` per-head attention sink (see §5). |
| 12, 13 | `k_scale`, `v_scale` | FP32 dequant scales for the quantized KV cache. |
| 14, 15 | `q_norm_weight`, `k_norm_weight` | `(head_size,)` per-head Q/K RMSNorm weights (QK-Norm, see §3). Both must be present together. |

Outputs are `output`, `present_key`, `present_value`, and optional `output_qk`.

Expand Down Expand Up @@ -118,6 +120,31 @@ last dimension is `(head_size + 1) / 2` because two nibbles are packed per byte.
- RoPE, packed-QKV unpacking, and KV-head expansion are handled internally (`PrepareQKV`) before the
selected backend runs, so every backend sees a consistent layout.

### Fused QK-Norm (per-head Q/K RMSNorm)

When the optional `q_norm_weight` (input 14) and `k_norm_weight` (input 15) tensors are provided, the
CUDA kernel applies a fused per-head RMS normalization to Q and K **before** RoPE. This matches the
QK-Norm used by **Qwen3, Gemma 2/3, OLMo2, SmolLM3**, etc. For each head, over the `head_size`
channels:

$$
x_\text{norm}[c] = x[c] \cdot \frac{1}{\sqrt{\frac{1}{H}\sum_{j} x[j]^2 + \epsilon}} \cdot w[c]
$$

where `H = head_size`, `w` is the per-head weight vector (`q_norm_weight` for Q, `k_norm_weight` for
K), and `epsilon = qk_norm_epsilon` (default `1e-6`). The sum of squares is reduced in FP32 for
numerical stability and the result is cast back to the operator type `T`.

- Both weights are 1D tensors of shape `(head_size,)`, share the operator's element type `T`
(`float16`/`bfloat16`), and are **shared across all heads**. They must be supplied together —
providing only one is rejected.
- The normalization is fused into the `PrepareQKV` prologue (`UnpackRoPEAppend` for the new-KV path,
or a standalone per-head RMSNorm kernel for the shared-buffer Q-only decode case), so it composes
with packed QKV, RoPE, KV-head expansion, and the quantized KV cache.
- Because the Flash-Decoding fast path does its own RoPE/append internally and bypasses `PrepareQKV`,
it is disabled when QK-Norm is present (see §6). The non-quantized XQA decode path can still run
with QK-Norm: CUDA normalizes Q/K in the `UnpackRoPEAppend` preprocess before launching XQA.

## 4. KV Cache and Quantization

### Layout and shared buffer
Expand Down Expand Up @@ -197,6 +224,13 @@ order and the first eligible backend wins:
The selected backend is reported in the kernel debug info as `SdpaKernel=...` when debug info is
enabled (see §10).

> **QK-Norm interaction.** When `q_norm_weight` / `k_norm_weight` are present (see §3), the
> Flash-Decoding fast path is disabled so the QK-Norm prologue always runs. Non-quantized XQA decode
> remains eligible for supported shapes: the `UnpackRoPEAppend` preprocess normalizes Q/K, applies
> RoPE, appends K/V, and then XQA consumes the normalized Q and cache. Quantized-cache QK-Norm decode
> still falls back to Flash Attention (or cuDNN SDPA / MEA / Unfused) until normalized-K scale
> handling is validated for XQA.

### 6.1 XQA

Checked first. Used only for single-token global decode under the conditions detailed in §7. When
Expand Down Expand Up @@ -402,10 +436,14 @@ CUDA parity tests live in

- `TestXQAQuantizedParity` — XQA per-tensor int8 quantized decode parity.
- `TestXQAHeadSinkParity` — non-quantized XQA decode parity with a `head_sink` (attention sink) input.
- `TestGQAQKNorm` — fused per-head Q/K RMSNorm (QK-Norm) parity for prompt and decode (past), FP16 and
BF16, across packed/unpacked Q/K/V and with/without RoPE.

`TestXQAQuantizedParity` sets `ORT_ENABLE_XQA=1` to force the XQA path. `TestXQAHeadSinkParity`
instead clears `ORT_ENABLE_XQA` to validate that XQA is enabled by default when a `head_sink` input
is present. Both compare against a PyTorch reference (`attention_ref` with `smooth_softmax_ref`).
`TestGQAQKNorm` applies the RMSNorm-before-RoPE reference to Q and K and compares against the CUDA
output.

## 13. Future Work and Known Limitations

Expand All @@ -414,9 +452,10 @@ popular LLMs. Listed roughly by impact.

### High impact

1. **Fused QK-Norm (per-head Q/K RMSNorm prologue).** The CUDA kernel rejects `q_norm_weight` /
`k_norm_weight`; only the WebGPU EP implements the fused prologue. Required by **Qwen3,
Gemma 2/3, OLMo2, SmolLM3**, etc., which otherwise run the normalization unfused.
1. **Fused QK-Norm (per-head Q/K RMSNorm prologue).** *Implemented.* The CUDA kernel applies the
fused per-head RMSNorm to Q and K before RoPE when `q_norm_weight` / `k_norm_weight` are provided
(see §3), matching **Qwen3, Gemma 2/3, OLMo2, SmolLM3**, etc. Remaining limitation: QK-Norm
disables Flash-Decoding, and quantized-cache QK-Norm does not yet get the XQA fast path.
2. **Sliding-window + attention-sink on the fused decode path.** XQA requires global attention
(`local_window_size == -1`), so **GPT-OSS / Mistral / Gemma 2** layers that combine a sliding
window with a `head_sink` fall back to Flash / Flash-Decoding instead of XQA. Unifying sliding
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,14 @@ const generatePositionIdsProgramInfo = (
};

export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
// q_norm_weight (input 14) / k_norm_weight (input 15) are emitted by the WebGPU-only
// q_norm_weight (input 14) / k_norm_weight (input 15) are emitted by the CUDA/native WebGPU
// GroupQueryAttentionPreNormFusion optimizer pass. JSEP does not implement the fused
// per-head Q/K RMS normalization prologue, so reject the node if either input is present
// (regardless of rank, including scalars) rather than silently dropping the normalization.
if ((context.inputs.length > 14 && context.inputs[14]) || (context.inputs.length > 15 && context.inputs[15])) {
throw new Error(
'GroupQueryAttention (JSEP): q_norm_weight / k_norm_weight inputs are not supported. ' +
'The per-head Q/K RMS normalization prologue is implemented only on the native WebGPU EP.',
'The per-head Q/K RMS normalization prologue is implemented only on the CUDA and native WebGPU EPs.',
);
}
const params = validateInputs(context.inputs, attributes);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ struct GroupQueryAttentionParameters : AttentionParameters {
bool is_first_prompt; // indicates whether this is first decoding step
bool rotary_interleaved;
bool use_smooth_softmax;
bool use_qk_norm = false; // per-head Q/K RMSNorm (QK-Norm) prologue before RoPE (inputs 14/15)
float qk_norm_epsilon = 1e-6f; // epsilon for the QK-Norm RMSNorm
float softcap;
AttentionQkvFormat past_kv_format;
int zeros_count;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
"kv_cache_bit_width must be 0 when quantization is disabled, got ", kv_cache_bit_width_);
}

// q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only
// q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the CUDA/WebGPU
// GroupQueryAttentionPreNormFusion optimizer pass. The CPU kernel does not implement
// the fused per-head Q/K RMS normalization prologue, so reject the node if either input
// is present rather than silently dropping the normalization.
Expand All @@ -93,7 +93,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"GroupQueryAttention (CPU): q_norm_weight / k_norm_weight inputs are not supported. "
"The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP.");
"The per-head Q/K RMS normalization prologue is implemented only on the CUDA and WebGPU EPs.");
}

GroupQueryAttentionParameters parameters = {};
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ struct GroupQueryAttentionData {
const T* sin_cache = nullptr;
const T* head_sink = nullptr;

// Optional per-head Q/K RMSNorm (QK-Norm) weights, shape (head_size,), shared across heads.
// Both are non-null together (validated in the op) and trigger the fused normalization before RoPE.
const T* q_norm_weight = nullptr;
const T* k_norm_weight = nullptr;
float qk_norm_epsilon = 1e-6f;

const float* k_scale = nullptr;
const float* v_scale = nullptr;

Expand Down
70 changes: 55 additions & 15 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <vector>
#include <algorithm>
#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cuda_type_conversion.h"
Expand Down Expand Up @@ -105,6 +106,9 @@ GroupQueryAttention<T, U>::GroupQueryAttention(const OpKernelInfo& info)
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
softcap_ = info.GetAttrOrDefault<float>("softcap", 0.0f);
use_smooth_softmax_ = info.GetAttrOrDefault<int64_t>("smooth_softmax", 0) == 1;
qk_norm_epsilon_ = info.GetAttrOrDefault<float>("qk_norm_epsilon", 1e-6f);
ORT_ENFORCE(std::isfinite(qk_norm_epsilon_) && qk_norm_epsilon_ > 0.0f,
"GroupQueryAttention (CUDA): qk_norm_epsilon must be finite and positive.");

k_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault<std::string>("k_quant_type", "NONE"));
v_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault<std::string>("v_quant_type", "NONE"));
Expand Down Expand Up @@ -222,16 +226,17 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
const Tensor* k_scale = context->Input<Tensor>(12);
const Tensor* v_scale = context->Input<Tensor>(13);

// q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only
// GroupQueryAttentionPreNormFusion optimizer pass. The CUDA kernel does not implement
// the fused per-head Q/K RMS normalization prologue, so reject the node if either input
// is present rather than silently dropping the normalization.
if ((context->InputCount() > 14 && context->Input<Tensor>(14) != nullptr) ||
(context->InputCount() > 15 && context->Input<Tensor>(15) != nullptr)) {
// q_norm_weight (input 14) / k_norm_weight (input 15) carry the per-head Q/K RMSNorm (QK-Norm)
// prologue weights, each of shape (head_size,) and shared across heads. They are populated by the
// GroupQueryAttentionPreNormFusion optimizer pass for Qwen3 / Gemma 2-3 / OLMo2 / SmolLM3 style
// models. Both must be present together; shape validation and wiring happen after CheckInputs,
// where head_size is known.
const Tensor* q_norm_weight = (context->InputCount() > 14) ? context->Input<Tensor>(14) : nullptr;
const Tensor* k_norm_weight = (context->InputCount() > 15) ? context->Input<Tensor>(15) : nullptr;
if ((q_norm_weight != nullptr) != (k_norm_weight != nullptr)) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"GroupQueryAttention (CUDA): q_norm_weight / k_norm_weight inputs are not supported. "
"The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP.");
"GroupQueryAttention (CUDA): q_norm_weight and k_norm_weight must be provided together.");
}

if (k_quant_type_ != KVQuantizationType::NONE) {
Expand Down Expand Up @@ -300,6 +305,27 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
head_sink,
parameters));

// Validate and enable the per-head Q/K RMSNorm (QK-Norm) prologue (inputs 14/15). Both weights
// must be 1D tensors of shape (head_size) with element type T (shared across all heads).
if (q_norm_weight != nullptr) {
const auto& q_norm_shape = q_norm_weight->Shape();
const auto& k_norm_shape = k_norm_weight->Shape();
if (q_norm_shape.NumDimensions() != 1 || q_norm_shape[0] != parameters.head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"q_norm_weight must be a 1D tensor of shape (head_size=", parameters.head_size, ")");
}
if (k_norm_shape.NumDimensions() != 1 || k_norm_shape[0] != parameters.head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"k_norm_weight must be a 1D tensor of shape (head_size=", parameters.head_size, ")");
}
if (!q_norm_weight->IsDataType<T>() || !k_norm_weight->IsDataType<T>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"q_norm_weight/k_norm_weight type must match GroupQueryAttention input type T.");
}
parameters.use_qk_norm = true;
}
parameters.qk_norm_epsilon = qk_norm_epsilon_;

parameters.local_window_size = local_window_size_;
parameters.is_unidirectional = is_unidirectional_;
parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr;
Expand Down Expand Up @@ -389,18 +415,23 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
// 5. No Softcap (XQA doesn't support softcap).
// 6. Standard Softmax, or smooth softmax represented by a head_sink tensor.
// 7. No local window attention (global attention only).
const bool use_xqa_attention_sinks = head_sink != nullptr && !is_inputs_quantized;
const bool is_xqa_smooth_softmax_supported = !parameters.use_smooth_softmax || use_xqa_attention_sinks;
// XQA is enabled when enable_xqa_=true; ineligible shapes/group sizes fall back via data.use_xqa below.
// QK-Norm can use XQA for the non-quantized KV-cache path: ExtremeDecoding runs the same
// UnpackRoPEAppend preprocess before XQA, so Q/K can be normalized before the XQA kernel consumes
// Q and the appended cache. Keep quantized QK-Norm off the XQA route until scale correctness is
// validated for normalized K before quantized-cache append.
const bool xqa_qk_norm_ok = !parameters.use_qk_norm || !is_inputs_quantized;
const bool use_xqa_attention_sinks = parameters.use_smooth_softmax && head_sink != nullptr && !is_inputs_quantized;
const bool xqa_smooth_softmax_ok = !parameters.use_smooth_softmax || (head_sink != nullptr && !is_inputs_quantized);
if (enable_xqa_ &&
(device_prop.major >= 8) &&
!parameters.is_first_prompt &&
parameters.sequence_length == 1 &&
parameters.kv_sequence_length > 0 && // Shared KV (kv_seq=0) has no new K/V to append
parameters.past_present_share_buffer &&
parameters.softcap == 0.0f &&
is_xqa_smooth_softmax_supported &&
parameters.local_window_size == -1) {
parameters.local_window_size == -1 &&
xqa_qk_norm_ok &&
xqa_smooth_softmax_ok) {
int group_size = parameters.num_heads / parameters.kv_num_heads;

bool is_int8_quantized_supported = is_int8 &&
Expand Down Expand Up @@ -518,7 +549,10 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
parameters.kv_num_heads);

data.use_flash_attention = use_flash_attention;
data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_sequence_length > 0 && parameters.past_present_share_buffer && !is_inputs_quantized;
// The fast-decode path lets the flash kernel perform RoPE and KV-append internally, bypassing
// PrepareQKV (and therefore the fused QK-Norm prologue). Disable it when q/k norm weights are
// present so the regular FlashAttention path (which normalizes via PrepareQKV) is used instead.
data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_sequence_length > 0 && parameters.past_present_share_buffer && !is_inputs_quantized && !parameters.use_qk_norm;

if (use_flash_attention) {
// Allocate Flash specific buffers (Softmax LSE, Accum)
Expand Down Expand Up @@ -701,6 +735,12 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
data.head_sink = reinterpret_cast<const CudaT*>(head_sink->Data<T>());
}

if (parameters.use_qk_norm) {
data.q_norm_weight = reinterpret_cast<const CudaT*>(q_norm_weight->Data<T>());
data.k_norm_weight = reinterpret_cast<const CudaT*>(k_norm_weight->Data<T>());
data.qk_norm_epsilon = qk_norm_epsilon_;
}

#if DUMP_TENSOR_LEVEL > 0
DUMP_TENSOR_INIT();
// Dump Scales
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class GroupQueryAttention final : public CudaKernel {
bool do_rotary_;
bool rotary_interleaved_;
bool use_smooth_softmax_;
float qk_norm_epsilon_; // epsilon for the per-head Q/K RMSNorm (QK-Norm) prologue
float scale_;
float softcap_;
bool disable_flash_attention_;
Expand Down
Loading
Loading