From 07a73474ffb7693e7f8b38f529dcad301ddb4439 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 22 Jun 2026 13:53:29 -0700 Subject: [PATCH] Revert "[CUDA] Enable XQA by default for FP16/BF16 GQA (#29046)" This reverts commit 6be94ded1919ab3193c9f30b0ecc5a229429a0f8. --- .../cuda/bert/group_query_attention.cc | 24 ++++++++++++++----- .../cuda/bert/group_query_attention.h | 3 ++- .../cuda/bert/xqa/xqa_loader_bf16_impl.cuh | 12 +--------- .../cuda/bert/xqa/xqa_loader_fp16_impl.cuh | 12 +--------- .../test/python/transformers/benchmark_gqa.py | 1 - .../test/python/transformers/test_gqa.py | 8 +++---- 6 files changed, 26 insertions(+), 34 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 60b806019cc57..60235024b9118 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -110,9 +110,16 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) v_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("v_quant_type", "NONE")); kv_cache_bit_width_ = static_cast(info.GetAttrOrDefault("kv_cache_bit_width", 0)); + bool is_quantized = (k_quant_type_ != KVQuantizationType::NONE || v_quant_type_ != KVQuantizationType::NONE); + // XQA enablement: + // - An explicit ORT_ENABLE_XQA overrides everything (1 = on, 0 = off, including the head_sink default-on path). + // - When unset, XQA defaults on for the quantized KV cache path and off for the non-quantized path + // (the non-quantized head_sink decode path is additionally enabled per-Run in ComputeInternal). constexpr bool kIsFp16OrBf16 = std::is_same_v || std::is_same_v; - // XQA defaults on for fp16/bf16; ORT_ENABLE_XQA=0 disables it explicitly. - enable_xqa_ = kIsFp16OrBf16 && (ParseEnvironmentVariableWithDefault("ORT_ENABLE_XQA", 1) != 0); + const int xqa_env = ParseEnvironmentVariableWithDefault("ORT_ENABLE_XQA", -1); // -1 means unset + xqa_force_disabled_ = (xqa_env == 0); + const int effective_enable_xqa = (xqa_env == -1) ? (is_quantized ? 1 : 0) : xqa_env; + enable_xqa_ = kIsFp16OrBf16 && (effective_enable_xqa != 0); kernel_options_ = this->GetAttentionKernelOptions(); @@ -391,8 +398,14 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // 7. No local window attention (global attention only). const bool use_xqa_attention_sinks = head_sink != nullptr && !is_inputs_quantized; const bool is_xqa_smooth_softmax_supported = !parameters.use_smooth_softmax || use_xqa_attention_sinks; - // XQA is enabled when enable_xqa_=true; ineligible shapes/group sizes fall back via data.use_xqa below. - if (enable_xqa_ && + // XQA is opt-in for the non-quantized path (ORT_ENABLE_XQA), but a head_sink (attention sink) input + // signals a GPT-OSS style decode model that benefits from XQA, so enable it by default in that case. + // An explicit ORT_ENABLE_XQA=0 (xqa_force_disabled_) still wins and turns XQA off entirely. + // The dtype guard mirrors enable_xqa_ (XQA only supports fp16/bf16); ineligible cases fall back below. + constexpr bool kIsFp16OrBf16 = std::is_same_v || std::is_same_v; + const bool xqa_enabled_for_run = + !xqa_force_disabled_ && (enable_xqa_ || (kIsFp16OrBf16 && use_xqa_attention_sinks)); + if (xqa_enabled_for_run && (device_prop.major >= 8) && !parameters.is_first_prompt && parameters.sequence_length == 1 && @@ -424,8 +437,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons bool is_non_quantized_supported = !is_inputs_quantized && (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) && - (group_size == 1 || group_size == 2 || group_size == 4 || group_size == 5 || - group_size == 8 || group_size == 16 || group_size == 32); + (64 % group_size == 0); data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported || is_fp8_quantized_supported); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 2fe6268e53c68..d5b980bdca290 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -38,7 +38,8 @@ class GroupQueryAttention final : public CudaKernel { bool disable_flash_attention_; bool disable_memory_efficient_attention_; bool disable_flash_decode_; - bool enable_xqa_; // True when ORT_ENABLE_XQA != 0 (default: on) and T is fp16/bf16. + bool enable_xqa_; + bool xqa_force_disabled_; // True when ORT_ENABLE_XQA=0 is explicitly set (overrides default-on paths). bool enable_cudnn_flash_attention_; // cuDNN SDPA explicitly enabled (env / sdpa_kernel) bool auto_enable_cudnn_flash_attention_; // auto-prefer cuDNN SDPA on SM>=90 when no explicit kernel pinned diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh index 3a9093f634bf0..6a84d452f1384 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh @@ -72,14 +72,6 @@ namespace HEAD_DIM_NAMESPACE { #undef GRP_SIZE #undef M_TILESIZE -#define NAMESPACE_NAME grp5_bf16 -#define GRP_SIZE 5 -#define M_TILESIZE 8 -#include "xqa_impl_gen.cuh" -#undef NAMESPACE_NAME -#undef GRP_SIZE -#undef M_TILESIZE - #define NAMESPACE_NAME grp8_bf16 #define GRP_SIZE 8 #define M_TILESIZE 8 @@ -223,8 +215,6 @@ Status LaunchXQAKernelImpl<__nv_bfloat16>( return grp2_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 4: return grp4_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); - case 5: - return grp5_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 8: return grp8_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 16: @@ -232,7 +222,7 @@ Status LaunchXQAKernelImpl<__nv_bfloat16>( case 32: return grp32_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); default: - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 5, 8, 16, 32. Input has ", group_size); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 8, 16, 32. Input has ", group_size); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh index 11b272a77732d..269b7956c0999 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh @@ -72,14 +72,6 @@ namespace HEAD_DIM_NAMESPACE { #undef GRP_SIZE #undef M_TILESIZE -#define NAMESPACE_NAME grp5_fp16 -#define GRP_SIZE 5 -#define M_TILESIZE 8 -#include "xqa_impl_gen.cuh" -#undef NAMESPACE_NAME -#undef GRP_SIZE -#undef M_TILESIZE - #define NAMESPACE_NAME grp8_fp16 #define GRP_SIZE 8 #define M_TILESIZE 8 @@ -205,8 +197,6 @@ Status LaunchXQAKernelImpl( return grp2_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 4: return grp4_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); - case 5: - return grp5_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 8: return grp8_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); case 16: @@ -214,7 +204,7 @@ Status LaunchXQAKernelImpl( case 32: return grp32_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size); default: - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 5, 8, 16, 32. Input has ", group_size); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 8, 16, 32. Input has ", group_size); } } diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py index b6a7361b6cb84..10e7ea953a503 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa.py @@ -259,7 +259,6 @@ def run_performance_test( # Note: some models use bf16. # We use fp16/bf16 for all models in this test. configures = [ - (40, 128, 8, 8192, None, "Qwen3-14B"), (32, 128, 8, 8192, None, "Llama3-8B"), (64, 128, 8, 8192, None, "Llama3-70B"), (32, 128, 8, 32768, 4096, "Mistral-7B-v0.1"), diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 2ff4cdc988a42..529eae1494e94 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -2371,9 +2371,9 @@ def gqa_xqa_head_sink_test_cases(): # Non-quantized global decode with a head_sink (attention sink) input. # These configs exercise the XQA attention-sink path added for GPT-OSS style models: # seq_len=1, shared KV buffer, no softcap, no local window, head_size in {64, 128}, - # and group_size in {1, 2, 4, 5, 8, 16, 32}. + # and 64 % group_size == 0. for torch_type, ort_type in [(torch.float16, TensorProto.FLOAT16), (torch.bfloat16, TensorProto.BFLOAT16)]: - for group_size in [1, 4, 5, 8]: + for group_size in [1, 4, 8]: for head_size in [64, 128]: for rotary in [False, True]: kv_num_heads = 4 @@ -2435,8 +2435,8 @@ class TestXQAHeadSinkParity(unittest.TestCase): """Verify the non-quantized XQA attention-sink (head_sink) decode path matches the reference.""" def setUp(self): - # XQA is enabled by default for fp16/bf16 (ORT_ENABLE_XQA defaults to 1). - # Pop any override so we exercise the real default behavior. + # XQA is enabled by default when a head_sink input is present, so this path is exercised + # without ORT_ENABLE_XQA. Clear it (saving the previous value) to test the real default. self._prev_enable_xqa = os.environ.pop("ORT_ENABLE_XQA", None) def tearDown(self):