Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/flashattn_qkv.cpp
${MLAS_SRC_DIR}/flashattn_gqa.cpp
${MLAS_SRC_DIR}/qkv_quant.cpp
${MLAS_SRC_DIR}/cast.cpp
${MLAS_SRC_DIR}/layernorm.cpp
Expand Down
154 changes: 149 additions & 5 deletions docs/contrib_ops/cpu/gqa.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ Quantized KV-cache GEMM helpers are implemented in MLAS:
- `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp`
- `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp`
- `onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp`
- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (flash attention tiled kernel)
- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (quantized-KV flash attention tiled kernel)

The non-quantized flash attention tiled kernel is implemented in MLAS:

- `onnxruntime/core/mlas/lib/flashattn_gqa.cpp` (FP32-KV flash attention tiled kernel)
- `onnxruntime/core/mlas/inc/mlas.h` (`MlasFlashAttentionGQA` declaration and `MlasFlashAttentionGQAArgs`)

The operator schema itself is defined in:

Expand Down Expand Up @@ -48,12 +53,14 @@ At a high level, the CPU kernel executes GroupQueryAttention in these stages:

The non-quantized and quantized paths share the surrounding validation, masking, softmax, and output flow. Their main difference is how the K/V cache is stored and read during QK and SV GEMMs.

The quantized path has two execution strategies:
Both the non-quantized and quantized paths have two execution strategies:

- **Naive (full materialization)**: Computes the full `[S, T]` attention score matrix, applies masking and softmax, then computes the SV product. Simple but memory-intensive for long sequences.
- **Flash Attention (tiled, online softmax)**: Processes K/V in L2-cache-sized blocks using the online softmax algorithm (Milakov & Gimelshein, 2018). Avoids materializing the full attention matrix, reducing peak memory from O(S×T) to O(S×Bc) per head. Multi-threaded via the MLAS thread pool.

The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path.
The quantized path uses `MlasFlashAttentionQuantizedKV` (`flashattn_qkv.cpp`); the non-quantized FP32 path uses `MlasFlashAttentionGQA` (`flashattn_gqa.cpp`). Both share the same tiling, masking, online-softmax, and flash-decoding structure.

The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path (applies to both the quantized and non-quantized paths).

## Supported Cache Modes

Expand Down Expand Up @@ -144,9 +151,9 @@ For quantized V cache, the CPU path calls `MlasSVGemm` with:

As with QK GEMM, the default MLAS contract preserves the FP32 left-hand operand and dequantizes only the cached V values on the fly.

## Flash Attention Path
## Quantized Flash Attention Path

The flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix.
The quantized flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix.

### Algorithm

Expand Down Expand Up @@ -204,6 +211,93 @@ The partials buffer is allocated alongside the per-thread scratch in a single al
- Per-thread scratch: `scores[Bc]` (one float per KV block element)
- Partials: `batch × num_heads × kv_chunks × (2 + H)` floats (m, l, and partial output per chunk)

## Non-Quantized Flash Attention Path

The non-quantized flash attention path (`MlasFlashAttentionGQA`, in `flashattn_gqa.cpp`) is the FP32-KV-cache counterpart of the quantized path. It is selected for the `float` kernel specialization and reuses the same tiling, online-softmax, masking, and flash-decoding structure.

### Differences from the Quantized Path

- **Cache element type**: The present K/V cache is FP32, laid out as BNSH (`[batch, kv_num_heads, seqlen_present, head_size]`). There is no quantize-on-write or dequantize-on-read step.
- **QK GEMM**: Uses the single-threaded SGEMM primitive `MlasSgemmOperation(CblasNoTrans, CblasTrans, ...)` on an FP32 K block instead of `MlasQKGemm`.
- **SV accumulate**: Uses `MlasSgemmOperation(CblasNoTrans, CblasNoTrans, ..., beta)` with `beta = 0` for the first KV block and `beta = 1` afterwards (accumulate) instead of `MlasSVGemm`.
- **Cache concat**: New K/V tokens are appended into the FP32 present cache with `ConcatStateChunkGQA<float>` before the tiled loop runs.

### Algorithm

For each (batch, head, q_block) tile:

1. **QK GEMM** — `MlasSgemmOperation` of the query tile against a block slice of the FP32 K cache (Bc rows at a time)
1b. **Attention bias** — Add the corresponding tile of the bias tensor (if present) to QK scores
2. **Causal + local window masking** — Set masked positions to −∞ before softmax
3. **Online softmax** — Track running max `m` and sum `l`, rescale accumulated output with `exp(m_old − m_new)`
4. **SV accumulate** — `MlasSgemmOperation(..., beta)` accumulates `softmax(QK_block) × V_block` into the output tile
5. **Finalize** — Normalize accumulated output by `1/l` after all KV blocks are processed

#### Causal early-termination

During prefill, every KV block whose start index is at or beyond the largest global query
position in the current q_block is fully causally masked and contributes nothing. The kernel
computes a per-q_block bound
`kv_causal_limit = past_seqlen + q_idx + row_size_q` and breaks out of the KV loop once
`ir >= kv_causal_limit`, instead of computing and then discarding the masked upper-triangle
QK/SV GEMMs. This skips roughly half of the QK/SV work for square prefill (S = T) and is the
main reason the FP32 flash path is faster than naive even at short sequence lengths
(see the benchmark results below). Decode (q_block of size 1 at the cache tail) attends to all
KV positions, so the bound equals `total_seqlen` and nothing is skipped.

### Activation Conditions

The non-quantized flash path is selected when ALL of the following hold:

- The kernel specialization is `float` (FP16 uses the naive path)
- `ORT_GQA_DISABLE_FLASH_ATTENTION` environment variable is not set (or set to `0`)
- `total_sequence_length > 1`
- No softcap
- No smooth softmax
- No head sink
- No output QK capture
- `present_key` and `present_value` are provided

Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, shared past/present buffers, and flash decoding are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path.

### Block Sizes, Threading, and Flash Decoding

Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`), and the two-phase flash-decoding strategy for single-token decode are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization.

#### Decode uses a dedicated GEMV kernel (`sequence_length == 1`)

The tiled online-softmax SGEMM kernel (`MlasFlashAttentionGQAThreaded`) is used **only for
prefill** (`sequence_length > 1`), where each KV tile is reused across the `q_block_size`
query rows and tiling delivers real cache-locality and SGEMM packing benefits.

For single-token decode the query tile has `M = 1`, so every K/V element is streamed
exactly once with no reuse across query rows. Tiling provides **no** cache-locality
benefit, and routing the `1 × T × H` work through `MlasSgemmOperation` pays the SGEMM
B-packing/setup cost on every call — which previously made the flash decode path *slower*
than the naive path (≈0.4–0.6x) for short-to-medium total sequence lengths.

Decode is therefore handled by a dedicated GEMV kernel (`MlasGQADecodeGQAThreaded`),
dispatched whenever `sequence_length == 1` and flash decoding is not active. It
parallelizes over `(batch, head)` and, per head, computes the attention directly with two
matrix-vector products and a two-pass softmax:

- **QK GEMV** — `scores[t] = scale · dot(q, K[t])` for `t ∈ [0, total_seqlen)`.
- two-pass softmax over `scores` using the dispatched `ReduceMaximumF32Kernel` /
`ComputeSumExpF32Kernel` helpers.
- **SV GEMV** — `out[h] = Σ_t probs[t] · V[t][h]`, then normalize by `1/Σ probs`.

Both GEMV helpers (`MlasGQADecodeQK`, `MlasGQADecodeSV`) live in the baseline-ISA MLAS
translation unit, so their inner loops use independent accumulator lanes / map-style
updates that vectorize under SSE2 without `-ffast-math`. Decode needs no causal mask (the
single new token is the most recent position and attends to every cached token); only
optional local-window masking and additive attention bias are applied. The kernel streams
K and V exactly once each, so it is memory-bandwidth bound.

The two-phase flash-decoding path (active when `batch × heads < threads`, KV partitioned
across idle threads) now also uses these GEMV helpers for its per-chunk QK and SV products
instead of `M = 1` SGEMM calls, removing the same packing overhead.


## MLAS Dispatch Paths

MLAS selects the best available quantized KV-cache GEMM implementation through the platform dispatch table.
Expand Down Expand Up @@ -428,7 +522,57 @@ Flash decoding IS active (batch×heads=4 < threads=8, KV partitioned across idle
| 4096 (N=32) | +2131 | +87 | 24.5x |

**Summary**: The flash path's primary benefit for prefill is **memory reduction** — avoiding the full O(N×S×T) attention matrix. For S=4096 with 16 heads, the naive path allocates ~1 GB for attention scores while the flash path uses ~80 MB regardless of sequence length. The prefill latency speedup (1.2–2.7x at kernel level, 1.2–1.9x at operator level) comes from improved cache locality. For decode, the tiled kernel provides 1.2–1.8x kernel-level speedup from fused single-pass KV access; at operator level the gain is visible for T≥1024 but masked by KV concat overhead at shorter sequences. When flash decoding is active (batch×heads < threads), KV partitioning across idle threads yields an additional 2–5x speedup for long sequences.
### Non-Quantized (FP32) Flash Attention vs Naive benchmark results

Measured on an AMD EPYC 7763 (32 logical / 16 physical cores), threads=8, FP32 KV cache,
`B=1, num_heads=16, kv_num_heads=8, head_size=128`. Operator-level, measured with:

```bash
python onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py \
--fp32 --prompt_only --warmup 10 --repeats 30
```

#### Latency — Prefill (S = T, prompt phase)

| Seq Length | Naive (ms) | Flash (ms) | Speedup |
|---:|---:|---:|---:|
| 512 | 5.8\u20138.4 | 4.2\u20135.3 | 1.4\u20131.6x |
| 1024 | 25\u201329 | 13\u201318 | 1.6\u20132.0x |
| 2048 | 87\u2013118 | 52\u201365 | 1.5\u20132.0x |
| 4096 | 365\u2013380 | 213\u2013234 | 1.6\u20131.7x |

The FP32 flash path is faster than naive across all measured prefill lengths. With the causal
early-termination described above, roughly half of the QK/SV work (the causally masked
upper triangle of the square prefill attention matrix) is skipped entirely, which more than
offsets the intrinsic per-KV-block online-softmax overhead (running max/exp/output rescale).
The same advantage holds single-threaded (1.4\u20131.8x at threads=1), confirming the gain is
algorithmic rather than purely from threading.

#### Latency — Decode (S = 1, token generation)

For single-token decode at this head configuration (`batch\u00d7heads = 16 > threads = 8`, so
flash decoding KV-partitioning is not active), the workload per `Run` is tiny (a `1 × T × H`
GEMV pair per head) and operator-level latency is dominated by fixed per-`Run` overhead
(session dispatch, KV-cache concatenation), so operator-level measurements on the EPYC dev
box are extremely noisy. The numbers below come from a min-of-many-repeats MLAS-path harness
to suppress that jitter.

| Total Seqlen | Naive (ms) | Flash (ms) | Speedup |
|---:|---:|---:|---:|
| 513 | 0.50 | 0.42 | ~1.0\u20131.2x (noisy) |
| 1025 | 0.78 | 0.69 | ~1.0\u20131.1x (noisy) |
| 2049 | 1.89 | 1.73 | ~1.0\u20131.1x (noisy) |
| 4097 | 6.1 | 4.5 | 1.35\u20131.5x |

Decode is now handled by the dedicated GEMV kernel (`MlasGQADecodeGQAThreaded`) instead of
the prefill tiling kernel; see *Decode uses a dedicated GEMV kernel* above. Replacing the
per-head `M = 1` `MlasSgemmOperation` QK/SV calls with direct GEMVs removes the SGEMM
B-packing overhead that previously made flash decode noticeably **slower** than naive
(measured ≈0.4\u20130.6x across all lengths before the change). Flash decode is now at parity
for short/medium sequences (where the work is memory-bandwidth bound and overhead-dominated)
and consistently ahead for long contexts (T≥4097, ~1.4\u20131.5x) where the streamed
single-pass KV access wins. Short decode remains overhead-bound rather than algorithm-bound,
so it is not the target of the prefill-oriented causal early-termination optimization.
## Current CPU Limitations

The current CPU GroupQueryAttention implementation has a few important limitations:
Expand Down
Loading
Loading