[CPU] Add FP32 flash attention (prefill) and GEMV decode kernel for GroupQueryAttention#29216
Draft
tianleiwu wants to merge 6 commits into
Draft
[CPU] Add FP32 flash attention (prefill) and GEMV decode kernel for GroupQueryAttention#29216tianleiwu wants to merge 6 commits into
tianleiwu wants to merge 6 commits into
Conversation
Add an FP32 tiled online-softmax flash attention kernel for the CPU GroupQueryAttention contrib op, mirroring the existing quantized-KV flash path. Avoids materializing the full attention score matrix and adds a two-phase flash-decoding path for single-token decode. - New MLAS kernel core/mlas/lib/flashattn_gqa.cpp (MlasFlashAttentionGQA) supporting GQA head grouping, causal masking, local window, attention bias, ragged/per-batch seqlens, packed QKV, and flash-decoding. - New ApplyAttentionFlash dispatch in gqa_attention_base.h; wired into group_query_attention.cc (float only, gated like the quantized flash path: no softcap/smooth softmax/head sink/QK output). - Reuses ORT_GQA_DISABLE_FLASH_ATTENTION to fall back to the naive path.
Guard the per-thread scratch and flash-decoding partials buffer size computations against size_t overflow for large or malformed shapes, matching the SafeInt usage elsewhere in this file.
Single-token decode (sequence_length == 1) falls back to the naive path. A dedicated FP32 decode kernel will be added in a follow-up PR. The quantized path is unchanged.
Adds a dedicated GEMV kernel (MlasGQADecodeGQAThreaded) for single-token decode (sequence_length == 1), and converts the flash-decoding inner M=1 GEMMs to GEMV. Re-enables the FP32 flash gate for decode (total_sequence_length > 1). Verified correctness vs naive (~1e-8); long-context decode ~1.0-1.2x, fixing the prior per-block SGEMM decode regression.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds CPU FP32 flash attention for
GroupQueryAttention, covering both prefill and single-token decode.This PR builds on the prefill-only flash attention change (#PR1) and additionally introduces a dedicated decode kernel. Since it depends on that change, the PR1 commits are included here for completeness; it targets
mainand is marked draft until PR1 merges.What's included
MlasFlashAttentionGQA) that avoids materializing the full attention-score matrix. Falls back to the naive path when an unsupported feature is requested (softcap, smooth softmax, head sink, or QK output).MlasGQADecodeGQAThreaded) forsequence_length == 1, parallelized over (batch, head) with a two-pass softmax, using GEMV (acc[8]-lane dot product / AXPY) helpers instead of per-block M=1 SGEMM calls. This fixes the per-block SGEMM decode regression.group_query_attention.cc) is enabled fortotal_sequence_length > 1, routing prefill to the tiled kernel and decode to the GEMV kernel.Results (AMD EPYC 7763, AVX2, 8 threads)
Motivation and Context
The naive GQA path materializes the full score matrix, which is memory-bound for long sequences. Flash attention reduces memory traffic for prefill, and the GEMV decode kernel avoids SGEMM overhead for the M=1 decode case.
Testing
--compile_no_warning_as_error.benchmark_gqa_cpu_flash.py.