Skip to content

[CPU] Add FP32 flash attention (prefill) and GEMV decode kernel for GroupQueryAttention#29216

Draft
tianleiwu wants to merge 6 commits into
mainfrom
tlwu/20260608/gqa_cpu_decode_gemv
Draft

[CPU] Add FP32 flash attention (prefill) and GEMV decode kernel for GroupQueryAttention#29216
tianleiwu wants to merge 6 commits into
mainfrom
tlwu/20260608/gqa_cpu_decode_gemv

Conversation

@tianleiwu

Copy link
Copy Markdown
Contributor

Description

Adds CPU FP32 flash attention for GroupQueryAttention, covering both prefill and single-token decode.

This PR builds on the prefill-only flash attention change (#PR1) and additionally introduces a dedicated decode kernel. Since it depends on that change, the PR1 commits are included here for completeness; it targets main and is marked draft until PR1 merges.

What's included

  • Prefill (tiled) flash attention — A fused QK^T + online-softmax + S·V MLAS kernel (MlasFlashAttentionGQA) that avoids materializing the full attention-score matrix. Falls back to the naive path when an unsupported feature is requested (softcap, smooth softmax, head sink, or QK output).
  • Decode (GEMV) kernel — A dedicated single-token decode kernel (MlasGQADecodeGQAThreaded) for sequence_length == 1, parallelized over (batch, head) with a two-pass softmax, using GEMV (acc[8]-lane dot product / AXPY) helpers instead of per-block M=1 SGEMM calls. This fixes the per-block SGEMM decode regression.
  • The FP32 flash gate (group_query_attention.cc) is enabled for total_sequence_length > 1, routing prefill to the tiled kernel and decode to the GEMV kernel.
  • The quantized KV-cache path is unchanged (FP32-only scope).

Results (AMD EPYC 7763, AVX2, 8 threads)

  • Prefill: ~1.4–2.3x vs naive (S = 512–4096).
  • Decode: correctness ~1e-8 vs naive; long-context decode ~1.0–1.5x (T = 4097 ~1.3–1.5x).

Motivation and Context

The naive GQA path materializes the full score matrix, which is memory-bound for long sequences. Flash attention reduces memory traffic for prefill, and the GEMV decode kernel avoids SGEMM overhead for the M=1 decode case.

Testing

  • Built with --compile_no_warning_as_error.
  • Correctness verified against the naive path for both prefill and decode (max abs diff ~1e-8).
  • Benchmarked via benchmark_gqa_cpu_flash.py.

tianleiwu added 6 commits June 9, 2026 19:36
Add an FP32 tiled online-softmax flash attention kernel for the CPU
GroupQueryAttention contrib op, mirroring the existing quantized-KV flash
path. Avoids materializing the full attention score matrix and adds a
two-phase flash-decoding path for single-token decode.

- New MLAS kernel core/mlas/lib/flashattn_gqa.cpp (MlasFlashAttentionGQA)
  supporting GQA head grouping, causal masking, local window, attention
  bias, ragged/per-batch seqlens, packed QKV, and flash-decoding.
- New ApplyAttentionFlash dispatch in gqa_attention_base.h; wired into
  group_query_attention.cc (float only, gated like the quantized flash
  path: no softcap/smooth softmax/head sink/QK output).
- Reuses ORT_GQA_DISABLE_FLASH_ATTENTION to fall back to the naive path.
Guard the per-thread scratch and flash-decoding partials buffer size
computations against size_t overflow for large or malformed shapes,
matching the SafeInt usage elsewhere in this file.
Single-token decode (sequence_length == 1) falls back to the naive path. A dedicated FP32 decode kernel will be added in a follow-up PR. The quantized path is unchanged.
Adds a dedicated GEMV kernel (MlasGQADecodeGQAThreaded) for single-token decode (sequence_length == 1), and converts the flash-decoding inner M=1 GEMMs to GEMV. Re-enables the FP32 flash gate for decode (total_sequence_length > 1). Verified correctness vs naive (~1e-8); long-context decode ~1.0-1.2x, fixing the prior per-block SGEMM decode regression.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant