Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,16 @@ GroupQueryAttention<T, U>::GroupQueryAttention(const OpKernelInfo& info)
v_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault<std::string>("v_quant_type", "NONE"));
kv_cache_bit_width_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("kv_cache_bit_width", 0));

bool is_quantized = (k_quant_type_ != KVQuantizationType::NONE || v_quant_type_ != KVQuantizationType::NONE);
// XQA enablement:
// - An explicit ORT_ENABLE_XQA overrides everything (1 = on, 0 = off, including the head_sink default-on path).
// - When unset, XQA defaults on for the quantized KV cache path and off for the non-quantized path
// (the non-quantized head_sink decode path is additionally enabled per-Run in ComputeInternal).
constexpr bool kIsFp16OrBf16 = std::is_same_v<T, MLFloat16> || std::is_same_v<T, BFloat16>;
// XQA defaults on for fp16/bf16; ORT_ENABLE_XQA=0 disables it explicitly.
enable_xqa_ = kIsFp16OrBf16 && (ParseEnvironmentVariableWithDefault<int>("ORT_ENABLE_XQA", 1) != 0);
const int xqa_env = ParseEnvironmentVariableWithDefault<int>("ORT_ENABLE_XQA", -1); // -1 means unset
xqa_force_disabled_ = (xqa_env == 0);
const int effective_enable_xqa = (xqa_env == -1) ? (is_quantized ? 1 : 0) : xqa_env;
enable_xqa_ = kIsFp16OrBf16 && (effective_enable_xqa != 0);

kernel_options_ = this->GetAttentionKernelOptions();

Expand Down Expand Up @@ -391,8 +398,14 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
// 7. No local window attention (global attention only).
const bool use_xqa_attention_sinks = head_sink != nullptr && !is_inputs_quantized;
const bool is_xqa_smooth_softmax_supported = !parameters.use_smooth_softmax || use_xqa_attention_sinks;
// XQA is enabled when enable_xqa_=true; ineligible shapes/group sizes fall back via data.use_xqa below.
if (enable_xqa_ &&
// XQA is opt-in for the non-quantized path (ORT_ENABLE_XQA), but a head_sink (attention sink) input
// signals a GPT-OSS style decode model that benefits from XQA, so enable it by default in that case.
// An explicit ORT_ENABLE_XQA=0 (xqa_force_disabled_) still wins and turns XQA off entirely.
// The dtype guard mirrors enable_xqa_ (XQA only supports fp16/bf16); ineligible cases fall back below.
constexpr bool kIsFp16OrBf16 = std::is_same_v<T, MLFloat16> || std::is_same_v<T, BFloat16>;
const bool xqa_enabled_for_run =
!xqa_force_disabled_ && (enable_xqa_ || (kIsFp16OrBf16 && use_xqa_attention_sinks));
if (xqa_enabled_for_run &&
(device_prop.major >= 8) &&
!parameters.is_first_prompt &&
parameters.sequence_length == 1 &&
Expand Down Expand Up @@ -424,8 +437,7 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons

bool is_non_quantized_supported = !is_inputs_quantized &&
(parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
(group_size == 1 || group_size == 2 || group_size == 4 || group_size == 5 ||
group_size == 8 || group_size == 16 || group_size == 32);
(64 % group_size == 0);

data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported || is_fp8_quantized_supported);

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class GroupQueryAttention final : public CudaKernel {
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
bool disable_flash_decode_;
bool enable_xqa_; // True when ORT_ENABLE_XQA != 0 (default: on) and T is fp16/bf16.
bool enable_xqa_;
bool xqa_force_disabled_; // True when ORT_ENABLE_XQA=0 is explicitly set (overrides default-on paths).
bool enable_cudnn_flash_attention_; // cuDNN SDPA explicitly enabled (env / sdpa_kernel)
bool auto_enable_cudnn_flash_attention_; // auto-prefer cuDNN SDPA on SM>=90 when no explicit kernel pinned

Expand Down
12 changes: 1 addition & 11 deletions onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,6 @@ namespace HEAD_DIM_NAMESPACE {
#undef GRP_SIZE
#undef M_TILESIZE

#define NAMESPACE_NAME grp5_bf16
#define GRP_SIZE 5
#define M_TILESIZE 8
#include "xqa_impl_gen.cuh"
#undef NAMESPACE_NAME
#undef GRP_SIZE
#undef M_TILESIZE

#define NAMESPACE_NAME grp8_bf16
#define GRP_SIZE 8
#define M_TILESIZE 8
Expand Down Expand Up @@ -223,16 +215,14 @@ Status LaunchXQAKernelImpl<__nv_bfloat16>(
return grp2_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size);
case 4:
return grp4_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size);
case 5:
return grp5_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size);
case 8:
return grp8_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size);
case 16:
return grp16_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size);
case 32:
return grp32_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size);
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 5, 8, 16, 32. Input has ", group_size);
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 8, 16, 32. Input has ", group_size);
}
}

Expand Down
12 changes: 1 addition & 11 deletions onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,6 @@ namespace HEAD_DIM_NAMESPACE {
#undef GRP_SIZE
#undef M_TILESIZE

#define NAMESPACE_NAME grp5_fp16
#define GRP_SIZE 5
#define M_TILESIZE 8
#include "xqa_impl_gen.cuh"
#undef NAMESPACE_NAME
#undef GRP_SIZE
#undef M_TILESIZE

#define NAMESPACE_NAME grp8_fp16
#define GRP_SIZE 8
#define M_TILESIZE 8
Expand Down Expand Up @@ -205,16 +197,14 @@ Status LaunchXQAKernelImpl(
return grp2_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size);
case 4:
return grp4_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size);
case 5:
return grp5_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size);
case 8:
return grp8_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size);
case 16:
return grp16_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size);
case 32:
return grp32_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, attention_sinks, kv_cache_scale, workspace, workspace_size);
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 5, 8, 16, 32. Input has ", group_size);
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 8, 16, 32. Input has ", group_size);
}
}

Expand Down
1 change: 0 additions & 1 deletion onnxruntime/test/python/transformers/benchmark_gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ def run_performance_test(
# Note: some models use bf16.
# We use fp16/bf16 for all models in this test.
configures = [
(40, 128, 8, 8192, None, "Qwen3-14B"),
(32, 128, 8, 8192, None, "Llama3-8B"),
(64, 128, 8, 8192, None, "Llama3-70B"),
(32, 128, 8, 32768, 4096, "Mistral-7B-v0.1"),
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/test/python/transformers/test_gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2371,9 +2371,9 @@ def gqa_xqa_head_sink_test_cases():
# Non-quantized global decode with a head_sink (attention sink) input.
# These configs exercise the XQA attention-sink path added for GPT-OSS style models:
# seq_len=1, shared KV buffer, no softcap, no local window, head_size in {64, 128},
# and group_size in {1, 2, 4, 5, 8, 16, 32}.
# and 64 % group_size == 0.
for torch_type, ort_type in [(torch.float16, TensorProto.FLOAT16), (torch.bfloat16, TensorProto.BFLOAT16)]:
for group_size in [1, 4, 5, 8]:
for group_size in [1, 4, 8]:
for head_size in [64, 128]:
for rotary in [False, True]:
kv_num_heads = 4
Expand Down Expand Up @@ -2435,8 +2435,8 @@ class TestXQAHeadSinkParity(unittest.TestCase):
"""Verify the non-quantized XQA attention-sink (head_sink) decode path matches the reference."""

def setUp(self):
# XQA is enabled by default for fp16/bf16 (ORT_ENABLE_XQA defaults to 1).
# Pop any override so we exercise the real default behavior.
# XQA is enabled by default when a head_sink input is present, so this path is exercised
# without ORT_ENABLE_XQA. Clear it (saving the previous value) to test the real default.
self._prev_enable_xqa = os.environ.pop("ORT_ENABLE_XQA", None)

def tearDown(self):
Expand Down
Loading