diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 38d101786b41a..79296736faf97 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2675,7 +2675,7 @@ This version of the operator has been available since version 1 of the 'com.micr
v_scale (optional) : T_KV_SCALE
Scale tensor for past_value.
q_norm_weight (optional) : T
-Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion folds that pattern into this input. Currently honored by the native WebGPU execution provider only; JSEP WebGPU/JS and other EPs must reject the node when this input is set.
+Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion folds that pattern into this input. Currently honored by the CUDA and native WebGPU execution providers; JSEP WebGPU/JS and other EPs must reject the node when this input is set.
k_norm_weight (optional) : T
Optional 1D tensor of shape (head_size). See q_norm_weight. Must be provided together with q_norm_weight.
diff --git a/docs/contrib_ops/cuda/gqa.md b/docs/contrib_ops/cuda/gqa.md
index aeb88dbd07588..aa289a0de9a3c 100644
--- a/docs/contrib_ops/cuda/gqa.md
+++ b/docs/contrib_ops/cuda/gqa.md
@@ -58,6 +58,7 @@ Selected attributes:
| `local_window_size` | Left window size for local attention. `-1` means global attention. |
| `do_rotary` / `rotary_interleaved` | Enable RoPE and select interleaved vs. half-rotary layout. |
| `smooth_softmax` | Add a smooth factor to the softmax denominator. |
+| `qk_norm_epsilon` | Epsilon for the fused per-head Q/K RMSNorm (QK-Norm) prologue. Defaults to `1e-6`. |
| `k_quant_type` / `v_quant_type` | KV cache quantization mode: `NONE`, `PER_TENSOR`, or `PER_CHANNEL`. |
| `kv_cache_bit_width` | Bit width of the quantized KV cache (`8` or `4`). |
@@ -73,6 +74,7 @@ Selected inputs (see the schema for the full list and shapes):
| 7, 8 | `cos_cache`, `sin_cache` | RoPE caches. |
| 11 | `head_sink` | `(num_heads,)` per-head attention sink (see §5). |
| 12, 13 | `k_scale`, `v_scale` | FP32 dequant scales for the quantized KV cache. |
+| 14, 15 | `q_norm_weight`, `k_norm_weight` | `(head_size,)` per-head Q/K RMSNorm weights (QK-Norm, see §3). Both must be present together. |
Outputs are `output`, `present_key`, `present_value`, and optional `output_qk`.
@@ -118,6 +120,31 @@ last dimension is `(head_size + 1) / 2` because two nibbles are packed per byte.
- RoPE, packed-QKV unpacking, and KV-head expansion are handled internally (`PrepareQKV`) before the
selected backend runs, so every backend sees a consistent layout.
+### Fused QK-Norm (per-head Q/K RMSNorm)
+
+When the optional `q_norm_weight` (input 14) and `k_norm_weight` (input 15) tensors are provided, the
+CUDA kernel applies a fused per-head RMS normalization to Q and K **before** RoPE. This matches the
+QK-Norm used by **Qwen3, Gemma 2/3, OLMo2, SmolLM3**, etc. For each head, over the `head_size`
+channels:
+
+$$
+x_\text{norm}[c] = x[c] \cdot \frac{1}{\sqrt{\frac{1}{H}\sum_{j} x[j]^2 + \epsilon}} \cdot w[c]
+$$
+
+where `H = head_size`, `w` is the per-head weight vector (`q_norm_weight` for Q, `k_norm_weight` for
+K), and `epsilon = qk_norm_epsilon` (default `1e-6`). The sum of squares is reduced in FP32 for
+numerical stability and the result is cast back to the operator type `T`.
+
+- Both weights are 1D tensors of shape `(head_size,)`, share the operator's element type `T`
+ (`float16`/`bfloat16`), and are **shared across all heads**. They must be supplied together —
+ providing only one is rejected.
+- The normalization is fused into the `PrepareQKV` prologue (`UnpackRoPEAppend` for the new-KV path,
+ or a standalone per-head RMSNorm kernel for the shared-buffer Q-only decode case), so it composes
+ with packed QKV, RoPE, KV-head expansion, and the quantized KV cache.
+- Because the Flash-Decoding fast path does its own RoPE/append internally and bypasses `PrepareQKV`,
+ it is disabled when QK-Norm is present (see §6). The non-quantized XQA decode path can still run
+ with QK-Norm: CUDA normalizes Q/K in the `UnpackRoPEAppend` preprocess before launching XQA.
+
## 4. KV Cache and Quantization
### Layout and shared buffer
@@ -197,6 +224,13 @@ order and the first eligible backend wins:
The selected backend is reported in the kernel debug info as `SdpaKernel=...` when debug info is
enabled (see §10).
+> **QK-Norm interaction.** When `q_norm_weight` / `k_norm_weight` are present (see §3), the
+> Flash-Decoding fast path is disabled so the QK-Norm prologue always runs. Non-quantized XQA decode
+> remains eligible for supported shapes: the `UnpackRoPEAppend` preprocess normalizes Q/K, applies
+> RoPE, appends K/V, and then XQA consumes the normalized Q and cache. Quantized-cache QK-Norm decode
+> still falls back to Flash Attention (or cuDNN SDPA / MEA / Unfused) until normalized-K scale
+> handling is validated for XQA.
+
### 6.1 XQA
Checked first. Used only for single-token global decode under the conditions detailed in §7. When
@@ -402,10 +436,14 @@ CUDA parity tests live in
- `TestXQAQuantizedParity` — XQA per-tensor int8 quantized decode parity.
- `TestXQAHeadSinkParity` — non-quantized XQA decode parity with a `head_sink` (attention sink) input.
+- `TestGQAQKNorm` — fused per-head Q/K RMSNorm (QK-Norm) parity for prompt and decode (past), FP16 and
+ BF16, across packed/unpacked Q/K/V and with/without RoPE.
`TestXQAQuantizedParity` sets `ORT_ENABLE_XQA=1` to force the XQA path. `TestXQAHeadSinkParity`
instead clears `ORT_ENABLE_XQA` to validate that XQA is enabled by default when a `head_sink` input
is present. Both compare against a PyTorch reference (`attention_ref` with `smooth_softmax_ref`).
+`TestGQAQKNorm` applies the RMSNorm-before-RoPE reference to Q and K and compares against the CUDA
+output.
## 13. Future Work and Known Limitations
@@ -414,9 +452,10 @@ popular LLMs. Listed roughly by impact.
### High impact
-1. **Fused QK-Norm (per-head Q/K RMSNorm prologue).** The CUDA kernel rejects `q_norm_weight` /
- `k_norm_weight`; only the WebGPU EP implements the fused prologue. Required by **Qwen3,
- Gemma 2/3, OLMo2, SmolLM3**, etc., which otherwise run the normalization unfused.
+1. **Fused QK-Norm (per-head Q/K RMSNorm prologue).** *Implemented.* The CUDA kernel applies the
+ fused per-head RMSNorm to Q and K before RoPE when `q_norm_weight` / `k_norm_weight` are provided
+ (see §3), matching **Qwen3, Gemma 2/3, OLMo2, SmolLM3**, etc. Remaining limitation: QK-Norm
+ disables Flash-Decoding, and quantized-cache QK-Norm does not yet get the XQA fast path.
2. **Sliding-window + attention-sink on the fused decode path.** XQA requires global attention
(`local_window_size == -1`), so **GPT-OSS / Mistral / Gemma 2** layers that combine a sliding
window with a `head_sink` fall back to Flash / Flash-Decoding instead of XQA. Unifying sliding
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
index ada0c65bdd8a7..4e2d3b82ce155 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
@@ -328,14 +328,14 @@ const generatePositionIdsProgramInfo = (
};
export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
- // q_norm_weight (input 14) / k_norm_weight (input 15) are emitted by the WebGPU-only
+ // q_norm_weight (input 14) / k_norm_weight (input 15) are emitted by the CUDA/native WebGPU
// GroupQueryAttentionPreNormFusion optimizer pass. JSEP does not implement the fused
// per-head Q/K RMS normalization prologue, so reject the node if either input is present
// (regardless of rank, including scalars) rather than silently dropping the normalization.
if ((context.inputs.length > 14 && context.inputs[14]) || (context.inputs.length > 15 && context.inputs[15])) {
throw new Error(
'GroupQueryAttention (JSEP): q_norm_weight / k_norm_weight inputs are not supported. ' +
- 'The per-head Q/K RMS normalization prologue is implemented only on the native WebGPU EP.',
+ 'The per-head Q/K RMS normalization prologue is implemented only on the CUDA and native WebGPU EPs.',
);
}
const params = validateInputs(context.inputs, attributes);
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
index 5b7624d11c6fd..11be4be1638df 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
@@ -93,6 +93,8 @@ struct GroupQueryAttentionParameters : AttentionParameters {
bool is_first_prompt; // indicates whether this is first decoding step
bool rotary_interleaved;
bool use_smooth_softmax;
+ bool use_qk_norm = false; // per-head Q/K RMSNorm (QK-Norm) prologue before RoPE (inputs 14/15)
+ float qk_norm_epsilon = 1e-6f; // epsilon for the QK-Norm RMSNorm
float softcap;
AttentionQkvFormat past_kv_format;
int zeros_count;
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
index 61ae474703213..e36bdb2de263a 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
@@ -84,7 +84,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
"kv_cache_bit_width must be 0 when quantization is disabled, got ", kv_cache_bit_width_);
}
- // q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only
+ // q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the CUDA/WebGPU
// GroupQueryAttentionPreNormFusion optimizer pass. The CPU kernel does not implement
// the fused per-head Q/K RMS normalization prologue, so reject the node if either input
// is present rather than silently dropping the normalization.
@@ -93,7 +93,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"GroupQueryAttention (CPU): q_norm_weight / k_norm_weight inputs are not supported. "
- "The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP.");
+ "The per-head Q/K RMS normalization prologue is implemented only on the CUDA and WebGPU EPs.");
}
GroupQueryAttentionParameters parameters = {};
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h
index 74aeaf1285e8a..f6a7b788819f2 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h
@@ -157,6 +157,12 @@ struct GroupQueryAttentionData {
const T* sin_cache = nullptr;
const T* head_sink = nullptr;
+ // Optional per-head Q/K RMSNorm (QK-Norm) weights, shape (head_size,), shared across heads.
+ // Both are non-null together (validated in the op) and trigger the fused normalization before RoPE.
+ const T* q_norm_weight = nullptr;
+ const T* k_norm_weight = nullptr;
+ float qk_norm_epsilon = 1e-6f;
+
const float* k_scale = nullptr;
const float* v_scale = nullptr;
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
index 60b806019cc57..fb82bbee2a84b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
@@ -1,9 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+#include
+#include
#include
#include
-#include
#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cuda_type_conversion.h"
@@ -105,6 +106,9 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
scale_ = info.GetAttrOrDefault("scale", 0.0f);
softcap_ = info.GetAttrOrDefault("softcap", 0.0f);
use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1;
+ qk_norm_epsilon_ = info.GetAttrOrDefault("qk_norm_epsilon", 1e-6f);
+ ORT_ENFORCE(std::isfinite(qk_norm_epsilon_) && qk_norm_epsilon_ > 0.0f,
+ "GroupQueryAttention (CUDA): qk_norm_epsilon must be finite and positive.");
k_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("k_quant_type", "NONE"));
v_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("v_quant_type", "NONE"));
@@ -222,16 +226,17 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons
const Tensor* k_scale = context->Input(12);
const Tensor* v_scale = context->Input(13);
- // q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only
- // GroupQueryAttentionPreNormFusion optimizer pass. The CUDA kernel does not implement
- // the fused per-head Q/K RMS normalization prologue, so reject the node if either input
- // is present rather than silently dropping the normalization.
- if ((context->InputCount() > 14 && context->Input(14) != nullptr) ||
- (context->InputCount() > 15 && context->Input(15) != nullptr)) {
+ // q_norm_weight (input 14) / k_norm_weight (input 15) carry the per-head Q/K RMSNorm (QK-Norm)
+ // prologue weights, each of shape (head_size,) and shared across heads. They are populated by the
+ // GroupQueryAttentionPreNormFusion optimizer pass for Qwen3 / Gemma 2-3 / OLMo2 / SmolLM3 style
+ // models. Both must be present together; shape validation and wiring happen after CheckInputs,
+ // where head_size is known.
+ const Tensor* q_norm_weight = (context->InputCount() > 14) ? context->Input(14) : nullptr;
+ const Tensor* k_norm_weight = (context->InputCount() > 15) ? context->Input(15) : nullptr;
+ if ((q_norm_weight != nullptr) != (k_norm_weight != nullptr)) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
- "GroupQueryAttention (CUDA): q_norm_weight / k_norm_weight inputs are not supported. "
- "The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP.");
+ "GroupQueryAttention (CUDA): q_norm_weight and k_norm_weight must be provided together.");
}
if (k_quant_type_ != KVQuantizationType::NONE) {
@@ -300,6 +305,27 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons
head_sink,
parameters));
+ // Validate and enable the per-head Q/K RMSNorm (QK-Norm) prologue (inputs 14/15). Both weights
+ // must be 1D tensors of shape (head_size) with element type T (shared across all heads).
+ if (q_norm_weight != nullptr) {
+ const auto& q_norm_shape = q_norm_weight->Shape();
+ const auto& k_norm_shape = k_norm_weight->Shape();
+ if (q_norm_shape.NumDimensions() != 1 || q_norm_shape[0] != parameters.head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "q_norm_weight must be a 1D tensor of shape (head_size=", parameters.head_size, ")");
+ }
+ if (k_norm_shape.NumDimensions() != 1 || k_norm_shape[0] != parameters.head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "k_norm_weight must be a 1D tensor of shape (head_size=", parameters.head_size, ")");
+ }
+ if (!q_norm_weight->IsDataType() || !k_norm_weight->IsDataType()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "q_norm_weight/k_norm_weight type must match GroupQueryAttention input type T.");
+ }
+ parameters.use_qk_norm = true;
+ }
+ parameters.qk_norm_epsilon = qk_norm_epsilon_;
+
parameters.local_window_size = local_window_size_;
parameters.is_unidirectional = is_unidirectional_;
parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr;
@@ -389,9 +415,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons
// 5. No Softcap (XQA doesn't support softcap).
// 6. Standard Softmax, or smooth softmax represented by a head_sink tensor.
// 7. No local window attention (global attention only).
- const bool use_xqa_attention_sinks = head_sink != nullptr && !is_inputs_quantized;
- const bool is_xqa_smooth_softmax_supported = !parameters.use_smooth_softmax || use_xqa_attention_sinks;
- // XQA is enabled when enable_xqa_=true; ineligible shapes/group sizes fall back via data.use_xqa below.
+ // QK-Norm can use XQA for the non-quantized KV-cache path: ExtremeDecoding runs the same
+ // UnpackRoPEAppend preprocess before XQA, so Q/K can be normalized before the XQA kernel consumes
+ // Q and the appended cache. Keep quantized QK-Norm off the XQA route until scale correctness is
+ // validated for normalized K before quantized-cache append.
+ const bool xqa_qk_norm_ok = !parameters.use_qk_norm || !is_inputs_quantized;
+ const bool use_xqa_attention_sinks = parameters.use_smooth_softmax && head_sink != nullptr && !is_inputs_quantized;
+ const bool xqa_smooth_softmax_ok = !parameters.use_smooth_softmax || (head_sink != nullptr && !is_inputs_quantized);
if (enable_xqa_ &&
(device_prop.major >= 8) &&
!parameters.is_first_prompt &&
@@ -399,8 +429,9 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons
parameters.kv_sequence_length > 0 && // Shared KV (kv_seq=0) has no new K/V to append
parameters.past_present_share_buffer &&
parameters.softcap == 0.0f &&
- is_xqa_smooth_softmax_supported &&
- parameters.local_window_size == -1) {
+ parameters.local_window_size == -1 &&
+ xqa_qk_norm_ok &&
+ xqa_smooth_softmax_ok) {
int group_size = parameters.num_heads / parameters.kv_num_heads;
bool is_int8_quantized_supported = is_int8 &&
@@ -518,7 +549,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons
parameters.kv_num_heads);
data.use_flash_attention = use_flash_attention;
- data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_sequence_length > 0 && parameters.past_present_share_buffer && !is_inputs_quantized;
+ // The fast-decode path lets the flash kernel perform RoPE and KV-append internally, bypassing
+ // PrepareQKV (and therefore the fused QK-Norm prologue). Disable it when q/k norm weights are
+ // present so the regular FlashAttention path (which normalizes via PrepareQKV) is used instead.
+ data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_sequence_length > 0 && parameters.past_present_share_buffer && !is_inputs_quantized && !parameters.use_qk_norm;
if (use_flash_attention) {
// Allocate Flash specific buffers (Softmax LSE, Accum)
@@ -701,6 +735,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons
data.head_sink = reinterpret_cast(head_sink->Data());
}
+ if (parameters.use_qk_norm) {
+ data.q_norm_weight = reinterpret_cast(q_norm_weight->Data());
+ data.k_norm_weight = reinterpret_cast(k_norm_weight->Data());
+ data.qk_norm_epsilon = qk_norm_epsilon_;
+ }
+
#if DUMP_TENSOR_LEVEL > 0
DUMP_TENSOR_INIT();
// Dump Scales
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
index 2fe6268e53c68..d219bd83474d2 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
@@ -33,6 +33,7 @@ class GroupQueryAttention final : public CudaKernel {
bool do_rotary_;
bool rotary_interleaved_;
bool use_smooth_softmax_;
+ float qk_norm_epsilon_; // epsilon for the per-head Q/K RMSNorm (QK-Norm) prologue
float scale_;
float softcap_;
bool disable_flash_attention_;
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
index 6a55f18bd939a..fda4ae66417fc 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
@@ -88,6 +88,63 @@ Status LaunchConvertHeadSinkToFloat(
return CUDA_CALL(cudaGetLastError());
}
+// Standalone per-head RMS normalization (QK-Norm prologue) for Q in BSNH layout.
+// Each block handles one (b, s, head) vector and normalizes over head_size:
+// out[c] = in[c] * rsqrt(mean(in^2) + epsilon) * weight[c]
+// where weight has shape (head_size,) and is shared across heads. This is used by the shared-KV
+// (kv_sequence_length == 0) path, which processes Q only; the new-KV path folds the equivalent
+// normalization into UnpackRoPEAppend before RoPE.
+template
+__global__ void PerHeadRMSNormBSNHKernel(
+ T* output, const T* input, const T* weight, const int head_size, const float epsilon) {
+ const int s = blockIdx.x;
+ const int b = blockIdx.y;
+ const int n = blockIdx.z;
+ const int sequence_length = gridDim.x;
+ const int num_heads = gridDim.z;
+ const int i = threadIdx.x;
+
+ const int64_t base = (((static_cast(b) * sequence_length + s) * num_heads) + n) * head_size;
+
+ extern __shared__ float s_sumsq[];
+ const float x = (i < head_size) ? static_cast(input[base + i]) : 0.0f;
+ s_sumsq[i] = x * x;
+ __syncthreads();
+
+ // Tree reduction over a power-of-two block size; padded entries (i >= head_size) are zero.
+ for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) {
+ if (i < stride) {
+ s_sumsq[i] += s_sumsq[i + stride];
+ }
+ __syncthreads();
+ }
+
+ if (i < head_size) {
+ const float inv_rms = rsqrtf(s_sumsq[0] / static_cast(head_size) + epsilon);
+ output[base + i] = static_cast(x * inv_rms * static_cast(weight[i]));
+ }
+}
+
+template
+Status LaunchPerHeadRMSNorm(
+ cudaStream_t stream, T* output, const T* input, const T* weight, const float epsilon,
+ const int batch_size, const int sequence_length, const int num_heads, const int head_size,
+ const int max_threads_per_block) {
+ // Round the thread count up to a power of two so the tree reduction is exact.
+ int tpb = 1;
+ while (tpb < head_size) {
+ tpb <<= 1;
+ }
+ if (tpb > max_threads_per_block) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "head_size (", head_size, ") exceeds max threads for QK-Norm.");
+ }
+ const dim3 grid(sequence_length, batch_size, num_heads);
+ const dim3 block(tpb);
+ const size_t smem = static_cast(tpb) * sizeof(float);
+ PerHeadRMSNormBSNHKernel<<>>(output, input, weight, head_size, epsilon);
+ return CUDA_CALL(cudaGetLastError());
+}
+
// Internal helper to get Q, K, V pointers, handling packed input
//
// This function orchestrates the preparation of Q, K, and V tensors for attention kernels.
@@ -116,7 +173,7 @@ Status PrepareQKV(
T* q_out = reinterpret_cast(data.qkv_buffer);
- if (!parameters.is_packed_qkv && !parameters.do_rotary) {
+ if (!parameters.is_packed_qkv && !parameters.do_rotary && !parameters.use_qk_norm) {
q_out = nullptr;
}
@@ -153,6 +210,16 @@ Status PrepareQKV(
// above has already populated the present buffer with the shared KV data.
// In both cases, only Q processing (RoPE if configured) is needed here.
if (kv_sequence_length == 0) {
+ // QK-Norm: normalize Q (BSNH) into q_out before RoPE. K is already normalized in the shared cache
+ // (it was normalized when first appended), so only Q needs processing on this path.
+ const T* q_rope_input = data.query;
+ if (parameters.use_qk_norm) {
+ ORT_RETURN_IF_ERROR((LaunchPerHeadRMSNorm(
+ stream, q_out, data.query, data.q_norm_weight, parameters.qk_norm_epsilon,
+ batch_size, sequence_length, num_heads, head_size, max_threads_per_block)));
+ // RoPE (if any) runs in-place on the normalized Q; the rotary kernel handles in-place safely.
+ q_rope_input = q_out;
+ }
if (parameters.do_rotary && data.cos_cache != nullptr && data.sin_cache != nullptr) {
// Apply RoPE to Q only using the standalone rotary embedding kernel.
// Q is in BSNH format; the kernel writes rotated Q to q_out.
@@ -161,7 +228,7 @@ Status PrepareQKV(
const int pos_format = data.position_ids != nullptr ? 1 : 2;
if constexpr (std::is_same::value) {
ORT_RETURN_IF_ERROR((LaunchRotaryEmbeddingKernel(
- stream, reinterpret_cast(q_out), reinterpret_cast(data.query),
+ stream, reinterpret_cast(q_out), reinterpret_cast(q_rope_input),
data.position_ids, data.past_seq_lens,
reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache),
batch_size, sequence_length, num_heads, head_size, parameters.rotary_dim, max_cache_length,
@@ -169,7 +236,7 @@ Status PrepareQKV(
max_threads_per_block, false /* is_input_bnsh_format: Q is BSNH */)));
} else if constexpr (std::is_same::value) {
ORT_RETURN_IF_ERROR((LaunchRotaryEmbeddingKernel(
- stream, reinterpret_cast(q_out), reinterpret_cast(data.query),
+ stream, reinterpret_cast(q_out), reinterpret_cast(q_rope_input),
data.position_ids, data.past_seq_lens,
reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache),
batch_size, sequence_length, num_heads, head_size, parameters.rotary_dim, max_cache_length,
@@ -177,7 +244,7 @@ Status PrepareQKV(
max_threads_per_block, false /* is_input_bnsh_format: Q is BSNH */)));
}
}
- // If do_rotary is false, Q is used directly from data.query (q_out == nullptr).
+ // If do_rotary is false and QK-Norm is off, Q is used directly from data.query (q_out == nullptr).
// K/V present buffers already point to the shared past — no work needed.
} else {
ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend(
@@ -191,6 +258,7 @@ Status PrepareQKV(
reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache),
parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved,
is_cache_bnsh, parameters.k_quant_type,
+ data.q_norm_weight, data.k_norm_weight, parameters.qk_norm_epsilon,
stream, max_threads_per_block)));
}
@@ -668,6 +736,7 @@ Status ExtremeDecoding(
parameters.rotary_interleaved,
!past_bsnh, // is_cache_bnsh
parameters.k_quant_type,
+ data.q_norm_weight, data.k_norm_weight, parameters.qk_norm_epsilon,
stream,
device_prop.maxThreadsPerBlock)));
@@ -894,6 +963,7 @@ Status DequantizeFlashAttentionFallback(
parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved,
(parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH),
parameters.k_quant_type,
+ data.q_norm_weight, data.k_norm_weight, parameters.qk_norm_epsilon,
stream, device_prop.maxThreadsPerBlock)));
// Step 2: Dequantize Entire Cache
@@ -978,6 +1048,7 @@ Status FlashAttentionAndQuantizeKV(
parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved,
false, // BSNH for scratch
KVQuantizationType::NONE,
+ data.q_norm_weight, data.k_norm_weight, parameters.qk_norm_epsilon,
stream, max_threads_per_block)));
// 2. Run Float Flash Attention
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
index 348dc0832d3ba..b1f979fef4169 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
@@ -76,9 +76,9 @@ struct GQABufferRequirements {
const size_t v_elements = k_elements;
if (use_xqa) {
- if (params.do_rotary || params.is_packed_qkv) {
- // XQA need scratch for rotated/unpacked Q.
- // RoPE K is written directly to cache by the fused kernel.
+ if (params.do_rotary || params.is_packed_qkv || params.use_qk_norm) {
+ // XQA needs scratch for rotated/unpacked/normalized Q.
+ // RoPE/QK-Norm K is written directly to cache by the fused preprocess kernel.
req.qkv_buffer_bytes = elem_size * q_elements;
}
return req;
@@ -127,7 +127,9 @@ struct GQABufferRequirements {
}
// Unfused fallback: needs Q buffer for rotary embedding output.
- if (req.qkv_buffer_bytes == 0 && (params.do_rotary || params.is_packed_qkv)) {
+ // QK-Norm also requires a materialized Q buffer to hold the normalized (and optionally rotated) Q,
+ // even when rotary is disabled and the input is not packed.
+ if (req.qkv_buffer_bytes == 0 && (params.do_rotary || params.is_packed_qkv || params.use_qk_norm)) {
req.qkv_buffer_bytes = elem_size * q_elements;
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh
index 5fb36e094482b..0c62aef11d53a 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh
@@ -61,7 +61,12 @@ __global__ void UnpackRoPEAppend(
const int64_t* position_ids,
const bool interleaved,
const bool is_cache_bnsh,
- const bool per_channel) {
+ const bool per_channel,
+ // QK-Norm (per-head Q/K RMSNorm) weights of shape (head_size,), shared across heads.
+ // nullptr disables normalization for the corresponding head type; V heads are never normalized.
+ const T* q_norm_weight,
+ const T* k_norm_weight,
+ const float qk_norm_epsilon) {
using LoadT = float4;
constexpr int elements_per_thread = sizeof(LoadT) / sizeof(T);
@@ -84,6 +89,9 @@ __global__ void UnpackRoPEAppend(
const int sequence_length = gridDim.x; // Number of new tokens in this launch
__shared__ T shared_head[MAX_HEAD_SIZE];
+ // Per-block reduction buffer for the QK-Norm sum-of-squares. One block handles one (b, s, head),
+ // and blockDim.x == head_size / elements_per_thread (<= MAX_HEAD_SIZE / elements_per_thread).
+ __shared__ float s_qk_reduce[MAX_HEAD_SIZE / elements_per_thread];
// Determine Head Type and Offset within the packed hidden dimension [Q, K, V]
enum HeadType { QUERY,
@@ -139,6 +147,44 @@ __global__ void UnpackRoPEAppend(
// Non-interleaved RoPE requires full head visibility to pair channels (h, h + d/2).
// We use shared memory as a staging buffer to allow any thread to access its pair.
const bool is_qk = (head_type == QUERY || head_type == KEY);
+
+ // 1.5 QK-Norm: per-head RMSNorm applied BEFORE RoPE (Qwen3 / Gemma 2-3 / OLMo2 / SmolLM3).
+ // Each block processes a single head, so head_type (and thus norm_weight) is uniform across the
+ // block, which makes the __syncthreads below safe. Q heads use q_norm_weight, K heads use
+ // k_norm_weight, and V heads are skipped. The weight is shared across heads (indexed by channel).
+ const T* norm_weight = (head_type == QUERY) ? q_norm_weight : ((head_type == KEY) ? k_norm_weight : nullptr);
+ if (is_qk && norm_weight != nullptr) {
+ float partial = 0.0f;
+ if (valid) {
+#pragma unroll
+ for (int i = 0; i < elements_per_thread; ++i) {
+ const float f = static_cast(vals[i]);
+ partial += f * f;
+ }
+ }
+ s_qk_reduce[tid] = partial;
+ __syncthreads();
+ // blockDim.x == head_size / elements_per_thread is small (<= 64). A linear reduction is robust
+ // for any (possibly non-power-of-two) thread count and avoids tree-reduction edge cases.
+ // Reduce once in tid==0 and broadcast inv_rms via shared memory to avoid the redundant
+ // O(blockDim.x^2) shared reads that result from every thread summing the partials.
+ if (tid == 0) {
+ float sumsq = 0.0f;
+ for (int t = 0; t < blockDim.x; ++t) {
+ sumsq += s_qk_reduce[t];
+ }
+ s_qk_reduce[0] = rsqrtf(sumsq / static_cast(head_size) + qk_norm_epsilon);
+ }
+ __syncthreads();
+ const float inv_rms = s_qk_reduce[0];
+ if (valid) {
+#pragma unroll
+ for (int i = 0; i < elements_per_thread; ++i) {
+ vals[i] = static_cast(static_cast(vals[i]) * inv_rms * static_cast(norm_weight[h + i]));
+ }
+ }
+ }
+
if (valid && rotary_dim > 0 && is_qk && !interleaved) {
T* shared_ptr = &shared_head[h];
*reinterpret_cast(shared_ptr) = *reinterpret_cast(vals);
@@ -276,27 +322,32 @@ Status DispatchUnpackRoPEAppendHeadSize(
const int num_heads, const int kv_num_heads, const int head_size, const int d,
const int max_seqlen, const int* past_seq_lens,
const T* cos_cache, const T* sin_cache, const int rotary_dim,
- const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const bool per_channel) {
+ const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const bool per_channel,
+ const T* q_norm_weight, const T* k_norm_weight, const float qk_norm_epsilon) {
if (head_size <= 64) {
UnpackRoPEAppend<<>>(
packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale,
num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens,
- cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel);
+ cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel,
+ q_norm_weight, k_norm_weight, qk_norm_epsilon);
} else if (head_size <= 128) {
UnpackRoPEAppend<<>>(
packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale,
num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens,
- cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel);
+ cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel,
+ q_norm_weight, k_norm_weight, qk_norm_epsilon);
} else if (head_size <= 256) {
UnpackRoPEAppend<<>>(
packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale,
num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens,
- cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel);
+ cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel,
+ q_norm_weight, k_norm_weight, qk_norm_epsilon);
} else if (head_size <= 512) {
UnpackRoPEAppend<<>>(
packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale,
num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens,
- cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel);
+ cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel,
+ q_norm_weight, k_norm_weight, qk_norm_epsilon);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size (", head_size, ") exceeds maximum supported MAX_HEAD_SIZE (512).");
}
@@ -318,6 +369,7 @@ Status LaunchUnpackRoPEAppend(
const int* past_seq_lens, const T* cos_cache, const T* sin_cache,
const int rotary_dim, const int64_t* position_ids, const bool interleaved,
const bool is_cache_bnsh, const KVQuantizationType k_quant_type,
+ const T* q_norm_weight, const T* k_norm_weight, const float qk_norm_epsilon,
cudaStream_t stream, const int max_threads_per_block) {
static_assert(std::is_same::type>::value);
static_assert(std::is_same::type>::value);
@@ -365,7 +417,8 @@ Status LaunchUnpackRoPEAppend(
return DispatchUnpackRoPEAppendHeadSize(
grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache,
k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens,
- cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel);
+ cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel,
+ q_norm_weight, k_norm_weight, qk_norm_epsilon);
} else if constexpr (std::is_same::value
#ifdef USE_FP8_KV_CACHE
|| std::is_same::value
@@ -375,14 +428,16 @@ Status LaunchUnpackRoPEAppend(
return DispatchUnpackRoPEAppendHeadSize(
grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache,
k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens,
- cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel);
+ cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel,
+ q_norm_weight, k_norm_weight, qk_norm_epsilon);
#ifdef USE_INT4_KV_CACHE
} else if constexpr (std::is_same::value) {
// INT4 quantization (packed 2 elements per byte)
return DispatchUnpackRoPEAppendHeadSize(
grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache,
k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens,
- cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel);
+ cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel,
+ q_norm_weight, k_norm_weight, qk_norm_epsilon);
#endif
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported cache type U for GQA quantization.");
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 5d537fa59bfab..896774fb5c8d8 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1324,7 +1324,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a "
"per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap "
"their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion "
- "folds that pattern into this input. Currently honored by the native WebGPU execution provider only; "
+ "folds that pattern into this input. Currently honored by the CUDA and native WebGPU execution providers; "
"JSEP WebGPU/JS and other EPs must reject the node when this input is set.",
"T",
OpSchema::Optional)
diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc
index a998eeacda734..56fc3ca7a9855 100644
--- a/onnxruntime/core/optimizer/graph_transformer_utils.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc
@@ -453,7 +453,8 @@ InlinedVector> GenerateTransformers(
transformers.emplace_back(std::make_unique(cpu_ep));
transformers.emplace_back(std::make_unique(
- InlinedHashSet{onnxruntime::kWebGpuExecutionProvider}));
+ InlinedHashSet{onnxruntime::kCudaExecutionProvider,
+ onnxruntime::kWebGpuExecutionProvider}));
bool has_matmul_nbits_mlp_kernel = false;
bool has_matmul_nbits_qkv_kernel = false;
if (execution_providers != nullptr) {
diff --git a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc
index 3271c4cc9a22f..f8763306ced46 100644
--- a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc
+++ b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc
@@ -167,6 +167,9 @@ bool MatchPreNormReshapeChain(Graph& graph,
}
const auto* sln_eps_attr = graph_utils::GetNodeAttribute(*sln, "epsilon");
const float sln_eps = (sln_eps_attr == nullptr) ? 1e-5f : sln_eps_attr->f();
+ if (!std::isfinite(sln_eps) || sln_eps <= 0.0f) {
+ return false;
+ }
// Inner reshape (between projection and SLN).
if (sln->InputDefs().empty() || sln->InputDefs()[0] == nullptr) {
diff --git a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h
index b69199bb5324d..f3cb131a0b8fb 100644
--- a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h
+++ b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h
@@ -27,12 +27,12 @@ on inputs 0 (query) and 1 (key) of an unfused GroupQueryAttention node:
When matched, the six Reshape/SLN nodes are removed and the pre-norm Q and K
projections feed GQA directly. The kernel is responsible for applying the RMS
-norm internally (currently the WebGPU EP).
+norm internally (currently the CUDA and WebGPU EPs).
Only fires for execution providers passed in `compatible_execution_providers`.
-At present this fusion is registered for the WebGPU EP only, because the
-in-kernel norm path is currently implemented there. The CPU, CUDA, and JSEP
-GroupQueryAttention kernels reject q_norm_weight / k_norm_weight inputs.
+At present this fusion is registered for the CUDA and WebGPU EPs, because those
+kernels implement the in-kernel norm path. CPU and JSEP GroupQueryAttention
+kernels reject q_norm_weight / k_norm_weight inputs.
*/
class GroupQueryAttentionPreNormFusion : public GraphTransformer {
public:
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp
index bf5a182f9662c..ae5434b2bae04 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp
@@ -34,6 +34,13 @@ class DmlOperatorGroupQueryAttention : public DmlOperator, public GroupQueryAtte
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 1);
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1);
+ constexpr uint32_t qNormWeightIndex = 14;
+ constexpr uint32_t kNormWeightIndex = 15;
+ const bool hasQNormWeight = kernelCreationContext.GetInputCount() > qNormWeightIndex && kernelCreationContext.IsInputValid(qNormWeightIndex);
+ const bool hasKNormWeight = kernelCreationContext.GetInputCount() > kNormWeightIndex && kernelCreationContext.IsInputValid(kNormWeightIndex);
+ ML_CHECK_VALID_ARGUMENT(!hasQNormWeight && !hasKNormWeight,
+ "GroupQueryAttention (DML): q_norm_weight / k_norm_weight inputs are not supported.");
+
std::vector> inputIndices(inputCount);
inputIndices[queryIndex] = queryIndex;
inputIndices[keyIndex] = keyIndex;
diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
index 405b6133fa474..94408f4ed24f0 100644
--- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
@@ -116,7 +116,7 @@ static void RunGQASeqlensKTest(
tester.Run(expect, expected_message, {}, nullptr, &execution_providers);
}
-// CPU GroupQueryAttention does not implement the WebGPU-only fused Q/K RMS-norm prologue
+// CPU GroupQueryAttention does not implement the CUDA/WebGPU fused Q/K RMS-norm prologue
// inputs (q_norm_weight/k_norm_weight at indices 14/15). Ensure we reject these explicitly.
TEST(GroupQueryAttentionTest, CpuRejectsQKNormWeightInputs) {
constexpr int batch_size = 1;
@@ -168,9 +168,14 @@ TEST(GroupQueryAttentionTest, CpuRejectsQKNormWeightInputs) {
{}, nullptr, &execution_providers);
}
-// CUDA GroupQueryAttention also does not implement the WebGPU-only fused Q/K RMS-norm
-// prologue inputs (q_norm_weight/k_norm_weight at indices 14/15). Ensure the guard is covered.
-TEST(GroupQueryAttentionTest, CudaRejectsQKNormWeightInputs) {
+static void RunCudaQKNormInputContractTest(
+ OpTester::ExpectResult expect,
+ const std::string& expected_message,
+ std::optional qk_norm_epsilon = std::nullopt,
+ bool include_q_norm_weight = true,
+ bool include_k_norm_weight = true,
+ int q_norm_weight_size = 8,
+ int k_norm_weight_size = 8) {
auto cuda_ep = DefaultCudaExecutionProvider();
if (!cuda_ep) {
GTEST_SKIP() << "CUDA EP not available";
@@ -187,6 +192,9 @@ TEST(GroupQueryAttentionTest, CudaRejectsQKNormWeightInputs) {
OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
tester.AddAttribute("num_heads", static_cast(num_heads));
tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads));
+ if (qk_norm_epsilon.has_value()) {
+ tester.AddAttribute("qk_norm_epsilon", *qk_norm_epsilon);
+ }
tester.AddInput("query", {batch_size, sequence_length, hidden_size},
std::vector(batch_size * sequence_length * hidden_size, MLFloat16(0.1f)));
@@ -208,8 +216,16 @@ TEST(GroupQueryAttentionTest, CudaRejectsQKNormWeightInputs) {
tester.AddOptionalInputEdge(); // k_scale
tester.AddOptionalInputEdge(); // v_scale
- tester.AddInput("q_norm_weight", {head_size}, std::vector(head_size, MLFloat16(1.0f)));
- tester.AddInput("k_norm_weight", {head_size}, std::vector(head_size, MLFloat16(1.0f)));
+ if (include_q_norm_weight) {
+ tester.AddInput("q_norm_weight", {q_norm_weight_size},
+ std::vector(q_norm_weight_size, MLFloat16(1.0f)));
+ } else if (include_k_norm_weight) {
+ tester.AddOptionalInputEdge();
+ }
+ if (include_k_norm_weight) {
+ tester.AddInput("k_norm_weight", {k_norm_weight_size},
+ std::vector(k_norm_weight_size, MLFloat16(1.0f)));
+ }
tester.AddOutput("output", {batch_size, sequence_length, hidden_size},
std::vector(batch_size * sequence_length * hidden_size, MLFloat16(0.0f)));
@@ -218,11 +234,50 @@ TEST(GroupQueryAttentionTest, CudaRejectsQKNormWeightInputs) {
tester.AddOutput("present_value", {batch_size, kv_num_heads, sequence_length, head_size},
std::vector(batch_size * kv_num_heads * sequence_length * head_size, MLFloat16(0.0f)));
+ if (expect == OpTester::ExpectResult::kExpectSuccess) {
+ // This is an input-contract smoke test. The dedicated QK-Norm functional tests cover numerical equivalence.
+ tester.SetOutputTolerance(1e6f);
+ }
+
std::vector> execution_providers;
- execution_providers.push_back(DefaultCudaExecutionProvider());
- tester.Run(OpTester::ExpectResult::kExpectFailure,
- "q_norm_weight / k_norm_weight inputs are not supported",
- {}, nullptr, &execution_providers);
+ execution_providers.push_back(std::move(cuda_ep));
+ tester.Run(expect, expected_message, {}, nullptr, &execution_providers);
+}
+
+// CUDA GroupQueryAttention implements the fused Q/K RMS-norm prologue inputs
+// (q_norm_weight/k_norm_weight at indices 14/15). Ensure the input contract is accepted.
+TEST(GroupQueryAttentionTest, CudaAcceptsQKNormWeightInputs) {
+ RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectSuccess, "");
+}
+
+TEST(GroupQueryAttentionTest, CudaRejectsQKNormMissingKWeight) {
+ RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectFailure,
+ "q_norm_weight and k_norm_weight must be provided together",
+ std::nullopt,
+ /*include_q_norm_weight=*/true,
+ /*include_k_norm_weight=*/false);
+}
+
+TEST(GroupQueryAttentionTest, CudaRejectsQKNormWrongKWeightShape) {
+ RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectFailure,
+ "k_norm_weight must be a 1D tensor of shape",
+ std::nullopt,
+ /*include_q_norm_weight=*/true,
+ /*include_k_norm_weight=*/true,
+ /*q_norm_weight_size=*/8,
+ /*k_norm_weight_size=*/9);
+}
+
+TEST(GroupQueryAttentionTest, CudaRejectsZeroQKNormEpsilon) {
+ RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectFailure,
+ "qk_norm_epsilon must be finite and positive",
+ 0.0f);
+}
+
+TEST(GroupQueryAttentionTest, CudaRejectsNegativeQKNormEpsilon) {
+ RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectFailure,
+ "qk_norm_epsilon must be finite and positive",
+ -1.0e-6f);
}
// Regression: negative seqlens_k wraps to huge size_t, causing GEMM OOB.
@@ -515,7 +570,7 @@ static void ApplyPerHeadRmsNormBSNH(std::vector& data,
}
// Runs GroupQueryAttention with do_rotary=1 and optional q/k norm weights.
-// If q_norm_weight/k_norm_weight are provided, this exercises the WebGPU-only
+// If q_norm_weight/k_norm_weight are provided, this exercises the EP-specific
// q/k norm input contract. CPU callers should pass nullptr for both and feed
// pre-normalized Q/K values instead.
static std::vector RunGQARotaryWithOptionalQKNorm(
diff --git a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc
index ef0f40a51c838..2ac880a5459c5 100644
--- a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc
+++ b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include
+#include
#include
#include "core/graph/node_attr_utils.h"
@@ -246,16 +247,34 @@ Status CheckUnfusedGraph(const Graph& graph) {
} // namespace
-// Helper: build the transformer registered for the WebGPU EP only (matches production).
+// Helper: build the transformer for the original WebGPU-only tests.
std::unique_ptr MakeWebGpuTransformer() {
return std::make_unique(
InlinedHashSet{kWebGpuExecutionProvider});
}
+// Helper: build the production-compatible transformer.
+std::unique_ptr MakeCudaWebGpuTransformer() {
+ return std::make_unique(
+ InlinedHashSet{kCudaExecutionProvider, kWebGpuExecutionProvider});
+}
+
TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesQwenPattern) {
auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); };
ASSERT_STATUS_OK(TestGraphTransformer(
- build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ build, /*opset_version=*/21, *logger_, MakeCudaWebGpuTransformer(),
+ TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph));
+}
+
+TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesCudaAssignedQwenPattern) {
+ auto build = [](ModelTestBuilder& builder) {
+ BuildQwenQkPostNormPattern(builder, BuildOptions{});
+ for (auto& node : builder.graph_.Nodes()) {
+ node.SetExecutionProviderType(kCudaExecutionProvider);
+ }
+ };
+ ASSERT_STATUS_OK(TestGraphTransformer(
+ build, /*opset_version=*/21, *logger_, MakeCudaWebGpuTransformer(),
TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph));
}
@@ -270,9 +289,9 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesQwenPatter
auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); };
// opset 27 is under development in ONNX 1.22 (released map-max 27 > last release 26), so strict legs
// reject this *CurrentOpset model at load; allow the unreleased opset. Remove once opset 27 ships.
- // Tracked by #28966; this is the WebGPU attention-fusion path that surfaced #28969.
+ // Tracked by #28966; this runs the CUDA+WebGPU attention-fusion path that surfaced #28969.
ASSERT_STATUS_OK(TestGraphTransformer(
- build, /*opset_version=*/current_opset, *logger_, MakeWebGpuTransformer(),
+ build, /*opset_version=*/current_opset, *logger_, MakeCudaWebGpuTransformer(),
TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph,
ModelOptions{kAllowReleasedOpsetsOnly, /*strict_shape_type_inference*/ false}));
}
@@ -344,6 +363,26 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsEpsilonM
TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
}
+TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNonPositiveEpsilon) {
+ BuildOptions opts;
+ opts.q_epsilon = 0.0f;
+ opts.k_epsilon = 0.0f;
+ auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); };
+ ASSERT_STATUS_OK(TestGraphTransformer(
+ build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
+}
+
+TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNonFiniteEpsilon) {
+ BuildOptions opts;
+ opts.q_epsilon = std::numeric_limits::quiet_NaN();
+ opts.k_epsilon = std::numeric_limits::quiet_NaN();
+ auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); };
+ ASSERT_STATUS_OK(TestGraphTransformer(
+ build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
+}
+
TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsBadInnerReshape) {
BuildOptions opts;
opts.break_k_inner_reshape_shape = true;
@@ -363,7 +402,7 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNon1DNor
}
TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsCpuEp) {
- // Build the pattern but assign all nodes to CPU EP. The fusion is gated to WebGPU only,
+ // Build the pattern but assign all nodes to CPU EP. The fusion is gated to CUDA/WebGPU only,
// so the graph must remain unfused.
auto build = [](ModelTestBuilder& builder) {
BuildQwenQkPostNormPattern(builder, BuildOptions{});
@@ -372,13 +411,13 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsCpuEp) {
}
};
ASSERT_STATUS_OK(TestGraphTransformer(
- build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ build, /*opset_version=*/21, *logger_, MakeCudaWebGpuTransformer(),
TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
}
TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsJsEp) {
// JSEP does not implement the fused per-head Q/K RMSNorm prologue, so the optimizer
- // (which we now register for WebGPU only) must leave JSEP-assigned graphs alone.
+ // (which we now register for CUDA/WebGPU only) must leave JSEP-assigned graphs alone.
auto build = [](ModelTestBuilder& builder) {
BuildQwenQkPostNormPattern(builder, BuildOptions{});
for (auto& node : builder.graph_.Nodes()) {
@@ -386,7 +425,7 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsJsEp) {
}
};
ASSERT_STATUS_OK(TestGraphTransformer(
- build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ build, /*opset_version=*/21, *logger_, MakeCudaWebGpuTransformer(),
TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
}
diff --git a/onnxruntime/test/python/transformers/gqa_test_helper.py b/onnxruntime/test/python/transformers/gqa_test_helper.py
index d3dd86ea9bbc6..73fe45ac8fc7e 100644
--- a/onnxruntime/test/python/transformers/gqa_test_helper.py
+++ b/onnxruntime/test/python/transformers/gqa_test_helper.py
@@ -311,6 +311,8 @@ def __init__(
kv_cache_type: str = "float16",
share_kv_scale: bool = False,
has_head_sink: bool = False,
+ has_qk_norm: bool = False,
+ qk_norm_epsilon: float = 1e-6,
):
super().__init__(
"GroupQueryAttention",
@@ -343,6 +345,8 @@ def __init__(
self.v_quant_type = v_quant_type
self.share_kv_scale = share_kv_scale
self.has_head_sink = has_head_sink
+ self.has_qk_norm = has_qk_norm
+ self.qk_norm_epsilon = qk_norm_epsilon
# Determine bit width from cache type if applicable
if kv_cache_type == "int4":
self.kv_cache_bit_width = 4
@@ -363,6 +367,9 @@ def shape_dict(self):
)
if self.has_head_sink:
shapes["head_sink"] = (self.num_heads,)
+ if self.has_qk_norm:
+ shapes["q_norm_weight"] = (self.head_size,)
+ shapes["k_norm_weight"] = (self.head_size,)
# Note: We don't adjust shapes for int4 here because the parent's random_inputs
# creates float tensors first, then quantization will pack them
return shapes
@@ -378,6 +385,15 @@ def random_inputs(self):
if self.has_head_sink:
feeds["head_sink"] = torch.rand((self.num_heads,), device=self.device, dtype=self.dtype)
+ if self.has_qk_norm:
+ generator = torch.Generator(device=self.device).manual_seed(7)
+ feeds["q_norm_weight"] = (
+ 1.0 + 0.1 * torch.randn(self.head_size, generator=generator, device=self.device, dtype=torch.float32)
+ ).to(self.dtype)
+ feeds["k_norm_weight"] = (
+ 1.0 + 0.1 * torch.randn(self.head_size, generator=generator, device=self.device, dtype=torch.float32)
+ ).to(self.dtype)
+
# Generate quantized cache and scales if quantization is enabled
if self.k_quant_type != "NONE":
# Compute scales from the generated float cache
@@ -432,6 +448,8 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig):
"head_sink" if config.has_head_sink else "",
"k_scale" if config.k_quant_type != "NONE" else "",
"v_scale" if config.v_quant_type != "NONE" else "",
+ "q_norm_weight" if config.has_qk_norm else "",
+ "k_norm_weight" if config.has_qk_norm else "",
]
# Remove trailing empty strings
while node_inputs and node_inputs[-1] == "":
@@ -449,6 +467,9 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig):
"domain": "com.microsoft",
}
+ if config.has_qk_norm:
+ node_attrs["qk_norm_epsilon"] = config.qk_norm_epsilon
+
# Add quantization attributes if enabled
if config.k_quant_type != "NONE":
node_attrs["k_quant_type"] = config.k_quant_type
@@ -521,6 +542,14 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig):
if config.has_head_sink:
graph_input.append(helper.make_tensor_value_info("head_sink", float_type, list(shape_dict["head_sink"])))
+ if config.has_qk_norm:
+ graph_input.extend(
+ [
+ helper.make_tensor_value_info("q_norm_weight", float_type, list(shape_dict["q_norm_weight"])),
+ helper.make_tensor_value_info("k_norm_weight", float_type, list(shape_dict["k_norm_weight"])),
+ ]
+ )
+
# Add scale inputs for quantization
# Shape depends on quantization type:
# - PER_TENSOR: [1]
diff --git a/onnxruntime/test/python/transformers/profile_gqa.py b/onnxruntime/test/python/transformers/profile_gqa.py
index 49ce26b5126b8..394e2d6689cc9 100644
--- a/onnxruntime/test/python/transformers/profile_gqa.py
+++ b/onnxruntime/test/python/transformers/profile_gqa.py
@@ -73,6 +73,8 @@ def create_gqa_config(
has_head_sink: bool = False,
device: str = "cuda",
share_kv_scale: bool = False,
+ has_qk_norm: bool = False,
+ qk_norm_epsilon: float = 1e-6,
) -> GroupQueryAttentionConfig:
"""Create a GQA config based on the mode."""
if mode == "fp16":
@@ -118,6 +120,8 @@ def create_gqa_config(
v_quant_type=v_quant_type,
kv_cache_type=kv_cache_type,
share_kv_scale=share_kv_scale,
+ has_qk_norm=has_qk_norm,
+ qk_norm_epsilon=qk_norm_epsilon,
)
return config
@@ -158,6 +162,7 @@ def run_comparison(args):
print(f"Config: batch={args.batch_size}, seq_len={args.sequence_length}, past_seq={args.past_sequence_length}")
print(f" num_heads={args.num_heads}, kv_heads={args.kv_num_heads}, head_size={args.head_size}")
print(f" packed_qkv={args.is_packed_qkv}, rotary={not args.no_rotary}, head_sink={args.head_sink}")
+ print(f" qk_norm={args.qk_norm}, qk_norm_epsilon={args.qk_norm_epsilon}")
print(f" warmup={args.warmup}, repeat={args.repeat}")
print(f"{'=' * 70}\n")
@@ -179,10 +184,14 @@ def run_comparison(args):
do_rotary=not args.no_rotary,
has_head_sink=args.head_sink,
share_kv_scale=args.share_kv_scale,
+ has_qk_norm=args.qk_norm,
+ qk_norm_epsilon=args.qk_norm_epsilon,
)
- avg_ms = benchmark_gqa(config, warmup=args.warmup, repeat=args.repeat, mode=mode)
+ range_name = f"{mode}_qknorm" if args.qk_norm else mode
+ avg_ms = benchmark_gqa(config, warmup=args.warmup, repeat=args.repeat, mode=range_name)
results[mode] = avg_ms
- print(f" {mode.upper():6s} (dtype={config.dtype}): {avg_ms:.4f} ms")
+ suffix = "+QKNorm" if args.qk_norm else ""
+ print(f" {mode.upper() + suffix:13s} (dtype={config.dtype}): {avg_ms:.4f} ms")
# Print comparison if we have baseline
baseline = "fp16" if "fp16" in results else ("bf16" if "bf16" in results else None)
@@ -216,6 +225,8 @@ def main():
parser.add_argument("--repeat", type=int, default=100, help="Benchmark iterations")
parser.add_argument("--is-packed-qkv", action="store_true", help="Use packed QKV")
parser.add_argument("--head-sink", action="store_true", help="Add a head_sink input")
+ parser.add_argument("--qk-norm", action="store_true", help="Add q_norm_weight/k_norm_weight inputs")
+ parser.add_argument("--qk-norm-epsilon", type=float, default=1e-6, help="QK-Norm epsilon")
parser.add_argument("--no-rotary", action="store_true", help="Disable rotary embeddings")
parser.add_argument("--share-kv-scale", action="store_true", help="Share KV scale tensor for XQA")
diff --git a/onnxruntime/test/python/transformers/profile_gqa.sh b/onnxruntime/test/python/transformers/profile_gqa.sh
index 3fe10eaf35ad2..0dfb820fecbef 100644
--- a/onnxruntime/test/python/transformers/profile_gqa.sh
+++ b/onnxruntime/test/python/transformers/profile_gqa.sh
@@ -10,7 +10,9 @@
# ./profile_gqa.sh --all
# ./profile_gqa.sh --fp16 --int8
# ./profile_gqa.sh --fp16 --past-sequence-length 8192 --local-window-size 128
-# ./profile_gqa.sh --bf16 --num-heads 64 --kv-num-heads 8
+# ./profile_gqa.sh --bf16 --num-heads 64 --kv-num-heads 8 --head-size 128
+# ./profile_gqa.sh --fp16 --compare-qk-norm --past-sequence-length 2048
+# NSYS=~/cuda13.0/bin/nsys ./profile_gqa.sh --fp16 --qk-norm
# CUDA_VISIBLE_DEVICES=1 PYTHON=python3 ./profile_gqa.sh --int4
#
@@ -19,6 +21,7 @@ set -o pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PY="${PYTHON:-python}"
+NSYS="${NSYS:-}"
# Parse arguments
RUN_FP16=false
@@ -31,11 +34,20 @@ RUN_BF16=false
BATCH_SIZE=""
SEQUENCE_LENGTH=""
PAST_SEQUENCE_LENGTH=""
+MAX_SEQUENCE_LENGTH=""
PACKED_QKV=""
SHARE_KV_SCALE=""
NUM_HEADS=""
KV_NUM_HEADS=""
+HEAD_SIZE=""
LOCAL_WINDOW_SIZE=""
+HEAD_SINK=""
+NO_ROTARY=""
+QK_NORM=false
+COMPARE_QK_NORM=false
+QK_NORM_EPSILON=""
+WARMUP=5
+REPEAT=100
while [[ "$#" -gt 0 ]]; do
case $1 in
--fp16)
@@ -66,6 +78,29 @@ while [[ "$#" -gt 0 ]]; do
RUN_BF16=true
echo "==== 🚀 All runs enabled ===="
;;
+ --qk-norm)
+ QK_NORM=true
+ echo "==== QK-Norm enabled ===="
+ ;;
+ --compare-qk-norm)
+ COMPARE_QK_NORM=true
+ echo "==== Compare baseline vs QK-Norm enabled ===="
+ ;;
+ --qk-norm-epsilon)
+ QK_NORM_EPSILON="--qk-norm-epsilon $2"
+ echo "==== QK-Norm epsilon: $2 ===="
+ shift
+ ;;
+ --warmup)
+ WARMUP="$2"
+ echo "==== Warmup iterations: $2 ===="
+ shift
+ ;;
+ --repeat)
+ REPEAT="$2"
+ echo "==== Repeat iterations: $2 ===="
+ shift
+ ;;
-b|--batch-size)
BATCH_SIZE="--batch-size $2"
echo "==== Batch size: $2 ===="
@@ -81,6 +116,11 @@ while [[ "$#" -gt 0 ]]; do
echo "==== Past sequence length: $2 ===="
shift
;;
+ --max-sequence-length)
+ MAX_SEQUENCE_LENGTH="--max-sequence-length $2"
+ echo "==== Max sequence length: $2 ===="
+ shift
+ ;;
--qkv)
PACKED_QKV="--is-packed-qkv"
echo "==== Packed QKV enabled ===="
@@ -99,11 +139,24 @@ while [[ "$#" -gt 0 ]]; do
echo "==== KV Num Heads: $2 ===="
shift
;;
+ --head-size)
+ HEAD_SIZE="--head-size $2"
+ echo "==== Head size: $2 ===="
+ shift
+ ;;
-w|--local-window-size)
LOCAL_WINDOW_SIZE="--local-window-size $2"
echo "==== Local window size: $2 ===="
shift
;;
+ --head-sink)
+ HEAD_SINK="--head-sink"
+ echo "==== Head sink enabled ===="
+ ;;
+ --no-rotary)
+ NO_ROTARY="--no-rotary"
+ echo "==== Rotary disabled ===="
+ ;;
*)
echo "Unknown option: $1"
exit 1
@@ -113,10 +166,19 @@ while [[ "$#" -gt 0 ]]; do
done
# Build extra args string
-EXTRA_ARGS="${BATCH_SIZE} ${SEQUENCE_LENGTH} ${PAST_SEQUENCE_LENGTH} ${PACKED_QKV} ${SHARE_KV_SCALE} ${NUM_HEADS} ${KV_NUM_HEADS} ${LOCAL_WINDOW_SIZE}"
+EXTRA_ARGS="${BATCH_SIZE} ${SEQUENCE_LENGTH} ${PAST_SEQUENCE_LENGTH} ${MAX_SEQUENCE_LENGTH} ${PACKED_QKV} ${SHARE_KV_SCALE} ${NUM_HEADS} ${KV_NUM_HEADS} ${HEAD_SIZE} ${LOCAL_WINDOW_SIZE} ${HEAD_SINK} ${NO_ROTARY} ${QK_NORM_EPSILON}"
-if ! command -v nsys >/dev/null; then
+if [[ -z "${NSYS}" ]]; then
+ if command -v nsys >/dev/null; then
+ NSYS="$(command -v nsys)"
+ elif [[ -x "${HOME}/cuda13.0/bin/nsys" ]]; then
+ NSYS="${HOME}/cuda13.0/bin/nsys"
+ fi
+fi
+
+if [[ -z "${NSYS}" || ! -x "${NSYS}" ]]; then
echo "Error: nsys not found. Install NVIDIA Nsight Systems or add it to PATH."
+ echo " Or set NSYS=/path/to/nsys (for example NSYS=~/cuda13.0/bin/nsys)."
exit 1
fi
@@ -129,12 +191,13 @@ else
echo " Falling back to --skip-first to exclude warmup-like first calls."
fi
-# profile_one [env_var=value ...]
+# profile_one [env_var=value ...]
profile_one() {
local mode="$1"
local tag="$2"
local base="$3"
- shift 3
+ local cli_extra="$4"
+ shift 4
local env_args=()
local e
@@ -145,34 +208,56 @@ profile_one() {
echo ""
echo "---- Profiling ${mode} ----"
rm -f "${base}.nsys-rep" "${base}.sqlite"
- nsys profile -t cuda,nvtx --force-overwrite true "${env_args[@]}" -o "${base}" --export=sqlite \
- "${PY}" "${SCRIPT_DIR}/profile_gqa.py" --mode "${mode}" --warmup 5 --repeat 100 ${EXTRA_ARGS}
+ "${NSYS}" profile -t cuda,nvtx --force-overwrite true "${env_args[@]}" -o "${base}" --export=sqlite \
+ "${PY}" "${SCRIPT_DIR}/profile_gqa.py" --mode "${mode}" --warmup "${WARMUP}" --repeat "${REPEAT}" ${EXTRA_ARGS} ${cli_extra}
echo ""
echo "---- Kernel results (${mode}) ----"
+ local range_name="benchmark_${mode}"
+ if [[ "${cli_extra}" == *"--qk-norm"* ]]; then
+ range_name="benchmark_${mode}_qknorm"
+ fi
if [[ "${HAVE_NVTX}" -eq 1 ]]; then
- "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --nvtx-range "benchmark_${mode}" --tag "${tag}"
+ "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --nvtx-range "${range_name}" --tag "${tag}" \
+ --pattern "%onnxruntime%" --pattern "%cudnn%"
+ else
+ "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --skip-first "${WARMUP}" --tag "${tag}" \
+ --pattern "%onnxruntime%" --pattern "%cudnn%"
+ fi
+}
+
+profile_mode() {
+ local mode="$1"
+ local tag="$2"
+ local base="$3"
+ shift 3
+
+ if [[ "${COMPARE_QK_NORM}" == true ]]; then
+ profile_one "${mode}" "${tag}" "${base}" "" "$@"
+ profile_one "${mode}" "${tag}QK" "${base}_qknorm" "--qk-norm" "$@"
+ elif [[ "${QK_NORM}" == true ]]; then
+ profile_one "${mode}" "${tag}QK" "${base}_qknorm" "--qk-norm" "$@"
else
- "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --skip-first 5 --tag "${tag}"
+ profile_one "${mode}" "${tag}" "${base}" "" "$@"
fi
}
if [ "$RUN_FP16" = true ]; then
- profile_one fp16 Fp16 gqa_fp16
+ profile_mode fp16 Fp16 gqa_fp16
fi
if [ "$RUN_BF16" = true ]; then
- profile_one bf16 Bf16 gqa_bf16
+ profile_mode bf16 Bf16 gqa_bf16
fi
if [ "$RUN_INT8" = true ]; then
- profile_one int8 Int8 gqa_int8 ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=0
+ profile_mode int8 Int8 gqa_int8 ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=0
fi
if [ "$RUN_INT8_QUANT" = true ]; then
- profile_one int8 Int8Q gqa_int8_quant ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=1
+ profile_mode int8 Int8Q gqa_int8_quant ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=1
fi
if [ "$RUN_INT4" = true ]; then
- profile_one int4 Int4 gqa_int4
+ profile_mode int4 Int4 gqa_int4
fi
diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py
index 2ff4cdc988a42..f0f6865bb463d 100644
--- a/onnxruntime/test/python/transformers/test_gqa.py
+++ b/onnxruntime/test/python/transformers/test_gqa.py
@@ -104,6 +104,10 @@ class GQAConfig:
v_quant_type: str = "NONE"
kv_cache_bit_width: int = 0
+ # Fused per-head Q/K RMSNorm (QK-Norm) applied before RoPE. Weight shape (head_size,) shared across heads.
+ has_qk_norm: bool = False
+ qk_norm_epsilon: float = 1e-6
+
# #################################################################################################
# Rotary Embedding Implementations (CPU and CUDA)
@@ -205,6 +209,25 @@ def make_head_sink_initializer(head_sink, ort_type, num_heads):
return helper.make_tensor(name="head_sink", data_type=ort_type, dims=[num_heads], vals=raw, raw=True)
+def make_qk_norm_weights(head_size, device, torch_type, seed=7):
+ """Generate deterministic per-head Q/K RMSNorm weights of shape (head_size,)."""
+ gen = torch.Generator(device=device).manual_seed(seed)
+ q_w = (1.0 + 0.1 * torch.randn(head_size, generator=gen, device=device, dtype=torch.float32)).to(torch_type)
+ k_w = (1.0 + 0.1 * torch.randn(head_size, generator=gen, device=device, dtype=torch.float32)).to(torch_type)
+ return q_w, k_w
+
+
+def apply_qk_rmsnorm(x, weight, eps):
+ """Reference per-head RMSNorm over the last (head_size) dim, computed in float32 then cast back.
+
+ x_norm[c] = x[c] * rsqrt(mean(x^2) + eps) * weight[c]
+ """
+ dtype = x.dtype
+ xf = x.to(torch.float32)
+ inv_rms = torch.rsqrt(xf.pow(2).mean(dim=-1, keepdim=True) + eps)
+ return (xf * inv_rms * weight.to(torch.float32)).to(dtype)
+
+
def create_gqa_node_and_io(
config: GQAConfig,
ort_type,
@@ -264,6 +287,8 @@ def create_gqa_node_and_io(
"k_scale"
if config.share_kv_scale and config.k_quant_type != "NONE"
else ("v_scale" if config.v_quant_type != "NONE" else ""),
+ "q_norm_weight" if config.has_qk_norm else "",
+ "k_norm_weight" if config.has_qk_norm else "",
]
# Remove trailing empty strings
@@ -280,6 +305,8 @@ def create_gqa_node_and_io(
else {}
)
+ qk_norm_attributes = {"qk_norm_epsilon": config.qk_norm_epsilon} if config.has_qk_norm else {}
+
node = helper.make_node(
op_type="GroupQueryAttention",
inputs=inputs,
@@ -294,6 +321,7 @@ def create_gqa_node_and_io(
smooth_softmax=1 if config.use_smooth_softmax else 0,
qk_output=output_qk,
**quantization_attributes,
+ **qk_norm_attributes,
domain="com.microsoft",
)
@@ -372,6 +400,10 @@ def create_gqa_node_and_io(
else:
graph_input.append(helper.make_tensor_value_info("head_sink", ort_type, [config.num_heads]))
+ if config.has_qk_norm:
+ graph_input.append(helper.make_tensor_value_info("q_norm_weight", ort_type, [config.head_size]))
+ graph_input.append(helper.make_tensor_value_info("k_norm_weight", ort_type, [config.head_size]))
+
# --- Graph Outputs ---
output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size]
if config.kv_cache_type == "int4":
@@ -467,6 +499,8 @@ def gqa_prompt_func(
device,
share_buffer=True,
ort_type=TensorProto.FLOAT16,
+ q_norm_weight=None,
+ k_norm_weight=None,
):
if not config.kv_cache_type:
config.kv_cache_type = "float16" if ort_type == TensorProto.FLOAT16 else "bfloat16"
@@ -479,8 +513,9 @@ def gqa_prompt_func(
q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1))
if new_k is not None:
- new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1))
- new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1))
+ kv_hidden_size = config.kv_num_heads * config.head_size
+ new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, kv_hidden_size))
+ new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, kv_hidden_size))
sess_options = SessionOptions()
ort_session = InferenceSession(onnx_model_str, sess_options, providers=[resolve_cuda_plugin_ep(ep)])
@@ -537,6 +572,10 @@ def gqa_prompt_func(
if config.has_head_sink and head_sink is not None:
bind_tensor(io_binding, "head_sink", head_sink, device, ort_type)
+ if config.has_qk_norm and q_norm_weight is not None and k_norm_weight is not None:
+ bind_tensor(io_binding, "q_norm_weight", q_norm_weight, device, ort_type)
+ bind_tensor(io_binding, "k_norm_weight", k_norm_weight, device, ort_type)
+
# 6. Quantization scales
if k_scale is not None:
k_scale_ort_type = TensorProto.FLOAT
@@ -627,6 +666,8 @@ def gqa_past_func(
device,
share_buffer=True,
ort_type=TensorProto.FLOAT16,
+ q_norm_weight=None,
+ k_norm_weight=None,
):
if not config.kv_cache_type:
config.kv_cache_type = "float16" if ort_type == TensorProto.FLOAT16 else "bfloat16"
@@ -641,8 +682,9 @@ def gqa_past_func(
q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1))
if new_k is not None:
- new_k = torch.reshape(new_k, (config.batch_size, config.q_sequence_length, -1))
- new_v = torch.reshape(new_v, (config.batch_size, config.q_sequence_length, -1))
+ kv_hidden_size = config.kv_num_heads * config.head_size
+ new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, kv_hidden_size))
+ new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, kv_hidden_size))
sess_options = SessionOptions()
# sess_options.log_severity_level = 0
@@ -650,7 +692,7 @@ def gqa_past_func(
io_binding = ort_session.io_binding()
# Common inputs
- total_seq_len = config.past_kv_sequence_length + config.q_sequence_length
+ total_seq_len = config.past_kv_sequence_length + config.kv_sequence_length
# 1. Bind 'query'
bind_tensor(io_binding, "query", q, device, ort_type)
@@ -702,6 +744,10 @@ def gqa_past_func(
if config.has_head_sink and head_sink is not None and not head_sink_as_initializer:
bind_tensor(io_binding, "head_sink", head_sink, device, ort_type)
+ if config.has_qk_norm and q_norm_weight is not None and k_norm_weight is not None:
+ bind_tensor(io_binding, "q_norm_weight", q_norm_weight, device, ort_type)
+ bind_tensor(io_binding, "k_norm_weight", k_norm_weight, device, ort_type)
+
# 6. Quantization
if k_scale is not None:
k_scale_ort_type = TensorProto.FLOAT
@@ -1000,13 +1046,18 @@ def parity_check_gqa_prompt(
rotary_seqlens = torch.zeros(config.batch_size, device=device, dtype=torch.long)
cos, sin, q_ro, k_ro = None, None, q, new_k
+ q_norm_weight, k_norm_weight = None, None
+ if config.has_qk_norm:
+ q_norm_weight, k_norm_weight = make_qk_norm_weights(config.head_size, device, torch_type)
+ q_ro = apply_qk_rmsnorm(q, q_norm_weight, config.qk_norm_epsilon)
+ k_ro = apply_qk_rmsnorm(new_k, k_norm_weight, config.qk_norm_epsilon)
if config.rotary:
rotary_dim = math.floor(config.head_size / 16) * 16
angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=torch_type)
sin = torch.sin(angle).to(dtype=torch_type)
- q_ro = apply_rotary_embedding(q.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device)
- k_ro = apply_rotary_embedding(new_k.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device)
+ q_ro = apply_rotary_embedding(q_ro.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device)
+ k_ro = apply_rotary_embedding(k_ro.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device)
position_ids, attention_bias = None, None
if config.has_position_ids:
@@ -1096,6 +1147,8 @@ def parity_check_gqa_prompt(
device=device,
share_buffer=config.share_buffer,
ort_type=ort_type,
+ q_norm_weight=q_norm_weight,
+ k_norm_weight=k_norm_weight,
)
out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size))
out_np = out.to(torch.float32).detach().cpu().numpy()
@@ -1277,7 +1330,7 @@ def parity_check_gqa_past(
new_k = (
torch.randn(
config.batch_size,
- config.q_sequence_length,
+ config.kv_sequence_length,
config.kv_num_heads,
config.head_size,
device=device,
@@ -1318,16 +1371,21 @@ def parity_check_gqa_past(
v_cache_ref = v_ref_dequant.clone().transpose(1, 2)
cos, sin, q_ro, k_ro = None, None, q, new_k
+ q_norm_weight, k_norm_weight = None, None
+ if config.has_qk_norm:
+ q_norm_weight, k_norm_weight = make_qk_norm_weights(config.head_size, device, torch_type)
+ q_ro = apply_qk_rmsnorm(q, q_norm_weight, config.qk_norm_epsilon)
+ k_ro = apply_qk_rmsnorm(new_k, k_norm_weight, config.qk_norm_epsilon)
if config.rotary:
rotary_dim = math.floor(config.head_size / 16) * 16
angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=torch_type)
sin = torch.sin(angle).to(dtype=torch_type)
- q_ro = apply_rotary_embedding(q.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device)
- k_ro = apply_rotary_embedding(new_k.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device)
+ q_ro = apply_rotary_embedding(q_ro.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device)
+ k_ro = apply_rotary_embedding(k_ro.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device)
position_ids, attention_bias = None, None
- total_seq_len = config.past_kv_sequence_length + config.q_sequence_length
+ total_seq_len = config.past_kv_sequence_length + config.kv_sequence_length
if config.has_position_ids:
position_ids = (cache_seqlens.unsqueeze(1) + torch.arange(config.q_sequence_length, device=device)).long()
if config.has_attention_bias:
@@ -1341,7 +1399,7 @@ def parity_check_gqa_past(
arange = rearrange(torch.arange(config.buffer_sequence_length, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
update_mask = torch.logical_and(
- cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.q_sequence_length
+ cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.kv_sequence_length
)
k_to_cache = k_ro
@@ -1369,7 +1427,7 @@ def parity_check_gqa_past(
k_cache_ref[update_mask] = rearrange(k_to_cache, "b s ... -> (b s) ...").to(k_cache_ref.dtype)
v_cache_ref[update_mask] = rearrange(v_to_cache, "b s ... -> (b s) ...").to(v_cache_ref.dtype)
- key_padding_mask = arange < cache_seqlens_expanded + config.q_sequence_length
+ key_padding_mask = arange < cache_seqlens_expanded + config.kv_sequence_length
out_ref, _ = attention_ref(
q=q_ro,
@@ -1405,7 +1463,7 @@ def parity_check_gqa_past(
k_ort = k_ort.contiguous()
v_ort = v_ort.contiguous()
- ort_seqlens = cache_seqlens + config.q_sequence_length - 1
+ ort_seqlens = cache_seqlens + config.kv_sequence_length - 1
out, present_k, present_v = gqa_past_func(
q=q_ort,
@@ -1426,6 +1484,8 @@ def parity_check_gqa_past(
device=device,
share_buffer=config.share_buffer,
ort_type=ort_type,
+ q_norm_weight=q_norm_weight,
+ k_norm_weight=k_norm_weight,
)
out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size))
out_np = out.to(torch.float32).detach().cpu().numpy()
@@ -2105,6 +2165,214 @@ def test_gqa_past_flash_attention_bf16(self, name, config):
)
+def gqa_qk_norm_test_cases(is_past: bool):
+ """Configs exercising the fused per-head Q/K RMSNorm (QK-Norm) prologue before RoPE."""
+ head_sizes = [64, 128]
+ head_groups = [(8, 2), (4, 4)]
+ rotary_opts = [(False, False), (True, False), (True, True)]
+ packed_opts = [False, True]
+ idx = 0
+ for h in head_sizes:
+ for n, n2 in head_groups:
+ for rotary, interleaved in rotary_opts:
+ if rotary and h % 16 != 0:
+ continue
+ packed = packed_opts[idx % len(packed_opts)]
+ idx += 1
+ if is_past:
+ b, s, s2 = 2, 1, 127
+ config = GQAConfig(
+ batch_size=b,
+ q_sequence_length=s,
+ kv_sequence_length=s,
+ past_kv_sequence_length=s2,
+ buffer_sequence_length=s + s2 + 8,
+ num_heads=n,
+ kv_num_heads=n2,
+ head_size=h,
+ rotary=rotary,
+ rotary_interleaved=interleaved,
+ packed=packed,
+ share_buffer=True,
+ has_qk_norm=True,
+ )
+ else:
+ b, s = 2, 64
+ config = GQAConfig(
+ batch_size=b,
+ q_sequence_length=s,
+ kv_sequence_length=s,
+ buffer_sequence_length=s + 8,
+ num_heads=n,
+ kv_num_heads=n2,
+ head_size=h,
+ rotary=rotary,
+ rotary_interleaved=interleaved,
+ packed=packed,
+ share_buffer=True,
+ has_qk_norm=True,
+ )
+ name = f"{'past' if is_past else 'prompt'}_b{b}_nh{n}_{n2}_h{h}_rot{rotary}{interleaved}_pkd{packed}"
+ yield name, config
+
+
+@unittest.skipIf(not has_cuda_device(80), "CUDA GQA QK-Norm requires Ampere or higher GPU, skipping tests.")
+class TestGQAQKNorm(unittest.TestCase):
+ def tearDown(self):
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ @parameterized.expand(gqa_qk_norm_test_cases(is_past=False))
+ def test_gqa_qk_norm_prompt(self, name, config):
+ if enable_debug_print:
+ print("-" * 20)
+ print(f"test_case: {name}\n{config}")
+
+ with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"):
+ parity_check_gqa_prompt(
+ config=config,
+ ep="CUDAExecutionProvider",
+ device="cuda",
+ torch_type=torch.float16,
+ ort_type=TensorProto.FLOAT16,
+ causal=True,
+ rtol=rtol["fp16"],
+ atol=atol["fp16"],
+ )
+
+ @parameterized.expand(gqa_qk_norm_test_cases(is_past=True))
+ def test_gqa_qk_norm_past(self, name, config):
+ if enable_debug_print:
+ print("-" * 20)
+ print(f"test_case: {name}\n{config}")
+
+ with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"):
+ parity_check_gqa_past(
+ config=config,
+ ep="CUDAExecutionProvider",
+ device="cuda",
+ torch_type=torch.float16,
+ ort_type=TensorProto.FLOAT16,
+ causal=True,
+ rtol=rtol["fp16"],
+ atol=atol["fp16"],
+ )
+
+ def test_gqa_qk_norm_past_xqa(self):
+ config = GQAConfig(
+ batch_size=2,
+ q_sequence_length=1,
+ kv_sequence_length=1,
+ past_kv_sequence_length=127,
+ buffer_sequence_length=136,
+ num_heads=32,
+ kv_num_heads=8,
+ head_size=128,
+ rotary=True,
+ rotary_interleaved=False,
+ packed=False,
+ share_buffer=True,
+ has_qk_norm=True,
+ )
+
+ with scoped_env_var("ORT_ENABLE_XQA", "1"):
+ parity_check_gqa_past(
+ config=config,
+ ep="CUDAExecutionProvider",
+ device="cuda",
+ torch_type=torch.float16,
+ ort_type=TensorProto.FLOAT16,
+ causal=True,
+ rtol=rtol["fp16"],
+ atol=atol["fp16"],
+ )
+
+ def test_gqa_qk_norm_past_shared_kv(self):
+ config = GQAConfig(
+ batch_size=2,
+ q_sequence_length=1,
+ kv_sequence_length=0,
+ past_kv_sequence_length=127,
+ buffer_sequence_length=135,
+ num_heads=8,
+ kv_num_heads=2,
+ head_size=64,
+ rotary=False,
+ packed=False,
+ share_buffer=True,
+ has_qk_norm=True,
+ )
+
+ with scoped_env_var("ORT_ENABLE_XQA", "1"):
+ parity_check_gqa_past(
+ config=config,
+ ep="CUDAExecutionProvider",
+ device="cuda",
+ torch_type=torch.float16,
+ ort_type=TensorProto.FLOAT16,
+ causal=True,
+ rtol=rtol["fp16"],
+ atol=atol["fp16"],
+ )
+
+ def test_gqa_qk_norm_past_xqa_bf16(self):
+ if not torch.cuda.is_bf16_supported():
+ self.skipTest("BFloat16 not supported on this device")
+
+ config = GQAConfig(
+ batch_size=2,
+ q_sequence_length=1,
+ kv_sequence_length=1,
+ past_kv_sequence_length=127,
+ buffer_sequence_length=136,
+ num_heads=32,
+ kv_num_heads=8,
+ head_size=128,
+ rotary=True,
+ rotary_interleaved=False,
+ packed=False,
+ share_buffer=True,
+ has_qk_norm=True,
+ )
+ config.kv_cache_type = "bfloat16"
+
+ with scoped_env_var("ORT_ENABLE_XQA", "1"):
+ parity_check_gqa_past(
+ config=config,
+ ep="CUDAExecutionProvider",
+ device="cuda",
+ torch_type=torch.bfloat16,
+ ort_type=TensorProto.BFLOAT16,
+ causal=True,
+ rtol=rtol["bf16"],
+ atol=atol["bf16"],
+ )
+
+ @parameterized.expand(gqa_qk_norm_test_cases(is_past=True))
+ def test_gqa_qk_norm_past_bf16(self, name, config):
+ if not torch.cuda.is_bf16_supported():
+ self.skipTest("BFloat16 not supported on this device")
+
+ if enable_debug_print:
+ print("-" * 20)
+ print(f"test_case: {name}\n{config}")
+
+ config.kv_cache_type = "bfloat16"
+ with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"):
+ parity_check_gqa_past(
+ config=config,
+ ep="CUDAExecutionProvider",
+ device="cuda",
+ torch_type=torch.bfloat16,
+ ort_type=TensorProto.BFLOAT16,
+ causal=True,
+ rtol=rtol["bf16"],
+ atol=atol["bf16"],
+ )
+
+
@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.")
@unittest.skipIf(not has_quantized_kv_cache(), "Quantized KV Cache is not available, skipping tests.")
class TestFlashGQABF16QuantizedKV(unittest.TestCase):