Skip to content

[LA] Lightning Attention MTP decode + KVBuffer parallel verify / commit#97

Open
fkuner wants to merge 11 commits into
inclusionAI:mainfrom
fkuner:la-decode-kvbuffer
Open

[LA] Lightning Attention MTP decode + KVBuffer parallel verify / commit#97
fkuner wants to merge 11 commits into
inclusionAI:mainfrom
fkuner:la-decode-kvbuffer

Conversation

@fkuner

@fkuner fkuner commented Jun 21, 2026

Copy link
Copy Markdown
Collaborator

📌 Description

Adds target-side Lightning Attention (LA) support for speculative decoding (multi-token prediction), in two complementary pieces:

  • Fused MTP decode (la_decode_mtp): a single launch advances the LA recurrence over all T draft tokens per (batch, head), with ILP variants and a work-unit heuristic (get_mtp_config); packed F32×2 FMA on SM100.
  • KVBuffer parallel verify + state-update (la_verify_kvbuffer / la_state_update_kvbuffer): the parallel-verification path. Verify computes every draft step's output in closed form straight from (h0, k, v) — without materializing the T intermediate states — and optionally writes k/v into a compact pooled KV buffer; state-update (commit) then advances the pooled state by the per-request accepted prefix length L read from that buffer.

Closed form used by verify (per draft step t):

o_t = α^{t+1} · (h0 · q_t · scale)                  # term1  (h0–Q GEMM)
    + Σ_{i=0..t} α^{t-i} · (q_t · k_i · scale) · v_i  # term2  (Q–K GEMM, then ·V)

The two dot-product GEMMs run on Blackwell tensor cores via inline-PTX mma.sync.m16n8k8 (TF32, fp32 SMEM staging); everything downstream is plain scalar math. M/N are padded to BT=8, so any draft length T ∈ [1, 8] (odd or even) is handled.

What changed

  • cula/lightning/la_decode_mtp.py — fused MTP decode kernel + config heuristic + shared dot/update helpers.
  • cula/lightning/la_verify_kvbuffer.py — KVBuffer verify (TF32 MMA, register-shuffle variant) + optional KV-buffer write.
  • cula/lightning/la_state_update_kvbuffer.py — KVBuffer state-update (commit) kernel; per-request accepted_len, skips padded slots (h0_indices < 0).
  • tests/test_la_decode_mtp.py, tests/test_la_kvbuffer.py, — correctness vs a PyTorch reference.
  • benchmarks/bench_la_decode_mtp.py, benchmarks/bench_la_kvbuffer.py — MTP-decode bench, and verify+commit chain bench with an optional SGLang baseline (no hard dependency on a local SGLang checkout).

🔍 Related Issues

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit.
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

All kernels are checked against a PyTorch reference. Coverage:

  • MTP decode — output vs reference across batch/T/head shapes.
  • KVBuffer verify — output vs reference, including odd T (1, 3, 5, 7) and the M/N-padding path; negative h0_indices slots left untouched.
  • State-update (commit) — full / partial / per-request accept length, L=0 no-op, and skipped (padded) slots.
$ pytest tests/test_la_decode_mtp.py tests/test_la_kvbuffer.py -q
Test file Result
tests/test_la_decode_mtp.py 15 passed
tests/test_la_kvbuffer.py 35 passed
Total 50 passed in 51.68s

⚡ Performance

B200 (SM100), K=V=128, bf16 in / fp32 state, accept m = full (L = T), kernel-only timing.
Baseline = SGLang upstream chain (seg_la_mtp_kernel verify + fused_mamba_state_scatter_with_mask commit, both Triton). Each cell = verify + commit chain speedup = sglang_total / cuLA_total; >1 means cuLA is faster. H == HV (no GQA) for a fair comparison.

Chain speedup, HV = H = 32

B T=2 T=3 T=4 T=6 T=8
1 2.38 2.41 2.44 2.43 2.44
2 2.44 2.34 2.41 2.40 2.39
4 2.40 2.39 2.41 2.48 2.54
8 2.39 2.34 2.51 2.55 2.55
16 2.27 2.16 2.28 2.57 2.98
32 2.46 2.35 2.70 2.90 3.01
64 1.68 1.76 1.80 2.23 2.44
128 1.68 1.80 1.96 2.36 2.64

Chain speedup, HV = H = 64

B T=2 T=3 T=4 T=6 T=8
1 2.17 2.41 2.46 2.53 2.50
2 2.43 2.46 2.48 2.54 2.50
4 2.43 2.42 2.62 2.48 2.57
8 2.22 2.21 2.34 2.69 3.12
16 2.49 2.33 2.66 3.02 3.04
32 1.70 1.77 1.82 2.19 2.46
64 1.68 1.80 1.99 2.38 2.65
128 1.75 1.80 1.98 2.43 2.77

Takeaways

  • Chain vs SGLang: 1.68× – 3.12× across every shape (both HV), strongest at T ≥ 6 and mid/large batch.
  • Verify scales flat in T. At B=128 the cuLA verify kernel grows only +81% (HV=32) / +90% (HV=64) from T=2→8, while SGLang's Triton verify grows +193% / +192%. The verify kernel alone (accept-independent) reaches up to ~4.6× (B=128, T=8) — the tensor-core, closed-form parallel verification paying off as T grows.
  • Memory: 128× less rollback storage. The pooled KV buffer (k,v per draft token) replaces SGLang's T per-token d×d intermediate states. At B=128, T=8: 16.8 MB vs 2147 MB (HV=32) and 33.6 MB vs 4295 MB (HV=64) — independent of latency.
  • Correctness: cuLA RMSE ≤ 2.7e-3 vs the PyTorch reference (bf16 in / fp32 state), at or below SGLang's own RMSE on the same inputs.

Reviewer Notes

  • The public linear_attention_verify_kvbuffer dispatches to a register-shuffle MMA variant; the inline-PTX mma.sync.m16n8k8 path is retained alongside it.
  • Verify's KV-buffer write is optional (write_kv) so the verify kernel can run standalone (verify-only path above) or fused with the buffer write.
  • The commit kernel's recurrence body is bit-identical to the baseline T-loop, so at L == T the committed state is bit-equivalent to running the baseline with disable_state_update=False.
  • The benchmark treats SGLang as an optional baseline (set LA_SGLANG_PYTHON to enable); without it the bench still validates against the PyTorch reference.

@fkuner fkuner requested a review from icavan June 21, 2026 05:39

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Lightning Attention MTP (Multi-Token Processing) decode and KVBuffer verify/state-update kernels to optimize speculative-decoding verification scenarios, along with corresponding benchmarks and unit tests. It also adds a Global-to-Register (G2R) prototype to optimize the big-batch decode path. The review feedback is highly constructive and identifies critical correctness and safety issues that must be addressed. Specifically, defensive checks should be added to Python entry points to prevent out-of-bounds memory accesses when the number of tokens T > 8 or head dimension K != 128. Additionally, wrappers must validate that the head dimension V is a multiple of ilp_rows to prevent silent correctness bugs where boundary chunks are skipped. Finally, environment variable lookups in the hot path of linear_attention_decode should be cached at the module level to eliminate unnecessary Python overhead.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread cula/lightning/la_decode_mtp.py
Comment thread cula/lightning/la_decode_mtp.py
Comment thread cula/lightning/la_verify_kvbuffer.py
Comment thread cula/lightning/la_verify_kvbuffer.py Outdated
Comment thread cula/lightning/la_verify_kvbuffer.py
Comment thread cula/lightning/la_verify_kvbuffer.py
Comment thread cula/lightning/la_state_update_kvbuffer.py
Comment thread cula/lightning/la_state_update_kvbuffer.py
Comment thread cula/ops/la_decode.py Outdated
Comment thread cula/ops/la_decode.py Outdated
fankun.fan added 3 commits June 21, 2026 13:44
Fused multi-token (MTP) Lightning Attention decode kernel for speculative
decoding: a single launch processes T draft tokens, with ILP variants and a
work-unit heuristic (get_mtp_config). Includes packed F32x2 FMA on SM100.

- cula/lightning/la_decode_mtp.py: kernel + config + shared dot/update helpers
- tests/test_la_decode_mtp.py + tests/_la_mtp_ref.py: correctness vs PyTorch ref
- benchmarks/bench_la_decode_mtp.py: vs sequential decode and FLA, with SOL model
KVBuffer-backed Lightning Attention for speculative decode verify/commit.
Verify computes each draft step's output in closed form (paper Eq. 7) with
the two dot-product GEMMs on tensor cores via inline-PTX mma.sync.m16n8k8
(TF32); state-update commits the accepted prefix into the pooled state
(paper Eq. 8), bit-equivalent to the baseline T-loop at L == T.

- cula/lightning/la_verify_kvbuffer.py: TF32 MMA verify kernel (+ shuffle variant)
- cula/lightning/la_update_kvbuffer.py: KV buffer state-update (commit) kernel
- tests/test_la_kvbuffer.py: correctness vs PyTorch ref (verify + update)
- benchmarks/bench_la_kvbuffer.py: vs SGLang verify+commit (optional), with SOL model
- test_la_kvbuffer.py: add odd-T cases (verify T=1,3,5,7; state-update T=3,7)
  to guard the BT=8 M/N padding path that handles non-even draft lengths.
- la_verify_kvbuffer.py: the shuffle launcher's SMEM byte estimate omitted the
  16B per-allocation alignment padding (4 SMEM tensors), so the declared launch
  size could fall ~12B short of actual usage and trip CUTLASS's size check. Add
  the 4*16 padding term, matching the main-kernel launcher.
@fkuner fkuner force-pushed the la-decode-kvbuffer branch 2 times, most recently from a0729c2 to 3e56972 Compare June 21, 2026 06:22
范坤 and others added 6 commits June 21, 2026 23:18
Module name now matches the public symbol linear_attention_state_update_kvbuffer.
Pure rename plus import-path updates; no behavior change.

Co-authored-by: Cursor <cursoragent@cursor.com>
Structural cleanup of the LA decode-MTP kernel (no semantic change),
split out of the prior pre-commit chore commit for reviewability.

Co-authored-by: Cursor <cursoragent@cursor.com>
Formatting/lint fixes plus inlining the shared _la_mtp_ref helper directly
into the test files. Benchmark updates included.

Co-authored-by: Cursor <cursoragent@cursor.com>
…nstexpr loop

Replace three explicit ilp_rows==2/==4/==8 branches with a single
range_constexpr(ilp_rows) path, mirroring the pattern already used in
la_state_update_kvbuffer.  Cuts ~550 LOC without changing semantics.
@fkuner fkuner force-pushed the la-decode-kvbuffer branch from b0e2243 to 971d524 Compare June 21, 2026 15:20
- la_verify_kvbuffer: re-check V % ilp_rows == 0 AFTER the ilp_rows->8
  promotion (the pre-promotion assert could let a partial row-block be
  silently skipped); zero the sH0 M-padding rows before GEMM1 so the MMA
  fragment is well-defined instead of consuming stale/NaN SMEM.
- assert K == 128 in the verify (MMA + shuffle) and state-update entry
  points, documenting the hardcoded head-dim assumption.
@fkuner fkuner force-pushed the la-decode-kvbuffer branch from 971d524 to d564e14 Compare June 21, 2026 16:55
# match the MMA kernel's ilp_rows=8 override (M=8 fragment fill)
if ilp_rows_kv < 8 and (tile_v_kv // 4) % 8 == 0:
ilp_rows_kv = 8
verify_buf_cache = _get_compiled_verify_kvbuffer_kernel(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeError: _get_compiled_verify_kvbuffer_kernel() takes 12 positional arguments but 14 were given

P0: I've runed this benchmark and encountered the above error.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

The kernel-only timing path passed shuffle-only args (use_smem_v, use_packed_fma)
to _get_compiled_verify_kvbuffer_kernel, causing a TypeError at T >= MMA_MIN_T.

Co-authored-by: Cursor <cursoragent@cursor.com>
# ─────────────────────────────────────────────────────────────────────────────
# Timing utility
# ─────────────────────────────────────────────────────────────────────────────
def benchmark_fn(fn, warmup=30, rep=200):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider reusing https://github.com/inclusionAI/cuLA/blob/main/benchmarks/utils.py#L129

You could add a new iqr support to the utils.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants