From b2613d9a81ae3404fc55f437ddc8f3deef83d91e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 9 Jun 2026 19:35:40 +0000 Subject: [PATCH 1/5] feat(cpu): add flash attention for non-quantized GQA Add an FP32 tiled online-softmax flash attention kernel for the CPU GroupQueryAttention contrib op, mirroring the existing quantized-KV flash path. Avoids materializing the full attention score matrix and adds a two-phase flash-decoding path for single-token decode. - New MLAS kernel core/mlas/lib/flashattn_gqa.cpp (MlasFlashAttentionGQA) supporting GQA head grouping, causal masking, local window, attention bias, ragged/per-batch seqlens, packed QKV, and flash-decoding. - New ApplyAttentionFlash dispatch in gqa_attention_base.h; wired into group_query_attention.cc (float only, gated like the quantized flash path: no softcap/smooth softmax/head sink/QK output). - Reuses ORT_GQA_DISABLE_FLASH_ATTENTION to fall back to the naive path. --- cmake/onnxruntime_mlas.cmake | 1 + docs/contrib_ops/cpu/gqa.md | 58 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 310 +++++++++ .../cpu/bert/group_query_attention.cc | 21 + onnxruntime/core/mlas/inc/mlas.h | 57 ++ onnxruntime/core/mlas/lib/flashattn_gqa.cpp | 610 ++++++++++++++++++ 6 files changed, 1052 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/flashattn_gqa.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index b254b40f88e76..1c47ac4ef4569 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -56,6 +56,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/flashattn_qkv.cpp + ${MLAS_SRC_DIR}/flashattn_gqa.cpp ${MLAS_SRC_DIR}/qkv_quant.cpp ${MLAS_SRC_DIR}/cast.cpp ${MLAS_SRC_DIR}/layernorm.cpp diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index e5a211c9fd11a..d3b7f25c6fdba 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -17,7 +17,12 @@ Quantized KV-cache GEMM helpers are implemented in MLAS: - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp` -- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (flash attention tiled kernel) +- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (quantized-KV flash attention tiled kernel) + +The non-quantized flash attention tiled kernel is implemented in MLAS: + +- `onnxruntime/core/mlas/lib/flashattn_gqa.cpp` (FP32-KV flash attention tiled kernel) +- `onnxruntime/core/mlas/inc/mlas.h` (`MlasFlashAttentionGQA` declaration and `MlasFlashAttentionGQAArgs`) The operator schema itself is defined in: @@ -48,12 +53,14 @@ At a high level, the CPU kernel executes GroupQueryAttention in these stages: The non-quantized and quantized paths share the surrounding validation, masking, softmax, and output flow. Their main difference is how the K/V cache is stored and read during QK and SV GEMMs. -The quantized path has two execution strategies: +Both the non-quantized and quantized paths have two execution strategies: - **Naive (full materialization)**: Computes the full `[S, T]` attention score matrix, applies masking and softmax, then computes the SV product. Simple but memory-intensive for long sequences. - **Flash Attention (tiled, online softmax)**: Processes K/V in L2-cache-sized blocks using the online softmax algorithm (Milakov & Gimelshein, 2018). Avoids materializing the full attention matrix, reducing peak memory from O(S×T) to O(S×Bc) per head. Multi-threaded via the MLAS thread pool. -The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path. +The quantized path uses `MlasFlashAttentionQuantizedKV` (`flashattn_qkv.cpp`); the non-quantized FP32 path uses `MlasFlashAttentionGQA` (`flashattn_gqa.cpp`). Both share the same tiling, masking, online-softmax, and flash-decoding structure. + +The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path (applies to both the quantized and non-quantized paths). ## Supported Cache Modes @@ -144,9 +151,9 @@ For quantized V cache, the CPU path calls `MlasSVGemm` with: As with QK GEMM, the default MLAS contract preserves the FP32 left-hand operand and dequantizes only the cached V values on the fly. -## Flash Attention Path +## Quantized Flash Attention Path -The flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix. +The quantized flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix. ### Algorithm @@ -204,6 +211,47 @@ The partials buffer is allocated alongside the per-thread scratch in a single al - Per-thread scratch: `scores[Bc]` (one float per KV block element) - Partials: `batch × num_heads × kv_chunks × (2 + H)` floats (m, l, and partial output per chunk) +## Non-Quantized Flash Attention Path + +The non-quantized flash attention path (`MlasFlashAttentionGQA`, in `flashattn_gqa.cpp`) is the FP32-KV-cache counterpart of the quantized path. It is selected for the `float` kernel specialization and reuses the same tiling, online-softmax, masking, and flash-decoding structure. + +### Differences from the Quantized Path + +- **Cache element type**: The present K/V cache is FP32, laid out as BNSH (`[batch, kv_num_heads, seqlen_present, head_size]`). There is no quantize-on-write or dequantize-on-read step. +- **QK GEMM**: Uses the single-threaded SGEMM primitive `MlasSgemmOperation(CblasNoTrans, CblasTrans, ...)` on an FP32 K block instead of `MlasQKGemm`. +- **SV accumulate**: Uses `MlasSgemmOperation(CblasNoTrans, CblasNoTrans, ..., beta)` with `beta = 0` for the first KV block and `beta = 1` afterwards (accumulate) instead of `MlasSVGemm`. +- **Cache concat**: New K/V tokens are appended into the FP32 present cache with `ConcatStateChunkGQA` before the tiled loop runs. + +### Algorithm + +For each (batch, head, q_block) tile: + +1. **QK GEMM** — `MlasSgemmOperation` of the query tile against a block slice of the FP32 K cache (Bc rows at a time) +1b. **Attention bias** — Add the corresponding tile of the bias tensor (if present) to QK scores +2. **Causal + local window masking** — Set masked positions to −∞ before softmax +3. **Online softmax** — Track running max `m` and sum `l`, rescale accumulated output with `exp(m_old − m_new)` +4. **SV accumulate** — `MlasSgemmOperation(..., beta)` accumulates `softmax(QK_block) × V_block` into the output tile +5. **Finalize** — Normalize accumulated output by `1/l` after all KV blocks are processed + +### Activation Conditions + +The non-quantized flash path is selected when ALL of the following hold: + +- The kernel specialization is `float` (FP16 uses the naive path) +- `ORT_GQA_DISABLE_FLASH_ATTENTION` environment variable is not set (or set to `0`) +- `total_sequence_length > 1` +- No softcap +- No smooth softmax +- No head sink +- No output QK capture +- `present_key` and `present_value` are provided + +Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, shared past/present buffers, and flash decoding are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path. + +### Block Sizes, Threading, and Flash Decoding + +Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`), and the two-phase flash-decoding strategy for single-token decode are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. + ## MLAS Dispatch Paths MLAS selects the best available quantized KV-cache GEMM implementation through the platform dispatch table. diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 12f61cddea18c..d66ed2cb0fb7d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -903,6 +903,316 @@ class GQAAttentionBase { return Status::OK(); } + // Non-quantized flash attention path. Only supports T = float. + // Concatenates new K/V into the FP32 present cache, then runs the tiled + // online-softmax kernel MlasFlashAttentionGQA (QK^T + softmax + S*V fused). + Status ApplyAttentionFlash( + const float* Q, // Q data [B, N, S, H] BNSH + const float* K, // K data [B, N_kv, L, H] or nullptr for packed_qkv + const float* V, // V data [B, N_kv, L, H] or nullptr for packed_qkv + const Tensor* attention_bias, // additive bias [B|1, N|1, S, T] or nullptr + const Tensor* past_key, // past K (float) + const Tensor* past_value, // past V (float) + Tensor* output, // output [B, S, N*H] float + Tensor* present_key, // present K (float) + Tensor* present_value, // present V (float) + const Tensor* seqlens_k, + GroupQueryAttentionParameters& parameters, + AllocatorPtr allocator, + OpKernelContext* context) const { + const bool is_prompt = parameters.is_first_prompt; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int head_size = parameters.head_size; + const int hidden_size = parameters.hidden_size; + const bool packed_qkv = parameters.is_packed_qkv; + + auto* tp = context->GetOperatorThreadPool(); + + int seqlen_past_kv_cache = 0; + if (past_key != nullptr && past_value != nullptr) { + seqlen_past_kv_cache = static_cast(past_key->Shape().GetDims()[2]); + } + int seqlen_present_kv_cache = present_key != nullptr + ? static_cast(present_key->Shape().GetDims()[2]) + : parameters.total_sequence_length; + + if (kv_sequence_length == 0) { + ORT_ENFORCE(parameters.total_sequence_length <= seqlen_past_kv_cache, + "total_seqlen (", parameters.total_sequence_length, ") exceeds past buffer size (", + seqlen_past_kv_cache, ") in shared KV mode"); + } + + ORT_RETURN_IF(present_key == nullptr || present_value == nullptr, + "present_key and present_value must be provided for flash attention"); + + const float* past_key_data = past_key != nullptr ? past_key->Data() : nullptr; + float* present_key_data = present_key->MutableData(); + const float* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; + float* present_value_data = present_value->MutableData(); + + bool past_present_share_buffer = (past_key_data == present_key_data) && + (past_value_data == present_value_data); + + const int32_t* seqlens_k_data = seqlens_k->Data(); + + // Attention bias setup + const float* attention_bias_data = nullptr; + int attention_bias_seqlen_stride = 0; + bool attention_bias_broadcast_batch = true; + bool attention_bias_broadcast_head = true; + if (attention_bias != nullptr) { + attention_bias_data = attention_bias->Data(); + auto bias_shape = attention_bias->Shape().GetDims(); + attention_bias_seqlen_stride = static_cast(bias_shape[3]); + attention_bias_broadcast_batch = (bias_shape[0] == 1); + attention_bias_broadcast_head = (bias_shape[1] == 1); + } + + // K/V base pointers (FP32, new tokens) + const float* k_base = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; + const float* v_base = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; + + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); + const size_t kv_input_chunk_length = SafeInt(kv_sequence_length) * head_size; + const size_t past_buff_chunk_length = SafeInt(seqlen_past_kv_cache) * head_size; + const size_t present_buff_chunk_length = SafeInt(seqlen_present_kv_cache) * head_size; + + // ---- Phase 1: Concat new K/V into present cache ---- + // We must do this first so the flash attention kernel can read the full present cache. + if (present_key_data && !past_present_share_buffer) { + memset(present_key_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_length * sizeof(float)); + memset(present_value_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_length * sizeof(float)); + } + + // Concat K and V caches (parallelize over batch * kv_num_heads) + { + const size_t concat_loop_len = batch_size * kv_num_heads_; + TensorOpCost concat_cost; + concat_cost.compute_cycles = static_cast(kv_sequence_length * head_size); + concat_cost.bytes_loaded = static_cast((past_buff_chunk_length + kv_input_chunk_length) * sizeof(float)); + concat_cost.bytes_stored = static_cast(present_buff_chunk_length * sizeof(float)); + + ThreadPool::TryParallelFor(tp, concat_loop_len, concat_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t kv_idx = begin; kv_idx != end; ++kv_idx) { + const size_t batch_index = kv_idx / kv_num_heads_; + const size_t kv_head_index = kv_idx % kv_num_heads_; + const size_t total_seqlen = SafeInt(seqlens_k_data[batch_index]) + 1; + + size_t past_seqlen; + if (past_key == nullptr) { + past_seqlen = 0; + } else if (kv_sequence_length == 0) { + past_seqlen = total_seqlen; + } else if (is_prompt) { + past_seqlen = 0; + } else { + past_seqlen = total_seqlen - sequence_length; + } + const size_t past_chunk_length = past_seqlen * head_size; + + // Concat K + const float* k_new; + if (packed_qkv) { + k_new = k_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + k_new = k_base + kv_input_chunk_length * kv_idx; + } + ConcatStateChunkGQA(past_key_data, k_new, present_key_data, + present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, + past_present_share_buffer, kv_idx); + + // Concat V + const float* v_new; + if (packed_qkv) { + v_new = v_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + v_new = v_base + kv_input_chunk_length * kv_idx; + } + ConcatStateChunkGQA(past_value_data, v_new, present_value_data, + present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, + past_present_share_buffer, kv_idx); + } + }); + } + + // ---- Phase 2: Flash Attention with FP32 KV cache ---- + // Compute L2-aware block sizes (same formula as MHA flash attention). + const auto& env = Env::Default(); + int l2_cache_size = env.GetL2CacheSize(); + + int kv_block_size = l2_cache_size / (static_cast(sizeof(float)) * 4 * (head_size + head_size)); + kv_block_size = std::max(kv_block_size, 1); + int q_block_size = std::min(kv_block_size, 2 * head_size); + + // The flash kernel uses a single (past_seqlen, total_seqlen) pair for all batch items. + // When batch items have different seqlens_k (ragged), fall back to per-batch invocation + // so each batch item gets its own correct causal offset. + int max_total_seqlen = 0; + int min_total_seqlen = std::numeric_limits::max(); + int common_past_seqlen = 0; + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + max_total_seqlen = std::max(max_total_seqlen, total_sl); + min_total_seqlen = std::min(min_total_seqlen, total_sl); + } + const bool ragged_seqlens = (max_total_seqlen != min_total_seqlen); + + if (ragged_seqlens) { + common_past_seqlen = -1; // sentinel: per-batch + } else if (past_key == nullptr || is_prompt) { + common_past_seqlen = 0; + } else if (kv_sequence_length == 0) { + // Shared buffer mode: each batch item has its own past_seqlen. + common_past_seqlen = -1; // sentinel: per-batch + } else { + common_past_seqlen = max_total_seqlen - sequence_length; + } + + // Cap block sizes + kv_block_size = std::min(kv_block_size, max_total_seqlen); + q_block_size = std::min(q_block_size, sequence_length); + + int thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); + thread_count = std::max(thread_count, 1); + + // Flash decoding: for decode (sequence_length==1), partition KV across threads + // to improve parallelism when batch*heads < thread_count. + const int kv_chunk_count = (max_total_seqlen + kv_block_size - 1) / kv_block_size; + const bool use_flash_decoding = (sequence_length == 1 && + batch_size * num_heads_ < thread_count && + kv_chunk_count > 1); + + size_t buffer_size_per_thread; + size_t partials_buffer_bytes = 0; + if (use_flash_decoding) { + // Flash decoding: per-thread scratch only needs scores[kv_block_size] + buffer_size_per_thread = static_cast(kv_block_size) * sizeof(float); + // Partials: [batch * num_heads * kv_chunk_count * (2 + head_size)] floats + partials_buffer_bytes = static_cast(batch_size) * num_heads_ * + kv_chunk_count * (2 + head_size) * sizeof(float); + } else { + buffer_size_per_thread = + (static_cast(q_block_size) * 2 + // l + m + static_cast(q_block_size) * static_cast(kv_block_size) + // scores + static_cast(q_block_size) * static_cast(head_size)) * // temp_output + sizeof(float); + } + size_t total_buffer_bytes = buffer_size_per_thread * thread_count + partials_buffer_bytes; + auto flash_buffer_alloc = allocator->Alloc(total_buffer_bytes); + BufferUniquePtr flash_buffer(flash_buffer_alloc, BufferDeleter(allocator)); + + // Partials buffer is placed after per-thread scratch + float* partials_ptr = use_flash_decoding + ? reinterpret_cast(reinterpret_cast(flash_buffer_alloc) + + buffer_size_per_thread * thread_count) + : nullptr; + + const float scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + + // If all batch items share the same past_seqlen, use the unified flash kernel. + // Otherwise, fall back to per-batch invocation. + if (common_past_seqlen >= 0) { + MlasFlashAttentionGQAArgs args; + args.batch_size = batch_size; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = max_total_seqlen; + args.head_size = head_size; + args.past_seqlen = common_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = kv_block_size; + args.scale = scale; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + args.query = Q; + args.q_batch_stride = packed_qkv + ? static_cast(packed_batch_stride) + : static_cast(SafeInt(num_heads_) * sequence_length * head_size); + args.k_cache = present_key_data; + args.v_cache = present_value_data; + args.output = output->MutableData(); + args.attention_bias = attention_bias_data; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = attention_bias_broadcast_batch; + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = partials_ptr; + args.kv_chunk_count = kv_chunk_count; + + MlasFlashAttentionGQA(&args, tp); + } else { + // Per-batch handling for variable past_seqlen (shared KV buffer mode or ragged seqlens) + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + int batch_past_seqlen = (past_key == nullptr || is_prompt) + ? 0 + : std::max(0, total_sl - sequence_length); + + MlasFlashAttentionGQAArgs args; + args.batch_size = 1; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = total_sl; + args.head_size = head_size; + args.past_seqlen = batch_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = std::min(kv_block_size, total_sl); + args.scale = scale; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + + // Offset Q and output for this batch + const ptrdiff_t q_batch_stride_elems = packed_batch_stride > 0 + ? packed_batch_stride + : static_cast(SafeInt(num_heads_) * sequence_length * head_size); + args.query = Q + static_cast(b) * static_cast(q_batch_stride_elems); + args.q_batch_stride = SafeInt(num_heads_) * sequence_length * head_size; + args.k_cache = present_key_data + + static_cast(b) * kv_num_heads_ * present_buff_chunk_length; + args.v_cache = present_value_data + + static_cast(b) * kv_num_heads_ * present_buff_chunk_length; + args.output = output->MutableData() + + static_cast(b) * sequence_length * hidden_size; + + // Slice attention bias for this batch (the kernel sees batch_size=1, so batch_idx=0 inside). + // Bias shape is [batch|1, num_heads|1, S, T]; the batch stride uses the actual head + // extent (1 when the head dim is broadcast). + const float* batch_bias = attention_bias_data; + if (attention_bias_data != nullptr && !attention_bias_broadcast_batch) { + const size_t bias_head_extent = attention_bias_broadcast_head ? 1 : static_cast(num_heads_); + batch_bias += static_cast(b) * bias_head_extent * sequence_length * attention_bias_seqlen_stride; + } + args.attention_bias = batch_bias; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = true; // batch offset handled above + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = nullptr; // per-batch doesn't use flash decoding + args.kv_chunk_count = 0; + + MlasFlashAttentionGQA(&args, tp); + } + } + + return Status::OK(); + } + private: // Helper function to compute the attention probs. It does 2 things: // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 61ae474703213..29d372eb7a4bb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -343,6 +343,27 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Compute the attention score and apply the score to V const T* k_data = packed_qkv ? nullptr : k_rotary; const T* v_data = packed_qkv ? nullptr : V.Get().Data(); + + // Non-quantized flash attention path (float only). Uses the tiled online-softmax + // kernel to avoid materializing the full attention score matrix. Falls back to the + // naive path when an unsupported feature is requested (softcap, smooth softmax, + // head sink, or QK output). + if constexpr (std::is_same_v) { + const bool use_flash = !disable_gqa_flash_ && + parameters.total_sequence_length > 1 && + softcap_ == 0.0f && + !use_smooth_softmax_ && + head_sink_data == nullptr && + output_qk == nullptr && + present_k != nullptr && present_v != nullptr; + if (use_flash) { + return ApplyAttentionFlash(q_rotary, k_data, v_data, + attention_bias, past_key, past_value, + output, present_k, present_v, seqlens_k, + parameters, allocator, context); + } + } + return ApplyAttention(q_rotary, k_data, v_data, head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, output_qk, seqlens_k, parameters, allocator, context); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 99b72dc756663..ec2398dd1ee0f 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -2281,6 +2281,63 @@ MlasFlashAttention( MLAS_THREADPOOL* ThreadPool ); +// +// Flash Attention for non-quantized (FP32) GroupQueryAttention KV cache. +// +// Adapts the online-softmax tiled algorithm to operate on an FP32 present +// K/V cache laid out as BNSH ([batch, kv_num_heads, seqlen_present, head_size]). +// Supports GQA head grouping, causal masking, local window attention, +// additive attention bias, and an optional flash-decoding split over the KV +// sequence dimension for the single-token decode case. +// +struct MlasFlashAttentionGQAArgs { + int batch_size; + int num_heads; // number of query heads + int kv_num_heads; // number of key/value heads (num_heads % kv_num_heads == 0) + int sequence_length; // number of new query tokens (S) + int total_seqlen; // total tokens (past + new) for this invocation (T) + int head_size; // per-head size (H) + int past_seqlen; // causal offset (number of cached tokens before the new ones) + int local_window_size; // -1 disables local window masking + int seqlen_present_kv; // sequence dimension of the present K/V buffer + int q_block_size; // query tile size (Br) + int kv_block_size; // key/value tile size (Bc) + float scale; // QK scaling factor + int thread_count; // number of partitions / threads + float* buffer; // per-thread scratch (+ optional flash-decoding partials) + size_t buffer_size_per_thread; + + const float* query; // [batch, num_heads, sequence_length, head_size] BNSH + size_t q_batch_stride; // element stride between consecutive batches in `query` + // (num_heads*S*H for unpacked, (num_heads+2*kv_num_heads)*S*H for packed QKV) + const float* k_cache; // [batch, kv_num_heads, seqlen_present, head_size] FP32 + const float* v_cache; // [batch, kv_num_heads, seqlen_present, head_size] FP32 + float* output; // [batch, sequence_length, num_heads, head_size] BSNH + + const float* attention_bias; // [batch|1, num_heads|1, S, T] additive bias, or nullptr + int attention_bias_seqlen_stride; + bool attention_bias_broadcast_batch; + bool attention_bias_broadcast_head; + + // Flash decoding (sequence_length == 1): partition KV across threads. + // Set flash_decoding_partials != nullptr to enable; otherwise the standard + // per-(batch, head, q_block) partitioning is used. + float* flash_decoding_partials; + int kv_chunk_count; +}; + +/** + * @brief FP32 Flash Attention for GroupQueryAttention with an FP32 KV cache. + * @param args Arguments + * @param ThreadPool Thread pool + */ +void +MLASCALL +MlasFlashAttentionGQA( + MlasFlashAttentionGQAArgs* args, + MLAS_THREADPOOL* ThreadPool +); + /** * @brief Enumeration of supported GELU algorithm variants. * diff --git a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp new file mode 100644 index 0000000000000..2f62b9a8b0735 --- /dev/null +++ b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp @@ -0,0 +1,610 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + flashattn_gqa.cpp + +Abstract: + + Flash Attention kernel for the non-quantized (FP32) GroupQueryAttention + KV cache. + + Adapts the online-softmax tiled algorithm from flashattn.cpp to operate on + an FP32 present K/V cache laid out as BNSH + ([batch, kv_num_heads, seqlen_present, head_size]) and to support GQA head + grouping (num_heads % kv_num_heads == 0), causal masking, local window + attention, additive attention bias, and an optional flash-decoding split + over the KV sequence dimension for single-token decode. + + QK^T and S*V use the single-threaded SGEMM primitive MlasSgemmOperation; + the outer parallelism is provided by MlasExecuteThreaded. + +--*/ + +#include +#include +#include +#include + +#include "mlasi.h" + +void +MlasFlashAttentionGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t q_block_size = static_cast(args->q_block_size); + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t sequence_length = static_cast(args->sequence_length); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: one per (batch, head, q_block) + const ptrdiff_t q_chunk_count = (sequence_length + q_block_size - 1) / q_block_size; + const ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t batch_idx = task_index; + ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size; + batch_idx /= q_chunk_count; + ptrdiff_t head_idx = batch_idx % num_heads; + batch_idx /= num_heads; + + // Per-thread buffer layout: + // l[q_block_size] - running sum for online softmax + // m[q_block_size] - running max for online softmax + // scores[q_block_size * kv_block_size] - QK scores (S) + // temp_output[q_block_size * head_size] - accumulated output + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* l = reinterpret_cast(buffer_ptr); + float* m = l + q_block_size; + float* scores = m + q_block_size; + float* temp_output = scores + q_block_size * kv_block_size; + + // Initialize running state + for (ptrdiff_t t = 0; t < q_block_size; ++t) { + m[t] = std::numeric_limits::lowest(); + l[t] = 0.0f; + } + memset(temp_output, 0, static_cast(q_block_size * head_size) * sizeof(float)); + + const size_t row_size_q = static_cast(std::min(q_block_size, sequence_length - q_idx)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // K/V cache pointers. Layout: [batch, kv_num_heads, seqlen_present, head_size] + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, seq, head_size]. The batch stride is + // supplied separately (args->q_batch_stride) so the kernel works with both the + // standard BNSH layout and packed-QKV input where Q/K/V are interleaved per batch. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(sequence_length) * static_cast(head_size) + + static_cast(q_idx) * static_cast(head_size); + + // Iterate over KV blocks + for (ptrdiff_t ir = 0; ir < total_seqlen; ir += kv_block_size) { + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Step 1: QK^T GEMM with FP32 K block + const float* k_block = k_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasTrans, + row_size_q, // M + row_size_kv, // N + static_cast(head_size), // K + scale, // alpha + q_ptr, // A (FP32 query) + static_cast(head_size), // lda + k_block, // B (FP32 K block) + static_cast(head_size), // ldb + 0.0f, // beta + scores, // C (output scores) + row_size_kv // ldc + ); + + // Step 1b: Apply attention bias (additive) if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = + static_cast(sequence_length) * bias_seqlen_stride; + // The bias tensor has shape [batch|1, num_heads|1, S, T]; the batch + // stride uses the actual head extent (1 when the head dim is broadcast). + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + // Add bias tile: bias[q_idx + irow, ir + jcol] + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + const float* bias_row = args->attention_bias + bias_offset + + (q_idx + irow) * bias_seqlen_stride + ir; + float* s_row = scores + irow * static_cast(row_size_kv); + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + s_row[jcol] += bias_row[jcol]; + } + } + } + + // Step 2: Apply causal mask and Step 3: Online softmax update + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float* p = scores + irow * static_cast(row_size_kv); + const ptrdiff_t global_q_pos = past_seqlen + q_idx + irow; + const ptrdiff_t causal_limit = global_q_pos + 1; // can attend to positions [0, causal_limit) + + // Apply causal masking + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + p[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + p[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Online softmax: find row max, update running max +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv); +#endif + + // If the entire row is masked (all scores are -inf), zero the scores + // so the S*V GEMM contributes nothing and skip the softmax state update. + if (rowmax == std::numeric_limits::lowest()) { + memset(p, 0, row_size_kv * sizeof(float)); + continue; + } + + float m_old = m[irow]; + m[irow] = std::max(m[irow], rowmax); + float m_diff = m_old - m[irow]; // <= 0 + + // Compute exp(score - m_new) for each element + float negmax = -m[irow]; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#endif + + // Rescale previous state + if (ir != 0) { + float exp_diff = std::exp(m_diff); + l[irow] = exp_diff * l[irow] + rowsum; + + // Rescale accumulated output + float* out_row = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + out_row[icol] *= exp_diff; + } + } else { + l[irow] = rowsum; + } + } + + // Step 4: Accumulate O += S_exp * V_block + const float* v_block = v_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasNoTrans, + row_size_q, // M + static_cast(head_size), // N + row_size_kv, // K + 1.0f, // alpha + scores, // A (exp softmax scores) + row_size_kv, // lda + v_block, // B (FP32 V block) + static_cast(head_size), // ldb + ir == 0 ? 0.0f : 1.0f, // beta (accumulate after first block) + temp_output, // C (accumulated output) + static_cast(head_size) // ldc + ); + } + + // Final: normalize output by l (softmax denominator) + // Output layout: [batch, sequence_length, num_heads, head_size] + float* output_row = args->output + + (static_cast(batch_idx) * static_cast(sequence_length) + + static_cast(q_idx)) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + const ptrdiff_t output_row_stride = num_heads * head_size; + + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float inv_l = (l[irow] > 0.0f) ? (1.0f / l[irow]) : 0.0f; + float* src = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + output_row[icol] = src[icol] * inv_l; + } + output_row += output_row_stride; + } + } +} + +// +// Flash Decoding: Phase 1 - parallel partial attention over (batch, head, kv_chunk). +// Each task computes attention for one KV chunk and stores (m, l, partial_output) +// into the partials buffer. +// +void +MlasFlashDecodingGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + // Partials layout per entry: [m, l, output[head_size]] + const ptrdiff_t partial_stride = 2 + head_size; + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: (batch, head, kv_chunk) + const ptrdiff_t total_task_count = batch_size * num_heads * kv_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + // Decompose task_index into (batch_idx, head_idx, kv_chunk_idx) + ptrdiff_t tmp = task_index; + ptrdiff_t kv_chunk_idx = tmp % kv_chunk_count; + tmp /= kv_chunk_count; + ptrdiff_t head_idx = tmp % num_heads; + ptrdiff_t batch_idx = tmp / num_heads; + + // Per-thread scratch buffer: just scores[kv_block_size] + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* scores = reinterpret_cast(buffer_ptr); + + // KV block range for this chunk + const ptrdiff_t ir = kv_chunk_idx * kv_block_size; + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // K/V cache pointers + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, 1, head_size] (sequence_length=1). + // The batch stride is supplied separately to support packed-QKV input. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(head_size); + + // Step 1: QK^T GEMM for this KV chunk + const float* k_block = k_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasTrans, + 1, // M (single query row) + row_size_kv, // N + static_cast(head_size), // K + scale, // alpha + q_ptr, // A (FP32 query) + static_cast(head_size), // lda + k_block, // B (FP32 K block) + static_cast(head_size), // ldb + 0.0f, // beta + scores, // C (output scores) + row_size_kv // ldc + ); + + // Step 1b: Apply attention bias if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = bias_seqlen_stride; // S=1 + // The bias tensor has shape [batch|1, num_heads|1, S, T]; the batch stride + // uses the actual head extent (1 when the head dim is broadcast). + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + const float* bias_row = args->attention_bias + bias_offset + ir; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + scores[jcol] += bias_row[jcol]; + } + } + + // Step 2: Apply causal mask + const ptrdiff_t global_q_pos = past_seqlen; // sequence_length=1, q_idx=0 + const ptrdiff_t causal_limit = global_q_pos + 1; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Step 3: Compute local softmax statistics (m, l) and exp scores +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(scores, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(scores, row_size_kv); +#endif + + // Pointer to this task's partial in the partials buffer + const ptrdiff_t partial_index = + (batch_idx * num_heads + head_idx) * kv_chunk_count + kv_chunk_idx; + float* partial = args->flash_decoding_partials + partial_index * partial_stride; + float* partial_m = partial; + float* partial_l = partial + 1; + float* partial_output = partial + 2; + + if (rowmax == std::numeric_limits::lowest()) { + // Entire chunk is masked: store sentinel + *partial_m = std::numeric_limits::lowest(); + *partial_l = 0.0f; + memset(partial_output, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + *partial_m = rowmax; + float negmax = -rowmax; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#endif + *partial_l = rowsum; + + // Step 4: S_exp * V_block -> partial_output + const float* v_block = v_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasNoTrans, + 1, // M + static_cast(head_size), // N + row_size_kv, // K + 1.0f, // alpha + scores, // A (exp softmax scores) + row_size_kv, // lda + v_block, // B (FP32 V block) + static_cast(head_size), // ldb + 0.0f, // beta (overwrite) + partial_output, // C (output for this chunk) + static_cast(head_size) // ldc + ); + } +} + +// +// Flash Decoding: Phase 2 - reduce partials for each (batch, head) into final output. +// +void +MlasFlashDecodingGQAReduceThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + const ptrdiff_t thread_count = static_cast(args->thread_count); + const ptrdiff_t partial_stride = 2 + head_size; + + // Total reduction tasks: one per (batch, head) + const ptrdiff_t total_task_count = batch_size * num_heads; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t head_idx = task_index % num_heads; + ptrdiff_t batch_idx = task_index / num_heads; + + // Pointer to this (batch, head)'s partials: kv_chunk_count entries + const float* partials_base = args->flash_decoding_partials + + task_index * kv_chunk_count * partial_stride; + + // Find global max across all chunks + float global_m = std::numeric_limits::lowest(); + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + float chunk_m = partials_base[c * partial_stride]; + global_m = std::max(global_m, chunk_m); + } + + // Output layout: [batch, sequence_length=1, num_heads, head_size] + float* output_ptr = args->output + + static_cast(batch_idx) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + + // If all chunks are masked, output zeros + if (global_m == std::numeric_limits::lowest()) { + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + // Accumulate rescaled outputs and l values + float global_l = 0.0f; + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + const float* partial = partials_base + c * partial_stride; + float chunk_m = partial[0]; + float chunk_l = partial[1]; + const float* chunk_output = partial + 2; + + if (chunk_l <= 0.0f) { + continue; // masked chunk contributes nothing + } + + float rescale = std::exp(chunk_m - global_m); + global_l += rescale * chunk_l; + + // partial_output = S_exp * V where sum(S_exp) = l_c (unnormalized). + // Rescale by exp(m_c - global_m) to align all chunks to the same max. + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] += rescale * chunk_output[i]; + } + } + + // output = sum_c(rescale_c * partial_output_c) / global_l + float inv_l = (global_l > 0.0f) ? (1.0f / global_l) : 0.0f; + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] *= inv_l; + } + } +} + +void +MLASCALL +MlasFlashAttentionGQA( + MlasFlashAttentionGQAArgs* args, + MLAS_THREADPOOL* ThreadPool +) +{ + if (args->flash_decoding_partials != nullptr && args->sequence_length == 1) { + // Flash decoding: two-phase approach. + // Phase 1: parallel partial computation over (batch, head, kv_chunk). + MlasExecuteThreaded( + MlasFlashDecodingGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + // Phase 2: reduce partials into final output (parallel over batch*heads). + MlasExecuteThreaded( + MlasFlashDecodingGQAReduceThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } else { + MlasExecuteThreaded( + MlasFlashAttentionGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } +} From 980b497e418bfc7af4d3cbdfeb150197bf415c1b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Jun 2026 02:51:24 +0000 Subject: [PATCH 2/5] benchmark and doc --- docs/contrib_ops/cpu/gqa.md | 45 ++++ onnxruntime/core/mlas/lib/flashattn_gqa.cpp | 12 ++ .../transformers/benchmark_gqa_cpu_flash.py | 197 +++++++++++++----- 3 files changed, 201 insertions(+), 53 deletions(-) diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index d3b7f25c6fdba..840dcea5b0cfd 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -233,6 +233,18 @@ For each (batch, head, q_block) tile: 4. **SV accumulate** — `MlasSgemmOperation(..., beta)` accumulates `softmax(QK_block) × V_block` into the output tile 5. **Finalize** — Normalize accumulated output by `1/l` after all KV blocks are processed +#### Causal early-termination + +During prefill, every KV block whose start index is at or beyond the largest global query +position in the current q_block is fully causally masked and contributes nothing. The kernel +computes a per-q_block bound +`kv_causal_limit = past_seqlen + q_idx + row_size_q` and breaks out of the KV loop once +`ir >= kv_causal_limit`, instead of computing and then discarding the masked upper-triangle +QK/SV GEMMs. This skips roughly half of the QK/SV work for square prefill (S = T) and is the +main reason the FP32 flash path is faster than naive even at short sequence lengths +(see the benchmark results below). Decode (q_block of size 1 at the cache tail) attends to all +KV positions, so the bound equals `total_seqlen` and nothing is skipped. + ### Activation Conditions The non-quantized flash path is selected when ALL of the following hold: @@ -476,7 +488,40 @@ Flash decoding IS active (batch×heads=4 < threads=8, KV partitioned across idle | 4096 (N=32) | +2131 | +87 | 24.5x | **Summary**: The flash path's primary benefit for prefill is **memory reduction** — avoiding the full O(N×S×T) attention matrix. For S=4096 with 16 heads, the naive path allocates ~1 GB for attention scores while the flash path uses ~80 MB regardless of sequence length. The prefill latency speedup (1.2–2.7x at kernel level, 1.2–1.9x at operator level) comes from improved cache locality. For decode, the tiled kernel provides 1.2–1.8x kernel-level speedup from fused single-pass KV access; at operator level the gain is visible for T≥1024 but masked by KV concat overhead at shorter sequences. When flash decoding is active (batch×heads < threads), KV partitioning across idle threads yields an additional 2–5x speedup for long sequences. +### Non-Quantized (FP32) Flash Attention vs Naive benchmark results + +Measured on an AMD EPYC 7763 (32 logical / 16 physical cores), threads=8, FP32 KV cache, +`B=1, num_heads=16, kv_num_heads=8, head_size=128`. Operator-level, measured with: + +```bash +python onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py \ + --fp32 --prompt_only --warmup 10 --repeats 30 +``` + +#### Latency — Prefill (S = T, prompt phase) + +| Seq Length | Naive (ms) | Flash (ms) | Speedup | +|---:|---:|---:|---:| +| 512 | 5.8\u20138.4 | 4.2\u20135.3 | 1.4\u20131.6x | +| 1024 | 25\u201329 | 13\u201318 | 1.6\u20132.0x | +| 2048 | 87\u2013118 | 52\u201365 | 1.5\u20132.0x | +| 4096 | 365\u2013380 | 213\u2013234 | 1.6\u20131.7x | + +The FP32 flash path is faster than naive across all measured prefill lengths. With the causal +early-termination described above, roughly half of the QK/SV work (the causally masked +upper triangle of the square prefill attention matrix) is skipped entirely, which more than +offsets the intrinsic per-KV-block online-softmax overhead (running max/exp/output rescale). +The same advantage holds single-threaded (1.4\u20131.8x at threads=1), confirming the gain is +algorithmic rather than purely from threading. + +#### Latency — Decode (S = 1, token generation) +For single-token decode at this head configuration (`batch\u00d7heads = 16 > threads = 8`, so +flash decoding KV-partitioning is not active), the workload per `Run` is tiny and dominated +by KV-cache concatenation overhead. Operator-level decode latency is therefore noisy and +roughly at parity between the two paths, with longer total sequence lengths (T\u22652049) +tending to favor flash. The FP32 decode path is not the target of the prefill-oriented +causal early-termination optimization. ## Current CPU Limitations The current CPU GroupQueryAttention implementation has a few important limitations: diff --git a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp index 2f62b9a8b0735..25f3733f59cca 100644 --- a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp +++ b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp @@ -126,8 +126,20 @@ MlasFlashAttentionGQAThreaded( static_cast(head_idx) * static_cast(sequence_length) * static_cast(head_size) + static_cast(q_idx) * static_cast(head_size); + // Causal early-termination bound: the largest global query position in this + // q_block is (past_seqlen + q_idx + row_size_q - 1), so it can attend to KV + // positions up to that index inclusive. Any KV block starting at or beyond + // (past_seqlen + q_idx + row_size_q) is fully causally masked for every row in + // the block, so it contributes nothing and can be skipped. This avoids the + // wasted QK/SV GEMMs over the causal upper triangle during prefill. + const ptrdiff_t kv_causal_limit = + past_seqlen + q_idx + static_cast(row_size_q); + // Iterate over KV blocks for (ptrdiff_t ir = 0; ir < total_seqlen; ir += kv_block_size) { + if (ir >= kv_causal_limit) { + break; + } const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); // Step 1: QK^T GEMM with FP32 K block diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py index 77ac08cf50d6c..7dbcb16a75973 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py @@ -106,6 +106,70 @@ def create_quantized_gqa_graph( return model.SerializeToString() +def create_fp32_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + buffer_seq_len=None, +): + """Create an ONNX graph for GroupQueryAttention with a non-quantized FP32 KV cache.""" + if buffer_seq_len is None: + buffer_seq_len = seq_len + + hidden_size = num_heads * head_size + kv_hidden_size = kv_num_heads * head_size + + inputs = [ + "query", + "key", + "value", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", + ] + + node = helper.make_node( + op_type="GroupQueryAttention", + inputs=inputs, + outputs=["output", "present_key", "present_value"], + name="GroupQueryAttention_0", + num_heads=num_heads, + kv_num_heads=kv_num_heads, + domain="com.microsoft", + ) + + graph_input = [ + helper.make_tensor_value_info("query", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info("key", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info("value", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info( + "past_key", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info( + "past_value", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [batch_size]), + helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), + ] + + graph_output = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info( + "present_key", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info( + "present_value", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + ] + + graph = helper.make_graph([node], "BenchGQA", graph_input, graph_output) + model = helper.make_model(graph) + return model.SerializeToString() + + def benchmark_gqa( batch_size, seq_len, @@ -117,6 +181,7 @@ def benchmark_gqa( past_seq_len=0, warmup=5, repeats=20, + non_quantized=False, ): """Benchmark a single GQA configuration. Returns elapsed time in ms.""" hidden_size = num_heads * head_size @@ -126,54 +191,76 @@ def benchmark_gqa( total_seqlen = past_seq_len + seq_len buffer_seq_len = total_seqlen - onnx_model_str = create_quantized_gqa_graph( - batch_size, - seq_len, - num_heads, - kv_num_heads, - head_size, - quant_type, - bit_width, - buffer_seq_len=buffer_seq_len, - ) - sess_options = SessionOptions() sess_options.intra_op_num_threads = 8 - sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - # Generate inputs np.random.seed(42) query = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, hidden_size)).astype(np.float32) key = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) value = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) - - cache_dtype = np.uint8 if bit_width == 4 else np.int8 - past_k = np.random.randint( - 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 - ).view(cache_dtype) - past_v = np.random.randint( - 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 - ).view(cache_dtype) - seqlens_k = np.array([total_seqlen - 1] * batch_size, dtype=np.int32) total_seq = np.array([total_seqlen], dtype=np.int32) - per_channel = quant_type == "PER_CHANNEL" - scale_size = kv_num_heads * head_size if per_channel else 1 - k_scale = np.full(scale_size, 0.01, dtype=np.float32) - v_scale = np.full(scale_size, 0.01, dtype=np.float32) - - feeds = { - "query": query, - "key": key, - "value": value, - "past_key": past_k, - "past_value": past_v, - "seqlens_k": seqlens_k, - "total_sequence_length": total_seq, - "k_scale": k_scale, - "v_scale": v_scale, - } + if non_quantized: + onnx_model_str = create_fp32_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + buffer_seq_len=buffer_seq_len, + ) + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + past_k = np.random.uniform(-0.5, 0.5, (batch_size, kv_num_heads, buffer_seq_len, head_size)).astype(np.float32) + past_v = np.random.uniform(-0.5, 0.5, (batch_size, kv_num_heads, buffer_seq_len, head_size)).astype(np.float32) + + feeds = { + "query": query, + "key": key, + "value": value, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + } + else: + onnx_model_str = create_quantized_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + buffer_seq_len=buffer_seq_len, + ) + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + cache_dtype = np.uint8 if bit_width == 4 else np.int8 + past_k = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + past_v = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + + per_channel = quant_type == "PER_CHANNEL" + scale_size = kv_num_heads * head_size if per_channel else 1 + k_scale = np.full(scale_size, 0.01, dtype=np.float32) + v_scale = np.full(scale_size, 0.01, dtype=np.float32) + + feeds = { + "query": query, + "key": key, + "value": value, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + "k_scale": k_scale, + "v_scale": v_scale, + } # Warmup for _ in range(warmup): @@ -242,20 +329,21 @@ def run_benchmarks(args): "past_seq_len": 2048, } ) - # INT4 prefill - configs.append( - { - "label": "Prefill S=2048 INT4", - "batch_size": 1, - "seq_len": 2048, - "num_heads": 16, - "kv_num_heads": 8, - "head_size": 128, - "quant_type": "PER_TENSOR", - "bit_width": 4, - "past_seq_len": 0, - } - ) + # INT4 prefill (quantized mode only) + if not args.fp32: + configs.append( + { + "label": "Prefill S=2048 INT4", + "batch_size": 1, + "seq_len": 2048, + "num_heads": 16, + "kv_num_heads": 8, + "head_size": 128, + "quant_type": "PER_TENSOR", + "bit_width": 4, + "past_seq_len": 0, + } + ) warmup = args.warmup repeats = args.repeats @@ -263,13 +351,15 @@ def run_benchmarks(args): # Save and restore env var to avoid side effects on callers saved_env = os.environ.get("ORT_GQA_DISABLE_FLASH_ATTENTION") + kv_mode = "FP32 (non-quantized)" if args.fp32 else "INT8/INT4 quantized" print("\nBenchmark: CPU GroupQueryAttention — Flash vs Naive") - print(f"Threads: {8}, Warmup: {warmup}, Repeats: {repeats}") + print(f"KV cache: {kv_mode}, Threads: {8}, Warmup: {warmup}, Repeats: {repeats}") print(f"{'Config':<25} {'Naive (ms)':>12} {'Flash (ms)':>12} {'Speedup':>10}") print("-" * 62) for cfg in configs: label = cfg.pop("label") + cfg["non_quantized"] = args.fp32 # Flash path (default) os.environ.pop("ORT_GQA_DISABLE_FLASH_ATTENTION", None) @@ -296,5 +386,6 @@ def run_benchmarks(args): parser.add_argument("--repeats", type=int, default=20, help="Measurement iterations") parser.add_argument("--decode_only", action="store_true", help="Only run decode benchmarks") parser.add_argument("--prompt_only", action="store_true", help="Only run prompt benchmarks") + parser.add_argument("--fp32", action="store_true", help="Use non-quantized FP32 KV cache instead of quantized") args = parser.parse_args() run_benchmarks(args) From 125172e29a843b49f6b037f3aef20c42d0998112 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Jun 2026 02:56:59 +0000 Subject: [PATCH 3/5] Use SafeInt for FP32 flash attention scratch buffer sizing Guard the per-thread scratch and flash-decoding partials buffer size computations against size_t overflow for large or malformed shapes, matching the SafeInt usage elsewhere in this file. --- .../contrib_ops/cpu/bert/gqa_attention_base.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index d66ed2cb0fb7d..caa6a89f02fd7 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -1096,18 +1096,18 @@ class GQAAttentionBase { size_t partials_buffer_bytes = 0; if (use_flash_decoding) { // Flash decoding: per-thread scratch only needs scores[kv_block_size] - buffer_size_per_thread = static_cast(kv_block_size) * sizeof(float); + buffer_size_per_thread = SafeInt(kv_block_size) * sizeof(float); // Partials: [batch * num_heads * kv_chunk_count * (2 + head_size)] floats - partials_buffer_bytes = static_cast(batch_size) * num_heads_ * + partials_buffer_bytes = SafeInt(batch_size) * num_heads_ * kv_chunk_count * (2 + head_size) * sizeof(float); } else { buffer_size_per_thread = - (static_cast(q_block_size) * 2 + // l + m - static_cast(q_block_size) * static_cast(kv_block_size) + // scores - static_cast(q_block_size) * static_cast(head_size)) * // temp_output + (SafeInt(q_block_size) * 2 + // l + m + SafeInt(q_block_size) * kv_block_size + // scores + SafeInt(q_block_size) * head_size) * // temp_output sizeof(float); } - size_t total_buffer_bytes = buffer_size_per_thread * thread_count + partials_buffer_bytes; + size_t total_buffer_bytes = SafeInt(buffer_size_per_thread) * thread_count + partials_buffer_bytes; auto flash_buffer_alloc = allocator->Alloc(total_buffer_bytes); BufferUniquePtr flash_buffer(flash_buffer_alloc, BufferDeleter(allocator)); From 4a4e8450105c3431bfa5a542b425667fe786ffa2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Jun 2026 17:02:45 +0000 Subject: [PATCH 4/5] Restrict CPU FP32 GQA flash attention to prefill (sequence_length > 1) Single-token decode (sequence_length == 1) falls back to the naive path. A dedicated FP32 decode kernel will be added in a follow-up PR. The quantized path is unchanged. --- docs/contrib_ops/cpu/gqa.md | 15 +++++++-------- .../contrib_ops/cpu/bert/group_query_attention.cc | 6 +++++- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index 840dcea5b0cfd..6fc042f9d9ad5 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -258,11 +258,11 @@ The non-quantized flash path is selected when ALL of the following hold: - No output QK capture - `present_key` and `present_value` are provided -Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, shared past/present buffers, and flash decoding are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path. +Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, and shared past/present buffers are all supported for prefill, mirroring the quantized flash path. The non-quantized flash path is currently selected for prefill only (`sequence_length > 1`); single-token decode falls back to the naive full-materialization path (a dedicated decode kernel is added in a follow-up change). When any supported condition is not met, the kernel also falls back to the naive path. ### Block Sizes, Threading, and Flash Decoding -Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`), and the two-phase flash-decoding strategy for single-token decode are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. +Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, and the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`) for prefill are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. The two-phase flash-decoding strategy for single-token decode is gated off for the non-quantized path in this PR (decode falls back to naive); it is enabled together with the dedicated decode kernel in a follow-up change. ## MLAS Dispatch Paths @@ -516,12 +516,11 @@ algorithmic rather than purely from threading. #### Latency — Decode (S = 1, token generation) -For single-token decode at this head configuration (`batch\u00d7heads = 16 > threads = 8`, so -flash decoding KV-partitioning is not active), the workload per `Run` is tiny and dominated -by KV-cache concatenation overhead. Operator-level decode latency is therefore noisy and -roughly at parity between the two paths, with longer total sequence lengths (T\u22652049) -tending to favor flash. The FP32 decode path is not the target of the prefill-oriented -causal early-termination optimization. +Single-token decode (`sequence_length == 1`) currently falls back to the naive path for the +non-quantized FP32 cache: the flash path is gated on `sequence_length > 1` (prefill only), +because routing the tiny `1 × T × H` decode work through the tiled SGEMM kernel pays +per-block GEMM setup overhead with no tiling reuse benefit. A dedicated FP32 decode kernel +is added in a follow-up change. ## Current CPU Limitations The current CPU GroupQueryAttention implementation has a few important limitations: diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 29d372eb7a4bb..e345cbfd3aef6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -348,9 +348,13 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // kernel to avoid materializing the full attention score matrix. Falls back to the // naive path when an unsupported feature is requested (softcap, smooth softmax, // head sink, or QK output). + // + // The flash path is currently used for prefill only (sequence_length > 1). Single-token + // decode (sequence_length == 1) falls back to the naive path; a dedicated decode kernel + // is added in a follow-up change. if constexpr (std::is_same_v) { const bool use_flash = !disable_gqa_flash_ && - parameters.total_sequence_length > 1 && + parameters.sequence_length > 1 && softcap_ == 0.0f && !use_smooth_softmax_ && head_sink_data == nullptr && From 49ffc40d7b636319f7c2e80bd08bf5cd45f4ce46 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Jun 2026 17:10:04 +0000 Subject: [PATCH 5/5] Add CPU FP32 GQA GEMV decode kernel Adds a dedicated GEMV kernel (MlasGQADecodeGQAThreaded) for single-token decode (sequence_length == 1), and converts the flash-decoding inner M=1 GEMMs to GEMV. Re-enables the FP32 flash gate for decode (total_sequence_length > 1). Verified correctness vs naive (~1e-8); long-context decode ~1.0-1.2x, fixing the prior per-block SGEMM decode regression. --- docs/contrib_ops/cpu/gqa.md | 66 +++- .../contrib_ops/cpu/bert/gqa_attention_base.h | 5 + .../cpu/bert/group_query_attention.cc | 8 +- onnxruntime/core/mlas/lib/flashattn_gqa.cpp | 290 +++++++++++++++--- 4 files changed, 310 insertions(+), 59 deletions(-) diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index 6fc042f9d9ad5..ffce29682ca42 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -258,11 +258,45 @@ The non-quantized flash path is selected when ALL of the following hold: - No output QK capture - `present_key` and `present_value` are provided -Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, and shared past/present buffers are all supported for prefill, mirroring the quantized flash path. The non-quantized flash path is currently selected for prefill only (`sequence_length > 1`); single-token decode falls back to the naive full-materialization path (a dedicated decode kernel is added in a follow-up change). When any supported condition is not met, the kernel also falls back to the naive path. +Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, shared past/present buffers, and flash decoding are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path. ### Block Sizes, Threading, and Flash Decoding -Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, and the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`) for prefill are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. The two-phase flash-decoding strategy for single-token decode is gated off for the non-quantized path in this PR (decode falls back to naive); it is enabled together with the dedicated decode kernel in a follow-up change. +Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`), and the two-phase flash-decoding strategy for single-token decode are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. + +#### Decode uses a dedicated GEMV kernel (`sequence_length == 1`) + +The tiled online-softmax SGEMM kernel (`MlasFlashAttentionGQAThreaded`) is used **only for +prefill** (`sequence_length > 1`), where each KV tile is reused across the `q_block_size` +query rows and tiling delivers real cache-locality and SGEMM packing benefits. + +For single-token decode the query tile has `M = 1`, so every K/V element is streamed +exactly once with no reuse across query rows. Tiling provides **no** cache-locality +benefit, and routing the `1 × T × H` work through `MlasSgemmOperation` pays the SGEMM +B-packing/setup cost on every call — which previously made the flash decode path *slower* +than the naive path (≈0.4–0.6x) for short-to-medium total sequence lengths. + +Decode is therefore handled by a dedicated GEMV kernel (`MlasGQADecodeGQAThreaded`), +dispatched whenever `sequence_length == 1` and flash decoding is not active. It +parallelizes over `(batch, head)` and, per head, computes the attention directly with two +matrix-vector products and a two-pass softmax: + +- **QK GEMV** — `scores[t] = scale · dot(q, K[t])` for `t ∈ [0, total_seqlen)`. +- two-pass softmax over `scores` using the dispatched `ReduceMaximumF32Kernel` / + `ComputeSumExpF32Kernel` helpers. +- **SV GEMV** — `out[h] = Σ_t probs[t] · V[t][h]`, then normalize by `1/Σ probs`. + +Both GEMV helpers (`MlasGQADecodeQK`, `MlasGQADecodeSV`) live in the baseline-ISA MLAS +translation unit, so their inner loops use independent accumulator lanes / map-style +updates that vectorize under SSE2 without `-ffast-math`. Decode needs no causal mask (the +single new token is the most recent position and attends to every cached token); only +optional local-window masking and additive attention bias are applied. The kernel streams +K and V exactly once each, so it is memory-bandwidth bound. + +The two-phase flash-decoding path (active when `batch × heads < threads`, KV partitioned +across idle threads) now also uses these GEMV helpers for its per-chunk QK and SV products +instead of `M = 1` SGEMM calls, removing the same packing overhead. + ## MLAS Dispatch Paths @@ -516,11 +550,29 @@ algorithmic rather than purely from threading. #### Latency — Decode (S = 1, token generation) -Single-token decode (`sequence_length == 1`) currently falls back to the naive path for the -non-quantized FP32 cache: the flash path is gated on `sequence_length > 1` (prefill only), -because routing the tiny `1 × T × H` decode work through the tiled SGEMM kernel pays -per-block GEMM setup overhead with no tiling reuse benefit. A dedicated FP32 decode kernel -is added in a follow-up change. +For single-token decode at this head configuration (`batch\u00d7heads = 16 > threads = 8`, so +flash decoding KV-partitioning is not active), the workload per `Run` is tiny (a `1 × T × H` +GEMV pair per head) and operator-level latency is dominated by fixed per-`Run` overhead +(session dispatch, KV-cache concatenation), so operator-level measurements on the EPYC dev +box are extremely noisy. The numbers below come from a min-of-many-repeats MLAS-path harness +to suppress that jitter. + +| Total Seqlen | Naive (ms) | Flash (ms) | Speedup | +|---:|---:|---:|---:| +| 513 | 0.50 | 0.42 | ~1.0\u20131.2x (noisy) | +| 1025 | 0.78 | 0.69 | ~1.0\u20131.1x (noisy) | +| 2049 | 1.89 | 1.73 | ~1.0\u20131.1x (noisy) | +| 4097 | 6.1 | 4.5 | 1.35\u20131.5x | + +Decode is now handled by the dedicated GEMV kernel (`MlasGQADecodeGQAThreaded`) instead of +the prefill tiling kernel; see *Decode uses a dedicated GEMV kernel* above. Replacing the +per-head `M = 1` `MlasSgemmOperation` QK/SV calls with direct GEMVs removes the SGEMM +B-packing overhead that previously made flash decode noticeably **slower** than naive +(measured ≈0.4\u20130.6x across all lengths before the change). Flash decode is now at parity +for short/medium sequences (where the work is memory-bandwidth bound and overhead-dominated) +and consistently ahead for long contexts (T≥4097, ~1.4\u20131.5x) where the streamed +single-pass KV access wins. Short decode remains overhead-bound rather than algorithm-bound, +so it is not the target of the prefill-oriented causal early-termination optimization. ## Current CPU Limitations The current CPU GroupQueryAttention implementation has a few important limitations: diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index caa6a89f02fd7..e56e6eb1720da 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -1100,6 +1100,11 @@ class GQAAttentionBase { // Partials: [batch * num_heads * kv_chunk_count * (2 + head_size)] floats partials_buffer_bytes = SafeInt(batch_size) * num_heads_ * kv_chunk_count * (2 + head_size) * sizeof(float); + } else if (sequence_length == 1) { + // Decode (GEMV kernel, no Q/KV tiling): per-thread scratch holds the full + // score row scores[total_seqlen] plus a temp output accumulator[head_size]. + buffer_size_per_thread = + (SafeInt(max_total_seqlen) + head_size) * sizeof(float); } else { buffer_size_per_thread = (SafeInt(q_block_size) * 2 + // l + m diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index e345cbfd3aef6..d9739ae02aa7b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -349,12 +349,12 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // naive path when an unsupported feature is requested (softcap, smooth softmax, // head sink, or QK output). // - // The flash path is currently used for prefill only (sequence_length > 1). Single-token - // decode (sequence_length == 1) falls back to the naive path; a dedicated decode kernel - // is added in a follow-up change. + // Prefill (sequence_length > 1) uses the tiled kernel; single-token decode + // (sequence_length == 1 with total_sequence_length > 1) uses the dedicated GEMV + // decode kernel. Both are reached when total_sequence_length > 1. if constexpr (std::is_same_v) { const bool use_flash = !disable_gqa_flash_ && - parameters.sequence_length > 1 && + parameters.total_sequence_length > 1 && softcap_ == 0.0f && !use_smooth_softmax_ && head_sink_data == nullptr && diff --git a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp index 25f3733f59cca..99a66e16c02fa 100644 --- a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp +++ b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp @@ -32,6 +32,71 @@ Module Name: #include "mlasi.h" +// +// Decode (sequence_length == 1) GEMV helpers. +// +// With a single query token the QK^T and S*V products degenerate into +// matrix-vector products. Computing them directly streams the K and V cache +// exactly once and avoids the SGEMM B-packing overhead that otherwise dominates +// the tiny M = 1 GEMMs. These helpers live in the baseline-ISA MLAS translation +// unit, so the inner loops are written with independent accumulator lanes and a +// map-style update so the compiler can vectorize them without -ffast-math +// (which would be required to reassociate a plain scalar float reduction). +// + +// QK^T GEMV: scores[t] = scale * dot(q[0..H), K[t*H .. t*H+H)) for t in [0, n_kv). +static void +MlasGQADecodeQK( + const float* q, + const float* k_cache, + std::ptrdiff_t n_kv, + std::ptrdiff_t head_size, + float scale, + float* scores +) +{ + constexpr std::ptrdiff_t kLanes = 8; + for (std::ptrdiff_t t = 0; t < n_kv; ++t) { + const float* krow = k_cache + t * head_size; + float acc[kLanes] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + std::ptrdiff_t h = 0; + for (; h + kLanes <= head_size; h += kLanes) { + for (std::ptrdiff_t j = 0; j < kLanes; ++j) { + acc[j] += q[h + j] * krow[h + j]; + } + } + float sum = ((acc[0] + acc[1]) + (acc[2] + acc[3])) + + ((acc[4] + acc[5]) + (acc[6] + acc[7])); + for (; h < head_size; ++h) { + sum += q[h] * krow[h]; + } + scores[t] = sum * scale; + } +} + +// S*V GEMV (accumulate): out[h] = sum_t probs[t] * V[t*H + h] for h in [0, head_size). +// `out` is overwritten (initialized to zero) before accumulation. +static void +MlasGQADecodeSV( + const float* probs, + const float* v_cache, + std::ptrdiff_t n_kv, + std::ptrdiff_t head_size, + float* out +) +{ + for (std::ptrdiff_t h = 0; h < head_size; ++h) { + out[h] = 0.0f; + } + for (std::ptrdiff_t t = 0; t < n_kv; ++t) { + const float p = probs[t]; + const float* vrow = v_cache + t * head_size; + for (std::ptrdiff_t h = 0; h < head_size; ++h) { + out[h] += p * vrow[h]; + } + } +} + void MlasFlashAttentionGQAThreaded( void* argptr, @@ -381,23 +446,9 @@ MlasFlashDecodingGQAThreaded( static_cast(batch_idx) * args->q_batch_stride + static_cast(head_idx) * static_cast(head_size); - // Step 1: QK^T GEMM for this KV chunk + // Step 1: QK^T GEMV for this KV chunk (M = 1) const float* k_block = k_cache_head + static_cast(ir) * static_cast(head_size); - MlasSgemmOperation( - CblasNoTrans, - CblasTrans, - 1, // M (single query row) - row_size_kv, // N - static_cast(head_size), // K - scale, // alpha - q_ptr, // A (FP32 query) - static_cast(head_size), // lda - k_block, // B (FP32 K block) - static_cast(head_size), // ldb - 0.0f, // beta - scores, // C (output scores) - row_size_kv // ldc - ); + MlasGQADecodeQK(q_ptr, k_block, static_cast(row_size_kv), head_size, scale, scores); // Step 1b: Apply attention bias if present if (args->attention_bias != nullptr) { @@ -476,23 +527,9 @@ MlasFlashDecodingGQAThreaded( #endif *partial_l = rowsum; - // Step 4: S_exp * V_block -> partial_output + // Step 4: S_exp * V_block -> partial_output (M = 1) const float* v_block = v_cache_head + static_cast(ir) * static_cast(head_size); - MlasSgemmOperation( - CblasNoTrans, - CblasNoTrans, - 1, // M - static_cast(head_size), // N - row_size_kv, // K - 1.0f, // alpha - scores, // A (exp softmax scores) - row_size_kv, // lda - v_block, // B (FP32 V block) - static_cast(head_size), // ldb - 0.0f, // beta (overwrite) - partial_output, // C (output for this chunk) - static_cast(head_size) // ldc - ); + MlasGQADecodeSV(scores, v_block, static_cast(row_size_kv), head_size, partial_output); } } @@ -588,6 +625,150 @@ MlasFlashDecodingGQAReduceThreaded( } } +// +// Decode kernel for sequence_length == 1 without KV-split (batch * heads >= +// thread_count). Parallelizes over (batch, head); each task attends the single +// query token to the whole KV cache with a pair of GEMVs and a two-pass softmax. +// Decode needs no causal masking (the single new token is the most recent +// position and attends to every cached token); only optional local-window +// masking and additive bias are applied. +// +void +MlasGQADecodeGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // One task per (batch, head). + const ptrdiff_t total_task_count = batch_size * num_heads; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + // Local-window low bound: decode can attend to KV positions [window_start, total_seqlen). + // causal_limit == past_seqlen + 1 == total_seqlen for the single new token. + const ptrdiff_t window_start = + (local_window_size >= 0 && total_seqlen > local_window_size) ? (total_seqlen - local_window_size) : 0; + + // Per-thread scratch: scores[total_seqlen] followed by temp_output[head_size]. + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* scores = reinterpret_cast(buffer_ptr); + float* temp_output = scores + total_seqlen; + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + const ptrdiff_t head_idx = task_index % num_heads; + const ptrdiff_t batch_idx = task_index / num_heads; + + // KV head index for GQA head sharing. + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, 1, head_size]; batch stride supplied + // separately to support packed-QKV input. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(head_size); + + // Step 1: QK^T GEMV -> scores[0..T) + MlasGQADecodeQK(q_ptr, k_cache_head, total_seqlen, head_size, scale, scores); + + // Step 1b: additive attention bias (shape [batch|1, num_heads|1, S=1, T]). + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_matrix_size = + static_cast(args->attention_bias_seqlen_stride); // S == 1 + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + const float* bias_row = args->attention_bias + bias_offset; + for (ptrdiff_t t = 0; t < total_seqlen; ++t) { + scores[t] += bias_row[t]; + } + } + + // Step 2: local-window masking (no causal mask needed for decode). + if (window_start > 0) { + for (ptrdiff_t t = 0; t < window_start; ++t) { + scores[t] = std::numeric_limits::lowest(); + } + } + + // Step 3: softmax over scores[0..T). +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(scores, total_seqlen); +#else + float rowmax = MlasReduceMaximumF32Kernel(scores, total_seqlen); +#endif + + // Output layout: [batch, sequence_length=1, num_heads, head_size] + float* output_ptr = args->output + + (static_cast(batch_idx) * static_cast(num_heads) + + static_cast(head_idx)) * static_cast(head_size); + + if (rowmax == std::numeric_limits::lowest()) { + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + float negmax = -rowmax; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(scores, scores, total_seqlen, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(scores, scores, total_seqlen, &negmax); +#endif + + // Step 4: S_exp * V GEMV -> temp_output, then normalize by 1/l. + MlasGQADecodeSV(scores, v_cache_head, total_seqlen, head_size, temp_output); + + const float inv_l = (rowsum > 0.0f) ? (1.0f / rowsum) : 0.0f; + for (ptrdiff_t h = 0; h < head_size; ++h) { + output_ptr[h] = temp_output[h] * inv_l; + } + } +} + void MLASCALL MlasFlashAttentionGQA( @@ -595,23 +776,35 @@ MlasFlashAttentionGQA( MLAS_THREADPOOL* ThreadPool ) { - if (args->flash_decoding_partials != nullptr && args->sequence_length == 1) { - // Flash decoding: two-phase approach. - // Phase 1: parallel partial computation over (batch, head, kv_chunk). - MlasExecuteThreaded( - MlasFlashDecodingGQAThreaded, - static_cast(args), - static_cast(args->thread_count), - ThreadPool - ); - // Phase 2: reduce partials into final output (parallel over batch*heads). - MlasExecuteThreaded( - MlasFlashDecodingGQAReduceThreaded, - static_cast(args), - static_cast(args->thread_count), - ThreadPool - ); + if (args->sequence_length == 1) { + // Decode: M = 1, use the GEMV kernels (no SGEMM packing overhead). + if (args->flash_decoding_partials != nullptr) { + // Flash decoding: two-phase approach when KV is partitioned across threads. + // Phase 1: parallel partial computation over (batch, head, kv_chunk). + MlasExecuteThreaded( + MlasFlashDecodingGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + // Phase 2: reduce partials into final output (parallel over batch*heads). + MlasExecuteThreaded( + MlasFlashDecodingGQAReduceThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } else { + // Single-pass decode parallelized over (batch, head). + MlasExecuteThreaded( + MlasGQADecodeGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } } else { + // Prefill (sequence_length > 1): tiled online-softmax SGEMM kernel. MlasExecuteThreaded( MlasFlashAttentionGQAThreaded, static_cast(args), @@ -620,3 +813,4 @@ MlasFlashAttentionGQA( ); } } +