diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 38d101786b41a..79296736faf97 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2675,7 +2675,7 @@ This version of the operator has been available since version 1 of the 'com.micr
v_scale (optional) : T_KV_SCALE
Scale tensor for past_value.
q_norm_weight (optional) : T
-
Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion folds that pattern into this input. Currently honored by the native WebGPU execution provider only; JSEP WebGPU/JS and other EPs must reject the node when this input is set.
+
Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion folds that pattern into this input. Currently honored by the CUDA and native WebGPU execution providers; JSEP WebGPU/JS and other EPs must reject the node when this input is set.
k_norm_weight (optional) : T
Optional 1D tensor of shape (head_size). See q_norm_weight. Must be provided together with q_norm_weight.
diff --git a/docs/contrib_ops/cuda/gqa.md b/docs/contrib_ops/cuda/gqa.md index aeb88dbd07588..aa289a0de9a3c 100644 --- a/docs/contrib_ops/cuda/gqa.md +++ b/docs/contrib_ops/cuda/gqa.md @@ -58,6 +58,7 @@ Selected attributes: | `local_window_size` | Left window size for local attention. `-1` means global attention. | | `do_rotary` / `rotary_interleaved` | Enable RoPE and select interleaved vs. half-rotary layout. | | `smooth_softmax` | Add a smooth factor to the softmax denominator. | +| `qk_norm_epsilon` | Epsilon for the fused per-head Q/K RMSNorm (QK-Norm) prologue. Defaults to `1e-6`. | | `k_quant_type` / `v_quant_type` | KV cache quantization mode: `NONE`, `PER_TENSOR`, or `PER_CHANNEL`. | | `kv_cache_bit_width` | Bit width of the quantized KV cache (`8` or `4`). | @@ -73,6 +74,7 @@ Selected inputs (see the schema for the full list and shapes): | 7, 8 | `cos_cache`, `sin_cache` | RoPE caches. | | 11 | `head_sink` | `(num_heads,)` per-head attention sink (see §5). | | 12, 13 | `k_scale`, `v_scale` | FP32 dequant scales for the quantized KV cache. | +| 14, 15 | `q_norm_weight`, `k_norm_weight` | `(head_size,)` per-head Q/K RMSNorm weights (QK-Norm, see §3). Both must be present together. | Outputs are `output`, `present_key`, `present_value`, and optional `output_qk`. @@ -118,6 +120,31 @@ last dimension is `(head_size + 1) / 2` because two nibbles are packed per byte. - RoPE, packed-QKV unpacking, and KV-head expansion are handled internally (`PrepareQKV`) before the selected backend runs, so every backend sees a consistent layout. +### Fused QK-Norm (per-head Q/K RMSNorm) + +When the optional `q_norm_weight` (input 14) and `k_norm_weight` (input 15) tensors are provided, the +CUDA kernel applies a fused per-head RMS normalization to Q and K **before** RoPE. This matches the +QK-Norm used by **Qwen3, Gemma 2/3, OLMo2, SmolLM3**, etc. For each head, over the `head_size` +channels: + +$$ +x_\text{norm}[c] = x[c] \cdot \frac{1}{\sqrt{\frac{1}{H}\sum_{j} x[j]^2 + \epsilon}} \cdot w[c] +$$ + +where `H = head_size`, `w` is the per-head weight vector (`q_norm_weight` for Q, `k_norm_weight` for +K), and `epsilon = qk_norm_epsilon` (default `1e-6`). The sum of squares is reduced in FP32 for +numerical stability and the result is cast back to the operator type `T`. + +- Both weights are 1D tensors of shape `(head_size,)`, share the operator's element type `T` + (`float16`/`bfloat16`), and are **shared across all heads**. They must be supplied together — + providing only one is rejected. +- The normalization is fused into the `PrepareQKV` prologue (`UnpackRoPEAppend` for the new-KV path, + or a standalone per-head RMSNorm kernel for the shared-buffer Q-only decode case), so it composes + with packed QKV, RoPE, KV-head expansion, and the quantized KV cache. +- Because the Flash-Decoding fast path does its own RoPE/append internally and bypasses `PrepareQKV`, + it is disabled when QK-Norm is present (see §6). The non-quantized XQA decode path can still run + with QK-Norm: CUDA normalizes Q/K in the `UnpackRoPEAppend` preprocess before launching XQA. + ## 4. KV Cache and Quantization ### Layout and shared buffer @@ -197,6 +224,13 @@ order and the first eligible backend wins: The selected backend is reported in the kernel debug info as `SdpaKernel=...` when debug info is enabled (see §10). +> **QK-Norm interaction.** When `q_norm_weight` / `k_norm_weight` are present (see §3), the +> Flash-Decoding fast path is disabled so the QK-Norm prologue always runs. Non-quantized XQA decode +> remains eligible for supported shapes: the `UnpackRoPEAppend` preprocess normalizes Q/K, applies +> RoPE, appends K/V, and then XQA consumes the normalized Q and cache. Quantized-cache QK-Norm decode +> still falls back to Flash Attention (or cuDNN SDPA / MEA / Unfused) until normalized-K scale +> handling is validated for XQA. + ### 6.1 XQA Checked first. Used only for single-token global decode under the conditions detailed in §7. When @@ -402,10 +436,14 @@ CUDA parity tests live in - `TestXQAQuantizedParity` — XQA per-tensor int8 quantized decode parity. - `TestXQAHeadSinkParity` — non-quantized XQA decode parity with a `head_sink` (attention sink) input. +- `TestGQAQKNorm` — fused per-head Q/K RMSNorm (QK-Norm) parity for prompt and decode (past), FP16 and + BF16, across packed/unpacked Q/K/V and with/without RoPE. `TestXQAQuantizedParity` sets `ORT_ENABLE_XQA=1` to force the XQA path. `TestXQAHeadSinkParity` instead clears `ORT_ENABLE_XQA` to validate that XQA is enabled by default when a `head_sink` input is present. Both compare against a PyTorch reference (`attention_ref` with `smooth_softmax_ref`). +`TestGQAQKNorm` applies the RMSNorm-before-RoPE reference to Q and K and compares against the CUDA +output. ## 13. Future Work and Known Limitations @@ -414,9 +452,10 @@ popular LLMs. Listed roughly by impact. ### High impact -1. **Fused QK-Norm (per-head Q/K RMSNorm prologue).** The CUDA kernel rejects `q_norm_weight` / - `k_norm_weight`; only the WebGPU EP implements the fused prologue. Required by **Qwen3, - Gemma 2/3, OLMo2, SmolLM3**, etc., which otherwise run the normalization unfused. +1. **Fused QK-Norm (per-head Q/K RMSNorm prologue).** *Implemented.* The CUDA kernel applies the + fused per-head RMSNorm to Q and K before RoPE when `q_norm_weight` / `k_norm_weight` are provided + (see §3), matching **Qwen3, Gemma 2/3, OLMo2, SmolLM3**, etc. Remaining limitation: QK-Norm + disables Flash-Decoding, and quantized-cache QK-Norm does not yet get the XQA fast path. 2. **Sliding-window + attention-sink on the fused decode path.** XQA requires global attention (`local_window_size == -1`), so **GPT-OSS / Mistral / Gemma 2** layers that combine a sliding window with a `head_sink` fall back to Flash / Flash-Decoding instead of XQA. Unifying sliding diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index ada0c65bdd8a7..4e2d3b82ce155 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -328,14 +328,14 @@ const generatePositionIdsProgramInfo = ( }; export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => { - // q_norm_weight (input 14) / k_norm_weight (input 15) are emitted by the WebGPU-only + // q_norm_weight (input 14) / k_norm_weight (input 15) are emitted by the CUDA/native WebGPU // GroupQueryAttentionPreNormFusion optimizer pass. JSEP does not implement the fused // per-head Q/K RMS normalization prologue, so reject the node if either input is present // (regardless of rank, including scalars) rather than silently dropping the normalization. if ((context.inputs.length > 14 && context.inputs[14]) || (context.inputs.length > 15 && context.inputs[15])) { throw new Error( 'GroupQueryAttention (JSEP): q_norm_weight / k_norm_weight inputs are not supported. ' + - 'The per-head Q/K RMS normalization prologue is implemented only on the native WebGPU EP.', + 'The per-head Q/K RMS normalization prologue is implemented only on the CUDA and native WebGPU EPs.', ); } const params = validateInputs(context.inputs, attributes); diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index 5b7624d11c6fd..11be4be1638df 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -93,6 +93,8 @@ struct GroupQueryAttentionParameters : AttentionParameters { bool is_first_prompt; // indicates whether this is first decoding step bool rotary_interleaved; bool use_smooth_softmax; + bool use_qk_norm = false; // per-head Q/K RMSNorm (QK-Norm) prologue before RoPE (inputs 14/15) + float qk_norm_epsilon = 1e-6f; // epsilon for the QK-Norm RMSNorm float softcap; AttentionQkvFormat past_kv_format; int zeros_count; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 61ae474703213..e36bdb2de263a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -84,7 +84,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { "kv_cache_bit_width must be 0 when quantization is disabled, got ", kv_cache_bit_width_); } - // q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only + // q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the CUDA/WebGPU // GroupQueryAttentionPreNormFusion optimizer pass. The CPU kernel does not implement // the fused per-head Q/K RMS normalization prologue, so reject the node if either input // is present rather than silently dropping the normalization. @@ -93,7 +93,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { return ORT_MAKE_STATUS( ONNXRUNTIME, INVALID_ARGUMENT, "GroupQueryAttention (CPU): q_norm_weight / k_norm_weight inputs are not supported. " - "The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP."); + "The per-head Q/K RMS normalization prologue is implemented only on the CUDA and WebGPU EPs."); } GroupQueryAttentionParameters parameters = {}; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 74aeaf1285e8a..f6a7b788819f2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -157,6 +157,12 @@ struct GroupQueryAttentionData { const T* sin_cache = nullptr; const T* head_sink = nullptr; + // Optional per-head Q/K RMSNorm (QK-Norm) weights, shape (head_size,), shared across heads. + // Both are non-null together (validated in the op) and trigger the fused normalization before RoPE. + const T* q_norm_weight = nullptr; + const T* k_norm_weight = nullptr; + float qk_norm_epsilon = 1e-6f; + const float* k_scale = nullptr; const float* v_scale = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 60b806019cc57..fb82bbee2a84b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -1,9 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include #include #include -#include #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cuda_type_conversion.h" @@ -105,6 +106,9 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); softcap_ = info.GetAttrOrDefault("softcap", 0.0f); use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; + qk_norm_epsilon_ = info.GetAttrOrDefault("qk_norm_epsilon", 1e-6f); + ORT_ENFORCE(std::isfinite(qk_norm_epsilon_) && qk_norm_epsilon_ > 0.0f, + "GroupQueryAttention (CUDA): qk_norm_epsilon must be finite and positive."); k_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("k_quant_type", "NONE")); v_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("v_quant_type", "NONE")); @@ -222,16 +226,17 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons const Tensor* k_scale = context->Input(12); const Tensor* v_scale = context->Input(13); - // q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only - // GroupQueryAttentionPreNormFusion optimizer pass. The CUDA kernel does not implement - // the fused per-head Q/K RMS normalization prologue, so reject the node if either input - // is present rather than silently dropping the normalization. - if ((context->InputCount() > 14 && context->Input(14) != nullptr) || - (context->InputCount() > 15 && context->Input(15) != nullptr)) { + // q_norm_weight (input 14) / k_norm_weight (input 15) carry the per-head Q/K RMSNorm (QK-Norm) + // prologue weights, each of shape (head_size,) and shared across heads. They are populated by the + // GroupQueryAttentionPreNormFusion optimizer pass for Qwen3 / Gemma 2-3 / OLMo2 / SmolLM3 style + // models. Both must be present together; shape validation and wiring happen after CheckInputs, + // where head_size is known. + const Tensor* q_norm_weight = (context->InputCount() > 14) ? context->Input(14) : nullptr; + const Tensor* k_norm_weight = (context->InputCount() > 15) ? context->Input(15) : nullptr; + if ((q_norm_weight != nullptr) != (k_norm_weight != nullptr)) { return ORT_MAKE_STATUS( ONNXRUNTIME, INVALID_ARGUMENT, - "GroupQueryAttention (CUDA): q_norm_weight / k_norm_weight inputs are not supported. " - "The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP."); + "GroupQueryAttention (CUDA): q_norm_weight and k_norm_weight must be provided together."); } if (k_quant_type_ != KVQuantizationType::NONE) { @@ -300,6 +305,27 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons head_sink, parameters)); + // Validate and enable the per-head Q/K RMSNorm (QK-Norm) prologue (inputs 14/15). Both weights + // must be 1D tensors of shape (head_size) with element type T (shared across all heads). + if (q_norm_weight != nullptr) { + const auto& q_norm_shape = q_norm_weight->Shape(); + const auto& k_norm_shape = k_norm_weight->Shape(); + if (q_norm_shape.NumDimensions() != 1 || q_norm_shape[0] != parameters.head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "q_norm_weight must be a 1D tensor of shape (head_size=", parameters.head_size, ")"); + } + if (k_norm_shape.NumDimensions() != 1 || k_norm_shape[0] != parameters.head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "k_norm_weight must be a 1D tensor of shape (head_size=", parameters.head_size, ")"); + } + if (!q_norm_weight->IsDataType() || !k_norm_weight->IsDataType()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "q_norm_weight/k_norm_weight type must match GroupQueryAttention input type T."); + } + parameters.use_qk_norm = true; + } + parameters.qk_norm_epsilon = qk_norm_epsilon_; + parameters.local_window_size = local_window_size_; parameters.is_unidirectional = is_unidirectional_; parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr; @@ -389,9 +415,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // 5. No Softcap (XQA doesn't support softcap). // 6. Standard Softmax, or smooth softmax represented by a head_sink tensor. // 7. No local window attention (global attention only). - const bool use_xqa_attention_sinks = head_sink != nullptr && !is_inputs_quantized; - const bool is_xqa_smooth_softmax_supported = !parameters.use_smooth_softmax || use_xqa_attention_sinks; - // XQA is enabled when enable_xqa_=true; ineligible shapes/group sizes fall back via data.use_xqa below. + // QK-Norm can use XQA for the non-quantized KV-cache path: ExtremeDecoding runs the same + // UnpackRoPEAppend preprocess before XQA, so Q/K can be normalized before the XQA kernel consumes + // Q and the appended cache. Keep quantized QK-Norm off the XQA route until scale correctness is + // validated for normalized K before quantized-cache append. + const bool xqa_qk_norm_ok = !parameters.use_qk_norm || !is_inputs_quantized; + const bool use_xqa_attention_sinks = parameters.use_smooth_softmax && head_sink != nullptr && !is_inputs_quantized; + const bool xqa_smooth_softmax_ok = !parameters.use_smooth_softmax || (head_sink != nullptr && !is_inputs_quantized); if (enable_xqa_ && (device_prop.major >= 8) && !parameters.is_first_prompt && @@ -399,8 +429,9 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.kv_sequence_length > 0 && // Shared KV (kv_seq=0) has no new K/V to append parameters.past_present_share_buffer && parameters.softcap == 0.0f && - is_xqa_smooth_softmax_supported && - parameters.local_window_size == -1) { + parameters.local_window_size == -1 && + xqa_qk_norm_ok && + xqa_smooth_softmax_ok) { int group_size = parameters.num_heads / parameters.kv_num_heads; bool is_int8_quantized_supported = is_int8 && @@ -518,7 +549,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.kv_num_heads); data.use_flash_attention = use_flash_attention; - data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_sequence_length > 0 && parameters.past_present_share_buffer && !is_inputs_quantized; + // The fast-decode path lets the flash kernel perform RoPE and KV-append internally, bypassing + // PrepareQKV (and therefore the fused QK-Norm prologue). Disable it when q/k norm weights are + // present so the regular FlashAttention path (which normalizes via PrepareQKV) is used instead. + data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_sequence_length > 0 && parameters.past_present_share_buffer && !is_inputs_quantized && !parameters.use_qk_norm; if (use_flash_attention) { // Allocate Flash specific buffers (Softmax LSE, Accum) @@ -701,6 +735,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons data.head_sink = reinterpret_cast(head_sink->Data()); } + if (parameters.use_qk_norm) { + data.q_norm_weight = reinterpret_cast(q_norm_weight->Data()); + data.k_norm_weight = reinterpret_cast(k_norm_weight->Data()); + data.qk_norm_epsilon = qk_norm_epsilon_; + } + #if DUMP_TENSOR_LEVEL > 0 DUMP_TENSOR_INIT(); // Dump Scales diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 2fe6268e53c68..d219bd83474d2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -33,6 +33,7 @@ class GroupQueryAttention final : public CudaKernel { bool do_rotary_; bool rotary_interleaved_; bool use_smooth_softmax_; + float qk_norm_epsilon_; // epsilon for the per-head Q/K RMSNorm (QK-Norm) prologue float scale_; float softcap_; bool disable_flash_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 6a55f18bd939a..fda4ae66417fc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -88,6 +88,63 @@ Status LaunchConvertHeadSinkToFloat( return CUDA_CALL(cudaGetLastError()); } +// Standalone per-head RMS normalization (QK-Norm prologue) for Q in BSNH layout. +// Each block handles one (b, s, head) vector and normalizes over head_size: +// out[c] = in[c] * rsqrt(mean(in^2) + epsilon) * weight[c] +// where weight has shape (head_size,) and is shared across heads. This is used by the shared-KV +// (kv_sequence_length == 0) path, which processes Q only; the new-KV path folds the equivalent +// normalization into UnpackRoPEAppend before RoPE. +template +__global__ void PerHeadRMSNormBSNHKernel( + T* output, const T* input, const T* weight, const int head_size, const float epsilon) { + const int s = blockIdx.x; + const int b = blockIdx.y; + const int n = blockIdx.z; + const int sequence_length = gridDim.x; + const int num_heads = gridDim.z; + const int i = threadIdx.x; + + const int64_t base = (((static_cast(b) * sequence_length + s) * num_heads) + n) * head_size; + + extern __shared__ float s_sumsq[]; + const float x = (i < head_size) ? static_cast(input[base + i]) : 0.0f; + s_sumsq[i] = x * x; + __syncthreads(); + + // Tree reduction over a power-of-two block size; padded entries (i >= head_size) are zero. + for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (i < stride) { + s_sumsq[i] += s_sumsq[i + stride]; + } + __syncthreads(); + } + + if (i < head_size) { + const float inv_rms = rsqrtf(s_sumsq[0] / static_cast(head_size) + epsilon); + output[base + i] = static_cast(x * inv_rms * static_cast(weight[i])); + } +} + +template +Status LaunchPerHeadRMSNorm( + cudaStream_t stream, T* output, const T* input, const T* weight, const float epsilon, + const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const int max_threads_per_block) { + // Round the thread count up to a power of two so the tree reduction is exact. + int tpb = 1; + while (tpb < head_size) { + tpb <<= 1; + } + if (tpb > max_threads_per_block) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "head_size (", head_size, ") exceeds max threads for QK-Norm."); + } + const dim3 grid(sequence_length, batch_size, num_heads); + const dim3 block(tpb); + const size_t smem = static_cast(tpb) * sizeof(float); + PerHeadRMSNormBSNHKernel<<>>(output, input, weight, head_size, epsilon); + return CUDA_CALL(cudaGetLastError()); +} + // Internal helper to get Q, K, V pointers, handling packed input // // This function orchestrates the preparation of Q, K, and V tensors for attention kernels. @@ -116,7 +173,7 @@ Status PrepareQKV( T* q_out = reinterpret_cast(data.qkv_buffer); - if (!parameters.is_packed_qkv && !parameters.do_rotary) { + if (!parameters.is_packed_qkv && !parameters.do_rotary && !parameters.use_qk_norm) { q_out = nullptr; } @@ -153,6 +210,16 @@ Status PrepareQKV( // above has already populated the present buffer with the shared KV data. // In both cases, only Q processing (RoPE if configured) is needed here. if (kv_sequence_length == 0) { + // QK-Norm: normalize Q (BSNH) into q_out before RoPE. K is already normalized in the shared cache + // (it was normalized when first appended), so only Q needs processing on this path. + const T* q_rope_input = data.query; + if (parameters.use_qk_norm) { + ORT_RETURN_IF_ERROR((LaunchPerHeadRMSNorm( + stream, q_out, data.query, data.q_norm_weight, parameters.qk_norm_epsilon, + batch_size, sequence_length, num_heads, head_size, max_threads_per_block))); + // RoPE (if any) runs in-place on the normalized Q; the rotary kernel handles in-place safely. + q_rope_input = q_out; + } if (parameters.do_rotary && data.cos_cache != nullptr && data.sin_cache != nullptr) { // Apply RoPE to Q only using the standalone rotary embedding kernel. // Q is in BSNH format; the kernel writes rotated Q to q_out. @@ -161,7 +228,7 @@ Status PrepareQKV( const int pos_format = data.position_ids != nullptr ? 1 : 2; if constexpr (std::is_same::value) { ORT_RETURN_IF_ERROR((LaunchRotaryEmbeddingKernel( - stream, reinterpret_cast(q_out), reinterpret_cast(data.query), + stream, reinterpret_cast(q_out), reinterpret_cast(q_rope_input), data.position_ids, data.past_seq_lens, reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), batch_size, sequence_length, num_heads, head_size, parameters.rotary_dim, max_cache_length, @@ -169,7 +236,7 @@ Status PrepareQKV( max_threads_per_block, false /* is_input_bnsh_format: Q is BSNH */))); } else if constexpr (std::is_same::value) { ORT_RETURN_IF_ERROR((LaunchRotaryEmbeddingKernel( - stream, reinterpret_cast(q_out), reinterpret_cast(data.query), + stream, reinterpret_cast(q_out), reinterpret_cast(q_rope_input), data.position_ids, data.past_seq_lens, reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), batch_size, sequence_length, num_heads, head_size, parameters.rotary_dim, max_cache_length, @@ -177,7 +244,7 @@ Status PrepareQKV( max_threads_per_block, false /* is_input_bnsh_format: Q is BSNH */))); } } - // If do_rotary is false, Q is used directly from data.query (q_out == nullptr). + // If do_rotary is false and QK-Norm is off, Q is used directly from data.query (q_out == nullptr). // K/V present buffers already point to the shared past — no work needed. } else { ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( @@ -191,6 +258,7 @@ Status PrepareQKV( reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, is_cache_bnsh, parameters.k_quant_type, + data.q_norm_weight, data.k_norm_weight, parameters.qk_norm_epsilon, stream, max_threads_per_block))); } @@ -668,6 +736,7 @@ Status ExtremeDecoding( parameters.rotary_interleaved, !past_bsnh, // is_cache_bnsh parameters.k_quant_type, + data.q_norm_weight, data.k_norm_weight, parameters.qk_norm_epsilon, stream, device_prop.maxThreadsPerBlock))); @@ -894,6 +963,7 @@ Status DequantizeFlashAttentionFallback( parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH), parameters.k_quant_type, + data.q_norm_weight, data.k_norm_weight, parameters.qk_norm_epsilon, stream, device_prop.maxThreadsPerBlock))); // Step 2: Dequantize Entire Cache @@ -978,6 +1048,7 @@ Status FlashAttentionAndQuantizeKV( parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, false, // BSNH for scratch KVQuantizationType::NONE, + data.q_norm_weight, data.k_norm_weight, parameters.qk_norm_epsilon, stream, max_threads_per_block))); // 2. Run Float Flash Attention diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 348dc0832d3ba..b1f979fef4169 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -76,9 +76,9 @@ struct GQABufferRequirements { const size_t v_elements = k_elements; if (use_xqa) { - if (params.do_rotary || params.is_packed_qkv) { - // XQA need scratch for rotated/unpacked Q. - // RoPE K is written directly to cache by the fused kernel. + if (params.do_rotary || params.is_packed_qkv || params.use_qk_norm) { + // XQA needs scratch for rotated/unpacked/normalized Q. + // RoPE/QK-Norm K is written directly to cache by the fused preprocess kernel. req.qkv_buffer_bytes = elem_size * q_elements; } return req; @@ -127,7 +127,9 @@ struct GQABufferRequirements { } // Unfused fallback: needs Q buffer for rotary embedding output. - if (req.qkv_buffer_bytes == 0 && (params.do_rotary || params.is_packed_qkv)) { + // QK-Norm also requires a materialized Q buffer to hold the normalized (and optionally rotated) Q, + // even when rotary is disabled and the input is not packed. + if (req.qkv_buffer_bytes == 0 && (params.do_rotary || params.is_packed_qkv || params.use_qk_norm)) { req.qkv_buffer_bytes = elem_size * q_elements; } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index 5fb36e094482b..0c62aef11d53a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -61,7 +61,12 @@ __global__ void UnpackRoPEAppend( const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, - const bool per_channel) { + const bool per_channel, + // QK-Norm (per-head Q/K RMSNorm) weights of shape (head_size,), shared across heads. + // nullptr disables normalization for the corresponding head type; V heads are never normalized. + const T* q_norm_weight, + const T* k_norm_weight, + const float qk_norm_epsilon) { using LoadT = float4; constexpr int elements_per_thread = sizeof(LoadT) / sizeof(T); @@ -84,6 +89,9 @@ __global__ void UnpackRoPEAppend( const int sequence_length = gridDim.x; // Number of new tokens in this launch __shared__ T shared_head[MAX_HEAD_SIZE]; + // Per-block reduction buffer for the QK-Norm sum-of-squares. One block handles one (b, s, head), + // and blockDim.x == head_size / elements_per_thread (<= MAX_HEAD_SIZE / elements_per_thread). + __shared__ float s_qk_reduce[MAX_HEAD_SIZE / elements_per_thread]; // Determine Head Type and Offset within the packed hidden dimension [Q, K, V] enum HeadType { QUERY, @@ -139,6 +147,44 @@ __global__ void UnpackRoPEAppend( // Non-interleaved RoPE requires full head visibility to pair channels (h, h + d/2). // We use shared memory as a staging buffer to allow any thread to access its pair. const bool is_qk = (head_type == QUERY || head_type == KEY); + + // 1.5 QK-Norm: per-head RMSNorm applied BEFORE RoPE (Qwen3 / Gemma 2-3 / OLMo2 / SmolLM3). + // Each block processes a single head, so head_type (and thus norm_weight) is uniform across the + // block, which makes the __syncthreads below safe. Q heads use q_norm_weight, K heads use + // k_norm_weight, and V heads are skipped. The weight is shared across heads (indexed by channel). + const T* norm_weight = (head_type == QUERY) ? q_norm_weight : ((head_type == KEY) ? k_norm_weight : nullptr); + if (is_qk && norm_weight != nullptr) { + float partial = 0.0f; + if (valid) { +#pragma unroll + for (int i = 0; i < elements_per_thread; ++i) { + const float f = static_cast(vals[i]); + partial += f * f; + } + } + s_qk_reduce[tid] = partial; + __syncthreads(); + // blockDim.x == head_size / elements_per_thread is small (<= 64). A linear reduction is robust + // for any (possibly non-power-of-two) thread count and avoids tree-reduction edge cases. + // Reduce once in tid==0 and broadcast inv_rms via shared memory to avoid the redundant + // O(blockDim.x^2) shared reads that result from every thread summing the partials. + if (tid == 0) { + float sumsq = 0.0f; + for (int t = 0; t < blockDim.x; ++t) { + sumsq += s_qk_reduce[t]; + } + s_qk_reduce[0] = rsqrtf(sumsq / static_cast(head_size) + qk_norm_epsilon); + } + __syncthreads(); + const float inv_rms = s_qk_reduce[0]; + if (valid) { +#pragma unroll + for (int i = 0; i < elements_per_thread; ++i) { + vals[i] = static_cast(static_cast(vals[i]) * inv_rms * static_cast(norm_weight[h + i])); + } + } + } + if (valid && rotary_dim > 0 && is_qk && !interleaved) { T* shared_ptr = &shared_head[h]; *reinterpret_cast(shared_ptr) = *reinterpret_cast(vals); @@ -276,27 +322,32 @@ Status DispatchUnpackRoPEAppendHeadSize( const int num_heads, const int kv_num_heads, const int head_size, const int d, const int max_seqlen, const int* past_seq_lens, const T* cos_cache, const T* sin_cache, const int rotary_dim, - const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const bool per_channel) { + const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const bool per_channel, + const T* q_norm_weight, const T* k_norm_weight, const float qk_norm_epsilon) { if (head_size <= 64) { UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); } else if (head_size <= 128) { UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); } else if (head_size <= 256) { UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); } else if (head_size <= 512) { UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size (", head_size, ") exceeds maximum supported MAX_HEAD_SIZE (512)."); } @@ -318,6 +369,7 @@ Status LaunchUnpackRoPEAppend( const int* past_seq_lens, const T* cos_cache, const T* sin_cache, const int rotary_dim, const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const KVQuantizationType k_quant_type, + const T* q_norm_weight, const T* k_norm_weight, const float qk_norm_epsilon, cudaStream_t stream, const int max_threads_per_block) { static_assert(std::is_same::type>::value); static_assert(std::is_same::type>::value); @@ -365,7 +417,8 @@ Status LaunchUnpackRoPEAppend( return DispatchUnpackRoPEAppendHeadSize( grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); } else if constexpr (std::is_same::value #ifdef USE_FP8_KV_CACHE || std::is_same::value @@ -375,14 +428,16 @@ Status LaunchUnpackRoPEAppend( return DispatchUnpackRoPEAppendHeadSize( grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); #ifdef USE_INT4_KV_CACHE } else if constexpr (std::is_same::value) { // INT4 quantization (packed 2 elements per byte) return DispatchUnpackRoPEAppendHeadSize( grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel, + q_norm_weight, k_norm_weight, qk_norm_epsilon); #endif } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported cache type U for GQA quantization."); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 5d537fa59bfab..896774fb5c8d8 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1324,7 +1324,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a " "per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap " "their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion " - "folds that pattern into this input. Currently honored by the native WebGPU execution provider only; " + "folds that pattern into this input. Currently honored by the CUDA and native WebGPU execution providers; " "JSEP WebGPU/JS and other EPs must reject the node when this input is set.", "T", OpSchema::Optional) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index a998eeacda734..56fc3ca7a9855 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -453,7 +453,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_ep)); transformers.emplace_back(std::make_unique( - InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); + InlinedHashSet{onnxruntime::kCudaExecutionProvider, + onnxruntime::kWebGpuExecutionProvider})); bool has_matmul_nbits_mlp_kernel = false; bool has_matmul_nbits_qkv_kernel = false; if (execution_providers != nullptr) { diff --git a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc index 3271c4cc9a22f..f8763306ced46 100644 --- a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc +++ b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc @@ -167,6 +167,9 @@ bool MatchPreNormReshapeChain(Graph& graph, } const auto* sln_eps_attr = graph_utils::GetNodeAttribute(*sln, "epsilon"); const float sln_eps = (sln_eps_attr == nullptr) ? 1e-5f : sln_eps_attr->f(); + if (!std::isfinite(sln_eps) || sln_eps <= 0.0f) { + return false; + } // Inner reshape (between projection and SLN). if (sln->InputDefs().empty() || sln->InputDefs()[0] == nullptr) { diff --git a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h index b69199bb5324d..f3cb131a0b8fb 100644 --- a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h +++ b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h @@ -27,12 +27,12 @@ on inputs 0 (query) and 1 (key) of an unfused GroupQueryAttention node: When matched, the six Reshape/SLN nodes are removed and the pre-norm Q and K projections feed GQA directly. The kernel is responsible for applying the RMS -norm internally (currently the WebGPU EP). +norm internally (currently the CUDA and WebGPU EPs). Only fires for execution providers passed in `compatible_execution_providers`. -At present this fusion is registered for the WebGPU EP only, because the -in-kernel norm path is currently implemented there. The CPU, CUDA, and JSEP -GroupQueryAttention kernels reject q_norm_weight / k_norm_weight inputs. +At present this fusion is registered for the CUDA and WebGPU EPs, because those +kernels implement the in-kernel norm path. CPU and JSEP GroupQueryAttention +kernels reject q_norm_weight / k_norm_weight inputs. */ class GroupQueryAttentionPreNormFusion : public GraphTransformer { public: diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp index bf5a182f9662c..ae5434b2bae04 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp @@ -34,6 +34,13 @@ class DmlOperatorGroupQueryAttention : public DmlOperator, public GroupQueryAtte ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 1); ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1); + constexpr uint32_t qNormWeightIndex = 14; + constexpr uint32_t kNormWeightIndex = 15; + const bool hasQNormWeight = kernelCreationContext.GetInputCount() > qNormWeightIndex && kernelCreationContext.IsInputValid(qNormWeightIndex); + const bool hasKNormWeight = kernelCreationContext.GetInputCount() > kNormWeightIndex && kernelCreationContext.IsInputValid(kNormWeightIndex); + ML_CHECK_VALID_ARGUMENT(!hasQNormWeight && !hasKNormWeight, + "GroupQueryAttention (DML): q_norm_weight / k_norm_weight inputs are not supported."); + std::vector> inputIndices(inputCount); inputIndices[queryIndex] = queryIndex; inputIndices[keyIndex] = keyIndex; diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 405b6133fa474..94408f4ed24f0 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -116,7 +116,7 @@ static void RunGQASeqlensKTest( tester.Run(expect, expected_message, {}, nullptr, &execution_providers); } -// CPU GroupQueryAttention does not implement the WebGPU-only fused Q/K RMS-norm prologue +// CPU GroupQueryAttention does not implement the CUDA/WebGPU fused Q/K RMS-norm prologue // inputs (q_norm_weight/k_norm_weight at indices 14/15). Ensure we reject these explicitly. TEST(GroupQueryAttentionTest, CpuRejectsQKNormWeightInputs) { constexpr int batch_size = 1; @@ -168,9 +168,14 @@ TEST(GroupQueryAttentionTest, CpuRejectsQKNormWeightInputs) { {}, nullptr, &execution_providers); } -// CUDA GroupQueryAttention also does not implement the WebGPU-only fused Q/K RMS-norm -// prologue inputs (q_norm_weight/k_norm_weight at indices 14/15). Ensure the guard is covered. -TEST(GroupQueryAttentionTest, CudaRejectsQKNormWeightInputs) { +static void RunCudaQKNormInputContractTest( + OpTester::ExpectResult expect, + const std::string& expected_message, + std::optional qk_norm_epsilon = std::nullopt, + bool include_q_norm_weight = true, + bool include_k_norm_weight = true, + int q_norm_weight_size = 8, + int k_norm_weight_size = 8) { auto cuda_ep = DefaultCudaExecutionProvider(); if (!cuda_ep) { GTEST_SKIP() << "CUDA EP not available"; @@ -187,6 +192,9 @@ TEST(GroupQueryAttentionTest, CudaRejectsQKNormWeightInputs) { OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + if (qk_norm_epsilon.has_value()) { + tester.AddAttribute("qk_norm_epsilon", *qk_norm_epsilon); + } tester.AddInput("query", {batch_size, sequence_length, hidden_size}, std::vector(batch_size * sequence_length * hidden_size, MLFloat16(0.1f))); @@ -208,8 +216,16 @@ TEST(GroupQueryAttentionTest, CudaRejectsQKNormWeightInputs) { tester.AddOptionalInputEdge(); // k_scale tester.AddOptionalInputEdge(); // v_scale - tester.AddInput("q_norm_weight", {head_size}, std::vector(head_size, MLFloat16(1.0f))); - tester.AddInput("k_norm_weight", {head_size}, std::vector(head_size, MLFloat16(1.0f))); + if (include_q_norm_weight) { + tester.AddInput("q_norm_weight", {q_norm_weight_size}, + std::vector(q_norm_weight_size, MLFloat16(1.0f))); + } else if (include_k_norm_weight) { + tester.AddOptionalInputEdge(); + } + if (include_k_norm_weight) { + tester.AddInput("k_norm_weight", {k_norm_weight_size}, + std::vector(k_norm_weight_size, MLFloat16(1.0f))); + } tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, std::vector(batch_size * sequence_length * hidden_size, MLFloat16(0.0f))); @@ -218,11 +234,50 @@ TEST(GroupQueryAttentionTest, CudaRejectsQKNormWeightInputs) { tester.AddOutput("present_value", {batch_size, kv_num_heads, sequence_length, head_size}, std::vector(batch_size * kv_num_heads * sequence_length * head_size, MLFloat16(0.0f))); + if (expect == OpTester::ExpectResult::kExpectSuccess) { + // This is an input-contract smoke test. The dedicated QK-Norm functional tests cover numerical equivalence. + tester.SetOutputTolerance(1e6f); + } + std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectFailure, - "q_norm_weight / k_norm_weight inputs are not supported", - {}, nullptr, &execution_providers); + execution_providers.push_back(std::move(cuda_ep)); + tester.Run(expect, expected_message, {}, nullptr, &execution_providers); +} + +// CUDA GroupQueryAttention implements the fused Q/K RMS-norm prologue inputs +// (q_norm_weight/k_norm_weight at indices 14/15). Ensure the input contract is accepted. +TEST(GroupQueryAttentionTest, CudaAcceptsQKNormWeightInputs) { + RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectSuccess, ""); +} + +TEST(GroupQueryAttentionTest, CudaRejectsQKNormMissingKWeight) { + RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectFailure, + "q_norm_weight and k_norm_weight must be provided together", + std::nullopt, + /*include_q_norm_weight=*/true, + /*include_k_norm_weight=*/false); +} + +TEST(GroupQueryAttentionTest, CudaRejectsQKNormWrongKWeightShape) { + RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectFailure, + "k_norm_weight must be a 1D tensor of shape", + std::nullopt, + /*include_q_norm_weight=*/true, + /*include_k_norm_weight=*/true, + /*q_norm_weight_size=*/8, + /*k_norm_weight_size=*/9); +} + +TEST(GroupQueryAttentionTest, CudaRejectsZeroQKNormEpsilon) { + RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectFailure, + "qk_norm_epsilon must be finite and positive", + 0.0f); +} + +TEST(GroupQueryAttentionTest, CudaRejectsNegativeQKNormEpsilon) { + RunCudaQKNormInputContractTest(OpTester::ExpectResult::kExpectFailure, + "qk_norm_epsilon must be finite and positive", + -1.0e-6f); } // Regression: negative seqlens_k wraps to huge size_t, causing GEMM OOB. @@ -515,7 +570,7 @@ static void ApplyPerHeadRmsNormBSNH(std::vector& data, } // Runs GroupQueryAttention with do_rotary=1 and optional q/k norm weights. -// If q_norm_weight/k_norm_weight are provided, this exercises the WebGPU-only +// If q_norm_weight/k_norm_weight are provided, this exercises the EP-specific // q/k norm input contract. CPU callers should pass nullptr for both and feed // pre-normalized Q/K values instead. static std::vector RunGQARotaryWithOptionalQKNorm( diff --git a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc index ef0f40a51c838..2ac880a5459c5 100644 --- a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc +++ b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include #include "core/graph/node_attr_utils.h" @@ -246,16 +247,34 @@ Status CheckUnfusedGraph(const Graph& graph) { } // namespace -// Helper: build the transformer registered for the WebGPU EP only (matches production). +// Helper: build the transformer for the original WebGPU-only tests. std::unique_ptr MakeWebGpuTransformer() { return std::make_unique( InlinedHashSet{kWebGpuExecutionProvider}); } +// Helper: build the production-compatible transformer. +std::unique_ptr MakeCudaWebGpuTransformer() { + return std::make_unique( + InlinedHashSet{kCudaExecutionProvider, kWebGpuExecutionProvider}); +} + TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesQwenPattern) { auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); }; ASSERT_STATUS_OK(TestGraphTransformer( - build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + build, /*opset_version=*/21, *logger_, MakeCudaWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesCudaAssignedQwenPattern) { + auto build = [](ModelTestBuilder& builder) { + BuildQwenQkPostNormPattern(builder, BuildOptions{}); + for (auto& node : builder.graph_.Nodes()) { + node.SetExecutionProviderType(kCudaExecutionProvider); + } + }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeCudaWebGpuTransformer(), TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph)); } @@ -270,9 +289,9 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesQwenPatter auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); }; // opset 27 is under development in ONNX 1.22 (released map-max 27 > last release 26), so strict legs // reject this *CurrentOpset model at load; allow the unreleased opset. Remove once opset 27 ships. - // Tracked by #28966; this is the WebGPU attention-fusion path that surfaced #28969. + // Tracked by #28966; this runs the CUDA+WebGPU attention-fusion path that surfaced #28969. ASSERT_STATUS_OK(TestGraphTransformer( - build, /*opset_version=*/current_opset, *logger_, MakeWebGpuTransformer(), + build, /*opset_version=*/current_opset, *logger_, MakeCudaWebGpuTransformer(), TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph, ModelOptions{kAllowReleasedOpsetsOnly, /*strict_shape_type_inference*/ false})); } @@ -344,6 +363,26 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsEpsilonM TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); } +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNonPositiveEpsilon) { + BuildOptions opts; + opts.q_epsilon = 0.0f; + opts.k_epsilon = 0.0f; + auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNonFiniteEpsilon) { + BuildOptions opts; + opts.q_epsilon = std::numeric_limits::quiet_NaN(); + opts.k_epsilon = std::numeric_limits::quiet_NaN(); + auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); +} + TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsBadInnerReshape) { BuildOptions opts; opts.break_k_inner_reshape_shape = true; @@ -363,7 +402,7 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNon1DNor } TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsCpuEp) { - // Build the pattern but assign all nodes to CPU EP. The fusion is gated to WebGPU only, + // Build the pattern but assign all nodes to CPU EP. The fusion is gated to CUDA/WebGPU only, // so the graph must remain unfused. auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); @@ -372,13 +411,13 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsCpuEp) { } }; ASSERT_STATUS_OK(TestGraphTransformer( - build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + build, /*opset_version=*/21, *logger_, MakeCudaWebGpuTransformer(), TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); } TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsJsEp) { // JSEP does not implement the fused per-head Q/K RMSNorm prologue, so the optimizer - // (which we now register for WebGPU only) must leave JSEP-assigned graphs alone. + // (which we now register for CUDA/WebGPU only) must leave JSEP-assigned graphs alone. auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); for (auto& node : builder.graph_.Nodes()) { @@ -386,7 +425,7 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsJsEp) { } }; ASSERT_STATUS_OK(TestGraphTransformer( - build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + build, /*opset_version=*/21, *logger_, MakeCudaWebGpuTransformer(), TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); } diff --git a/onnxruntime/test/python/transformers/gqa_test_helper.py b/onnxruntime/test/python/transformers/gqa_test_helper.py index d3dd86ea9bbc6..73fe45ac8fc7e 100644 --- a/onnxruntime/test/python/transformers/gqa_test_helper.py +++ b/onnxruntime/test/python/transformers/gqa_test_helper.py @@ -311,6 +311,8 @@ def __init__( kv_cache_type: str = "float16", share_kv_scale: bool = False, has_head_sink: bool = False, + has_qk_norm: bool = False, + qk_norm_epsilon: float = 1e-6, ): super().__init__( "GroupQueryAttention", @@ -343,6 +345,8 @@ def __init__( self.v_quant_type = v_quant_type self.share_kv_scale = share_kv_scale self.has_head_sink = has_head_sink + self.has_qk_norm = has_qk_norm + self.qk_norm_epsilon = qk_norm_epsilon # Determine bit width from cache type if applicable if kv_cache_type == "int4": self.kv_cache_bit_width = 4 @@ -363,6 +367,9 @@ def shape_dict(self): ) if self.has_head_sink: shapes["head_sink"] = (self.num_heads,) + if self.has_qk_norm: + shapes["q_norm_weight"] = (self.head_size,) + shapes["k_norm_weight"] = (self.head_size,) # Note: We don't adjust shapes for int4 here because the parent's random_inputs # creates float tensors first, then quantization will pack them return shapes @@ -378,6 +385,15 @@ def random_inputs(self): if self.has_head_sink: feeds["head_sink"] = torch.rand((self.num_heads,), device=self.device, dtype=self.dtype) + if self.has_qk_norm: + generator = torch.Generator(device=self.device).manual_seed(7) + feeds["q_norm_weight"] = ( + 1.0 + 0.1 * torch.randn(self.head_size, generator=generator, device=self.device, dtype=torch.float32) + ).to(self.dtype) + feeds["k_norm_weight"] = ( + 1.0 + 0.1 * torch.randn(self.head_size, generator=generator, device=self.device, dtype=torch.float32) + ).to(self.dtype) + # Generate quantized cache and scales if quantization is enabled if self.k_quant_type != "NONE": # Compute scales from the generated float cache @@ -432,6 +448,8 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): "head_sink" if config.has_head_sink else "", "k_scale" if config.k_quant_type != "NONE" else "", "v_scale" if config.v_quant_type != "NONE" else "", + "q_norm_weight" if config.has_qk_norm else "", + "k_norm_weight" if config.has_qk_norm else "", ] # Remove trailing empty strings while node_inputs and node_inputs[-1] == "": @@ -449,6 +467,9 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): "domain": "com.microsoft", } + if config.has_qk_norm: + node_attrs["qk_norm_epsilon"] = config.qk_norm_epsilon + # Add quantization attributes if enabled if config.k_quant_type != "NONE": node_attrs["k_quant_type"] = config.k_quant_type @@ -521,6 +542,14 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): if config.has_head_sink: graph_input.append(helper.make_tensor_value_info("head_sink", float_type, list(shape_dict["head_sink"]))) + if config.has_qk_norm: + graph_input.extend( + [ + helper.make_tensor_value_info("q_norm_weight", float_type, list(shape_dict["q_norm_weight"])), + helper.make_tensor_value_info("k_norm_weight", float_type, list(shape_dict["k_norm_weight"])), + ] + ) + # Add scale inputs for quantization # Shape depends on quantization type: # - PER_TENSOR: [1] diff --git a/onnxruntime/test/python/transformers/profile_gqa.py b/onnxruntime/test/python/transformers/profile_gqa.py index 49ce26b5126b8..394e2d6689cc9 100644 --- a/onnxruntime/test/python/transformers/profile_gqa.py +++ b/onnxruntime/test/python/transformers/profile_gqa.py @@ -73,6 +73,8 @@ def create_gqa_config( has_head_sink: bool = False, device: str = "cuda", share_kv_scale: bool = False, + has_qk_norm: bool = False, + qk_norm_epsilon: float = 1e-6, ) -> GroupQueryAttentionConfig: """Create a GQA config based on the mode.""" if mode == "fp16": @@ -118,6 +120,8 @@ def create_gqa_config( v_quant_type=v_quant_type, kv_cache_type=kv_cache_type, share_kv_scale=share_kv_scale, + has_qk_norm=has_qk_norm, + qk_norm_epsilon=qk_norm_epsilon, ) return config @@ -158,6 +162,7 @@ def run_comparison(args): print(f"Config: batch={args.batch_size}, seq_len={args.sequence_length}, past_seq={args.past_sequence_length}") print(f" num_heads={args.num_heads}, kv_heads={args.kv_num_heads}, head_size={args.head_size}") print(f" packed_qkv={args.is_packed_qkv}, rotary={not args.no_rotary}, head_sink={args.head_sink}") + print(f" qk_norm={args.qk_norm}, qk_norm_epsilon={args.qk_norm_epsilon}") print(f" warmup={args.warmup}, repeat={args.repeat}") print(f"{'=' * 70}\n") @@ -179,10 +184,14 @@ def run_comparison(args): do_rotary=not args.no_rotary, has_head_sink=args.head_sink, share_kv_scale=args.share_kv_scale, + has_qk_norm=args.qk_norm, + qk_norm_epsilon=args.qk_norm_epsilon, ) - avg_ms = benchmark_gqa(config, warmup=args.warmup, repeat=args.repeat, mode=mode) + range_name = f"{mode}_qknorm" if args.qk_norm else mode + avg_ms = benchmark_gqa(config, warmup=args.warmup, repeat=args.repeat, mode=range_name) results[mode] = avg_ms - print(f" {mode.upper():6s} (dtype={config.dtype}): {avg_ms:.4f} ms") + suffix = "+QKNorm" if args.qk_norm else "" + print(f" {mode.upper() + suffix:13s} (dtype={config.dtype}): {avg_ms:.4f} ms") # Print comparison if we have baseline baseline = "fp16" if "fp16" in results else ("bf16" if "bf16" in results else None) @@ -216,6 +225,8 @@ def main(): parser.add_argument("--repeat", type=int, default=100, help="Benchmark iterations") parser.add_argument("--is-packed-qkv", action="store_true", help="Use packed QKV") parser.add_argument("--head-sink", action="store_true", help="Add a head_sink input") + parser.add_argument("--qk-norm", action="store_true", help="Add q_norm_weight/k_norm_weight inputs") + parser.add_argument("--qk-norm-epsilon", type=float, default=1e-6, help="QK-Norm epsilon") parser.add_argument("--no-rotary", action="store_true", help="Disable rotary embeddings") parser.add_argument("--share-kv-scale", action="store_true", help="Share KV scale tensor for XQA") diff --git a/onnxruntime/test/python/transformers/profile_gqa.sh b/onnxruntime/test/python/transformers/profile_gqa.sh index 3fe10eaf35ad2..0dfb820fecbef 100644 --- a/onnxruntime/test/python/transformers/profile_gqa.sh +++ b/onnxruntime/test/python/transformers/profile_gqa.sh @@ -10,7 +10,9 @@ # ./profile_gqa.sh --all # ./profile_gqa.sh --fp16 --int8 # ./profile_gqa.sh --fp16 --past-sequence-length 8192 --local-window-size 128 -# ./profile_gqa.sh --bf16 --num-heads 64 --kv-num-heads 8 +# ./profile_gqa.sh --bf16 --num-heads 64 --kv-num-heads 8 --head-size 128 +# ./profile_gqa.sh --fp16 --compare-qk-norm --past-sequence-length 2048 +# NSYS=~/cuda13.0/bin/nsys ./profile_gqa.sh --fp16 --qk-norm # CUDA_VISIBLE_DEVICES=1 PYTHON=python3 ./profile_gqa.sh --int4 # @@ -19,6 +21,7 @@ set -o pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PY="${PYTHON:-python}" +NSYS="${NSYS:-}" # Parse arguments RUN_FP16=false @@ -31,11 +34,20 @@ RUN_BF16=false BATCH_SIZE="" SEQUENCE_LENGTH="" PAST_SEQUENCE_LENGTH="" +MAX_SEQUENCE_LENGTH="" PACKED_QKV="" SHARE_KV_SCALE="" NUM_HEADS="" KV_NUM_HEADS="" +HEAD_SIZE="" LOCAL_WINDOW_SIZE="" +HEAD_SINK="" +NO_ROTARY="" +QK_NORM=false +COMPARE_QK_NORM=false +QK_NORM_EPSILON="" +WARMUP=5 +REPEAT=100 while [[ "$#" -gt 0 ]]; do case $1 in --fp16) @@ -66,6 +78,29 @@ while [[ "$#" -gt 0 ]]; do RUN_BF16=true echo "==== 🚀 All runs enabled ====" ;; + --qk-norm) + QK_NORM=true + echo "==== QK-Norm enabled ====" + ;; + --compare-qk-norm) + COMPARE_QK_NORM=true + echo "==== Compare baseline vs QK-Norm enabled ====" + ;; + --qk-norm-epsilon) + QK_NORM_EPSILON="--qk-norm-epsilon $2" + echo "==== QK-Norm epsilon: $2 ====" + shift + ;; + --warmup) + WARMUP="$2" + echo "==== Warmup iterations: $2 ====" + shift + ;; + --repeat) + REPEAT="$2" + echo "==== Repeat iterations: $2 ====" + shift + ;; -b|--batch-size) BATCH_SIZE="--batch-size $2" echo "==== Batch size: $2 ====" @@ -81,6 +116,11 @@ while [[ "$#" -gt 0 ]]; do echo "==== Past sequence length: $2 ====" shift ;; + --max-sequence-length) + MAX_SEQUENCE_LENGTH="--max-sequence-length $2" + echo "==== Max sequence length: $2 ====" + shift + ;; --qkv) PACKED_QKV="--is-packed-qkv" echo "==== Packed QKV enabled ====" @@ -99,11 +139,24 @@ while [[ "$#" -gt 0 ]]; do echo "==== KV Num Heads: $2 ====" shift ;; + --head-size) + HEAD_SIZE="--head-size $2" + echo "==== Head size: $2 ====" + shift + ;; -w|--local-window-size) LOCAL_WINDOW_SIZE="--local-window-size $2" echo "==== Local window size: $2 ====" shift ;; + --head-sink) + HEAD_SINK="--head-sink" + echo "==== Head sink enabled ====" + ;; + --no-rotary) + NO_ROTARY="--no-rotary" + echo "==== Rotary disabled ====" + ;; *) echo "Unknown option: $1" exit 1 @@ -113,10 +166,19 @@ while [[ "$#" -gt 0 ]]; do done # Build extra args string -EXTRA_ARGS="${BATCH_SIZE} ${SEQUENCE_LENGTH} ${PAST_SEQUENCE_LENGTH} ${PACKED_QKV} ${SHARE_KV_SCALE} ${NUM_HEADS} ${KV_NUM_HEADS} ${LOCAL_WINDOW_SIZE}" +EXTRA_ARGS="${BATCH_SIZE} ${SEQUENCE_LENGTH} ${PAST_SEQUENCE_LENGTH} ${MAX_SEQUENCE_LENGTH} ${PACKED_QKV} ${SHARE_KV_SCALE} ${NUM_HEADS} ${KV_NUM_HEADS} ${HEAD_SIZE} ${LOCAL_WINDOW_SIZE} ${HEAD_SINK} ${NO_ROTARY} ${QK_NORM_EPSILON}" -if ! command -v nsys >/dev/null; then +if [[ -z "${NSYS}" ]]; then + if command -v nsys >/dev/null; then + NSYS="$(command -v nsys)" + elif [[ -x "${HOME}/cuda13.0/bin/nsys" ]]; then + NSYS="${HOME}/cuda13.0/bin/nsys" + fi +fi + +if [[ -z "${NSYS}" || ! -x "${NSYS}" ]]; then echo "Error: nsys not found. Install NVIDIA Nsight Systems or add it to PATH." + echo " Or set NSYS=/path/to/nsys (for example NSYS=~/cuda13.0/bin/nsys)." exit 1 fi @@ -129,12 +191,13 @@ else echo " Falling back to --skip-first to exclude warmup-like first calls." fi -# profile_one [env_var=value ...] +# profile_one [env_var=value ...] profile_one() { local mode="$1" local tag="$2" local base="$3" - shift 3 + local cli_extra="$4" + shift 4 local env_args=() local e @@ -145,34 +208,56 @@ profile_one() { echo "" echo "---- Profiling ${mode} ----" rm -f "${base}.nsys-rep" "${base}.sqlite" - nsys profile -t cuda,nvtx --force-overwrite true "${env_args[@]}" -o "${base}" --export=sqlite \ - "${PY}" "${SCRIPT_DIR}/profile_gqa.py" --mode "${mode}" --warmup 5 --repeat 100 ${EXTRA_ARGS} + "${NSYS}" profile -t cuda,nvtx --force-overwrite true "${env_args[@]}" -o "${base}" --export=sqlite \ + "${PY}" "${SCRIPT_DIR}/profile_gqa.py" --mode "${mode}" --warmup "${WARMUP}" --repeat "${REPEAT}" ${EXTRA_ARGS} ${cli_extra} echo "" echo "---- Kernel results (${mode}) ----" + local range_name="benchmark_${mode}" + if [[ "${cli_extra}" == *"--qk-norm"* ]]; then + range_name="benchmark_${mode}_qknorm" + fi if [[ "${HAVE_NVTX}" -eq 1 ]]; then - "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --nvtx-range "benchmark_${mode}" --tag "${tag}" + "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --nvtx-range "${range_name}" --tag "${tag}" \ + --pattern "%onnxruntime%" --pattern "%cudnn%" + else + "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --skip-first "${WARMUP}" --tag "${tag}" \ + --pattern "%onnxruntime%" --pattern "%cudnn%" + fi +} + +profile_mode() { + local mode="$1" + local tag="$2" + local base="$3" + shift 3 + + if [[ "${COMPARE_QK_NORM}" == true ]]; then + profile_one "${mode}" "${tag}" "${base}" "" "$@" + profile_one "${mode}" "${tag}QK" "${base}_qknorm" "--qk-norm" "$@" + elif [[ "${QK_NORM}" == true ]]; then + profile_one "${mode}" "${tag}QK" "${base}_qknorm" "--qk-norm" "$@" else - "${PY}" "${SCRIPT_DIR}/parse_nsys.py" "${base}.sqlite" --skip-first 5 --tag "${tag}" + profile_one "${mode}" "${tag}" "${base}" "" "$@" fi } if [ "$RUN_FP16" = true ]; then - profile_one fp16 Fp16 gqa_fp16 + profile_mode fp16 Fp16 gqa_fp16 fi if [ "$RUN_BF16" = true ]; then - profile_one bf16 Bf16 gqa_bf16 + profile_mode bf16 Bf16 gqa_bf16 fi if [ "$RUN_INT8" = true ]; then - profile_one int8 Int8 gqa_int8 ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=0 + profile_mode int8 Int8 gqa_int8 ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=0 fi if [ "$RUN_INT8_QUANT" = true ]; then - profile_one int8 Int8Q gqa_int8_quant ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=1 + profile_mode int8 Int8Q gqa_int8_quant ORT_FLASH_ATTENTION_QUERY_DYNAMIC_QUANT=1 fi if [ "$RUN_INT4" = true ]; then - profile_one int4 Int4 gqa_int4 + profile_mode int4 Int4 gqa_int4 fi diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 2ff4cdc988a42..f0f6865bb463d 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -104,6 +104,10 @@ class GQAConfig: v_quant_type: str = "NONE" kv_cache_bit_width: int = 0 + # Fused per-head Q/K RMSNorm (QK-Norm) applied before RoPE. Weight shape (head_size,) shared across heads. + has_qk_norm: bool = False + qk_norm_epsilon: float = 1e-6 + # ################################################################################################# # Rotary Embedding Implementations (CPU and CUDA) @@ -205,6 +209,25 @@ def make_head_sink_initializer(head_sink, ort_type, num_heads): return helper.make_tensor(name="head_sink", data_type=ort_type, dims=[num_heads], vals=raw, raw=True) +def make_qk_norm_weights(head_size, device, torch_type, seed=7): + """Generate deterministic per-head Q/K RMSNorm weights of shape (head_size,).""" + gen = torch.Generator(device=device).manual_seed(seed) + q_w = (1.0 + 0.1 * torch.randn(head_size, generator=gen, device=device, dtype=torch.float32)).to(torch_type) + k_w = (1.0 + 0.1 * torch.randn(head_size, generator=gen, device=device, dtype=torch.float32)).to(torch_type) + return q_w, k_w + + +def apply_qk_rmsnorm(x, weight, eps): + """Reference per-head RMSNorm over the last (head_size) dim, computed in float32 then cast back. + + x_norm[c] = x[c] * rsqrt(mean(x^2) + eps) * weight[c] + """ + dtype = x.dtype + xf = x.to(torch.float32) + inv_rms = torch.rsqrt(xf.pow(2).mean(dim=-1, keepdim=True) + eps) + return (xf * inv_rms * weight.to(torch.float32)).to(dtype) + + def create_gqa_node_and_io( config: GQAConfig, ort_type, @@ -264,6 +287,8 @@ def create_gqa_node_and_io( "k_scale" if config.share_kv_scale and config.k_quant_type != "NONE" else ("v_scale" if config.v_quant_type != "NONE" else ""), + "q_norm_weight" if config.has_qk_norm else "", + "k_norm_weight" if config.has_qk_norm else "", ] # Remove trailing empty strings @@ -280,6 +305,8 @@ def create_gqa_node_and_io( else {} ) + qk_norm_attributes = {"qk_norm_epsilon": config.qk_norm_epsilon} if config.has_qk_norm else {} + node = helper.make_node( op_type="GroupQueryAttention", inputs=inputs, @@ -294,6 +321,7 @@ def create_gqa_node_and_io( smooth_softmax=1 if config.use_smooth_softmax else 0, qk_output=output_qk, **quantization_attributes, + **qk_norm_attributes, domain="com.microsoft", ) @@ -372,6 +400,10 @@ def create_gqa_node_and_io( else: graph_input.append(helper.make_tensor_value_info("head_sink", ort_type, [config.num_heads])) + if config.has_qk_norm: + graph_input.append(helper.make_tensor_value_info("q_norm_weight", ort_type, [config.head_size])) + graph_input.append(helper.make_tensor_value_info("k_norm_weight", ort_type, [config.head_size])) + # --- Graph Outputs --- output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size] if config.kv_cache_type == "int4": @@ -467,6 +499,8 @@ def gqa_prompt_func( device, share_buffer=True, ort_type=TensorProto.FLOAT16, + q_norm_weight=None, + k_norm_weight=None, ): if not config.kv_cache_type: config.kv_cache_type = "float16" if ort_type == TensorProto.FLOAT16 else "bfloat16" @@ -479,8 +513,9 @@ def gqa_prompt_func( q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) if new_k is not None: - new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + kv_hidden_size = config.kv_num_heads * config.head_size + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, kv_hidden_size)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, kv_hidden_size)) sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=[resolve_cuda_plugin_ep(ep)]) @@ -537,6 +572,10 @@ def gqa_prompt_func( if config.has_head_sink and head_sink is not None: bind_tensor(io_binding, "head_sink", head_sink, device, ort_type) + if config.has_qk_norm and q_norm_weight is not None and k_norm_weight is not None: + bind_tensor(io_binding, "q_norm_weight", q_norm_weight, device, ort_type) + bind_tensor(io_binding, "k_norm_weight", k_norm_weight, device, ort_type) + # 6. Quantization scales if k_scale is not None: k_scale_ort_type = TensorProto.FLOAT @@ -627,6 +666,8 @@ def gqa_past_func( device, share_buffer=True, ort_type=TensorProto.FLOAT16, + q_norm_weight=None, + k_norm_weight=None, ): if not config.kv_cache_type: config.kv_cache_type = "float16" if ort_type == TensorProto.FLOAT16 else "bfloat16" @@ -641,8 +682,9 @@ def gqa_past_func( q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) if new_k is not None: - new_k = torch.reshape(new_k, (config.batch_size, config.q_sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.q_sequence_length, -1)) + kv_hidden_size = config.kv_num_heads * config.head_size + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, kv_hidden_size)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, kv_hidden_size)) sess_options = SessionOptions() # sess_options.log_severity_level = 0 @@ -650,7 +692,7 @@ def gqa_past_func( io_binding = ort_session.io_binding() # Common inputs - total_seq_len = config.past_kv_sequence_length + config.q_sequence_length + total_seq_len = config.past_kv_sequence_length + config.kv_sequence_length # 1. Bind 'query' bind_tensor(io_binding, "query", q, device, ort_type) @@ -702,6 +744,10 @@ def gqa_past_func( if config.has_head_sink and head_sink is not None and not head_sink_as_initializer: bind_tensor(io_binding, "head_sink", head_sink, device, ort_type) + if config.has_qk_norm and q_norm_weight is not None and k_norm_weight is not None: + bind_tensor(io_binding, "q_norm_weight", q_norm_weight, device, ort_type) + bind_tensor(io_binding, "k_norm_weight", k_norm_weight, device, ort_type) + # 6. Quantization if k_scale is not None: k_scale_ort_type = TensorProto.FLOAT @@ -1000,13 +1046,18 @@ def parity_check_gqa_prompt( rotary_seqlens = torch.zeros(config.batch_size, device=device, dtype=torch.long) cos, sin, q_ro, k_ro = None, None, q, new_k + q_norm_weight, k_norm_weight = None, None + if config.has_qk_norm: + q_norm_weight, k_norm_weight = make_qk_norm_weights(config.head_size, device, torch_type) + q_ro = apply_qk_rmsnorm(q, q_norm_weight, config.qk_norm_epsilon) + k_ro = apply_qk_rmsnorm(new_k, k_norm_weight, config.qk_norm_epsilon) if config.rotary: rotary_dim = math.floor(config.head_size / 16) * 16 angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device=device) * 2 * math.pi cos = torch.cos(angle).to(dtype=torch_type) sin = torch.sin(angle).to(dtype=torch_type) - q_ro = apply_rotary_embedding(q.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device) - k_ro = apply_rotary_embedding(new_k.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device) + q_ro = apply_rotary_embedding(q_ro.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device) + k_ro = apply_rotary_embedding(k_ro.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device) position_ids, attention_bias = None, None if config.has_position_ids: @@ -1096,6 +1147,8 @@ def parity_check_gqa_prompt( device=device, share_buffer=config.share_buffer, ort_type=ort_type, + q_norm_weight=q_norm_weight, + k_norm_weight=k_norm_weight, ) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() @@ -1277,7 +1330,7 @@ def parity_check_gqa_past( new_k = ( torch.randn( config.batch_size, - config.q_sequence_length, + config.kv_sequence_length, config.kv_num_heads, config.head_size, device=device, @@ -1318,16 +1371,21 @@ def parity_check_gqa_past( v_cache_ref = v_ref_dequant.clone().transpose(1, 2) cos, sin, q_ro, k_ro = None, None, q, new_k + q_norm_weight, k_norm_weight = None, None + if config.has_qk_norm: + q_norm_weight, k_norm_weight = make_qk_norm_weights(config.head_size, device, torch_type) + q_ro = apply_qk_rmsnorm(q, q_norm_weight, config.qk_norm_epsilon) + k_ro = apply_qk_rmsnorm(new_k, k_norm_weight, config.qk_norm_epsilon) if config.rotary: rotary_dim = math.floor(config.head_size / 16) * 16 angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device=device) * 2 * math.pi cos = torch.cos(angle).to(dtype=torch_type) sin = torch.sin(angle).to(dtype=torch_type) - q_ro = apply_rotary_embedding(q.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device) - k_ro = apply_rotary_embedding(new_k.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device) + q_ro = apply_rotary_embedding(q_ro.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device) + k_ro = apply_rotary_embedding(k_ro.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device) position_ids, attention_bias = None, None - total_seq_len = config.past_kv_sequence_length + config.q_sequence_length + total_seq_len = config.past_kv_sequence_length + config.kv_sequence_length if config.has_position_ids: position_ids = (cache_seqlens.unsqueeze(1) + torch.arange(config.q_sequence_length, device=device)).long() if config.has_attention_bias: @@ -1341,7 +1399,7 @@ def parity_check_gqa_past( arange = rearrange(torch.arange(config.buffer_sequence_length, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.q_sequence_length + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.kv_sequence_length ) k_to_cache = k_ro @@ -1369,7 +1427,7 @@ def parity_check_gqa_past( k_cache_ref[update_mask] = rearrange(k_to_cache, "b s ... -> (b s) ...").to(k_cache_ref.dtype) v_cache_ref[update_mask] = rearrange(v_to_cache, "b s ... -> (b s) ...").to(v_cache_ref.dtype) - key_padding_mask = arange < cache_seqlens_expanded + config.q_sequence_length + key_padding_mask = arange < cache_seqlens_expanded + config.kv_sequence_length out_ref, _ = attention_ref( q=q_ro, @@ -1405,7 +1463,7 @@ def parity_check_gqa_past( k_ort = k_ort.contiguous() v_ort = v_ort.contiguous() - ort_seqlens = cache_seqlens + config.q_sequence_length - 1 + ort_seqlens = cache_seqlens + config.kv_sequence_length - 1 out, present_k, present_v = gqa_past_func( q=q_ort, @@ -1426,6 +1484,8 @@ def parity_check_gqa_past( device=device, share_buffer=config.share_buffer, ort_type=ort_type, + q_norm_weight=q_norm_weight, + k_norm_weight=k_norm_weight, ) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() @@ -2105,6 +2165,214 @@ def test_gqa_past_flash_attention_bf16(self, name, config): ) +def gqa_qk_norm_test_cases(is_past: bool): + """Configs exercising the fused per-head Q/K RMSNorm (QK-Norm) prologue before RoPE.""" + head_sizes = [64, 128] + head_groups = [(8, 2), (4, 4)] + rotary_opts = [(False, False), (True, False), (True, True)] + packed_opts = [False, True] + idx = 0 + for h in head_sizes: + for n, n2 in head_groups: + for rotary, interleaved in rotary_opts: + if rotary and h % 16 != 0: + continue + packed = packed_opts[idx % len(packed_opts)] + idx += 1 + if is_past: + b, s, s2 = 2, 1, 127 + config = GQAConfig( + batch_size=b, + q_sequence_length=s, + kv_sequence_length=s, + past_kv_sequence_length=s2, + buffer_sequence_length=s + s2 + 8, + num_heads=n, + kv_num_heads=n2, + head_size=h, + rotary=rotary, + rotary_interleaved=interleaved, + packed=packed, + share_buffer=True, + has_qk_norm=True, + ) + else: + b, s = 2, 64 + config = GQAConfig( + batch_size=b, + q_sequence_length=s, + kv_sequence_length=s, + buffer_sequence_length=s + 8, + num_heads=n, + kv_num_heads=n2, + head_size=h, + rotary=rotary, + rotary_interleaved=interleaved, + packed=packed, + share_buffer=True, + has_qk_norm=True, + ) + name = f"{'past' if is_past else 'prompt'}_b{b}_nh{n}_{n2}_h{h}_rot{rotary}{interleaved}_pkd{packed}" + yield name, config + + +@unittest.skipIf(not has_cuda_device(80), "CUDA GQA QK-Norm requires Ampere or higher GPU, skipping tests.") +class TestGQAQKNorm(unittest.TestCase): + def tearDown(self): + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + gc.collect() + + @parameterized.expand(gqa_qk_norm_test_cases(is_past=False)) + def test_gqa_qk_norm_prompt(self, name, config): + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + @parameterized.expand(gqa_qk_norm_test_cases(is_past=True)) + def test_gqa_qk_norm_past(self, name, config): + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_gqa_qk_norm_past_xqa(self): + config = GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=127, + buffer_sequence_length=136, + num_heads=32, + kv_num_heads=8, + head_size=128, + rotary=True, + rotary_interleaved=False, + packed=False, + share_buffer=True, + has_qk_norm=True, + ) + + with scoped_env_var("ORT_ENABLE_XQA", "1"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_gqa_qk_norm_past_shared_kv(self): + config = GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=0, + past_kv_sequence_length=127, + buffer_sequence_length=135, + num_heads=8, + kv_num_heads=2, + head_size=64, + rotary=False, + packed=False, + share_buffer=True, + has_qk_norm=True, + ) + + with scoped_env_var("ORT_ENABLE_XQA", "1"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_gqa_qk_norm_past_xqa_bf16(self): + if not torch.cuda.is_bf16_supported(): + self.skipTest("BFloat16 not supported on this device") + + config = GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=127, + buffer_sequence_length=136, + num_heads=32, + kv_num_heads=8, + head_size=128, + rotary=True, + rotary_interleaved=False, + packed=False, + share_buffer=True, + has_qk_norm=True, + ) + config.kv_cache_type = "bfloat16" + + with scoped_env_var("ORT_ENABLE_XQA", "1"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) + + @parameterized.expand(gqa_qk_norm_test_cases(is_past=True)) + def test_gqa_qk_norm_past_bf16(self, name, config): + if not torch.cuda.is_bf16_supported(): + self.skipTest("BFloat16 not supported on this device") + + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + + config.kv_cache_type = "bfloat16" + with scoped_env_var("ORT_DISABLE_FLASH_ATTENTION", "0"): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) + + @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") @unittest.skipIf(not has_quantized_kv_cache(), "Quantized KV Cache is not available, skipping tests.") class TestFlashGQABF16QuantizedKV(unittest.TestCase):