[LA] Lightning Attention MTP decode + KVBuffer parallel verify / commit#97
[LA] Lightning Attention MTP decode + KVBuffer parallel verify / commit#97fkuner wants to merge 11 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Lightning Attention MTP (Multi-Token Processing) decode and KVBuffer verify/state-update kernels to optimize speculative-decoding verification scenarios, along with corresponding benchmarks and unit tests. It also adds a Global-to-Register (G2R) prototype to optimize the big-batch decode path. The review feedback is highly constructive and identifies critical correctness and safety issues that must be addressed. Specifically, defensive checks should be added to Python entry points to prevent out-of-bounds memory accesses when the number of tokens T > 8 or head dimension K != 128. Additionally, wrappers must validate that the head dimension V is a multiple of ilp_rows to prevent silent correctness bugs where boundary chunks are skipped. Finally, environment variable lookups in the hot path of linear_attention_decode should be cached at the module level to eliminate unnecessary Python overhead.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
Fused multi-token (MTP) Lightning Attention decode kernel for speculative decoding: a single launch processes T draft tokens, with ILP variants and a work-unit heuristic (get_mtp_config). Includes packed F32x2 FMA on SM100. - cula/lightning/la_decode_mtp.py: kernel + config + shared dot/update helpers - tests/test_la_decode_mtp.py + tests/_la_mtp_ref.py: correctness vs PyTorch ref - benchmarks/bench_la_decode_mtp.py: vs sequential decode and FLA, with SOL model
KVBuffer-backed Lightning Attention for speculative decode verify/commit. Verify computes each draft step's output in closed form (paper Eq. 7) with the two dot-product GEMMs on tensor cores via inline-PTX mma.sync.m16n8k8 (TF32); state-update commits the accepted prefix into the pooled state (paper Eq. 8), bit-equivalent to the baseline T-loop at L == T. - cula/lightning/la_verify_kvbuffer.py: TF32 MMA verify kernel (+ shuffle variant) - cula/lightning/la_update_kvbuffer.py: KV buffer state-update (commit) kernel - tests/test_la_kvbuffer.py: correctness vs PyTorch ref (verify + update) - benchmarks/bench_la_kvbuffer.py: vs SGLang verify+commit (optional), with SOL model
- test_la_kvbuffer.py: add odd-T cases (verify T=1,3,5,7; state-update T=3,7) to guard the BT=8 M/N padding path that handles non-even draft lengths. - la_verify_kvbuffer.py: the shuffle launcher's SMEM byte estimate omitted the 16B per-allocation alignment padding (4 SMEM tensors), so the declared launch size could fall ~12B short of actual usage and trip CUTLASS's size check. Add the 4*16 padding term, matching the main-kernel launcher.
a0729c2 to
3e56972
Compare
Module name now matches the public symbol linear_attention_state_update_kvbuffer. Pure rename plus import-path updates; no behavior change. Co-authored-by: Cursor <cursoragent@cursor.com>
Structural cleanup of the LA decode-MTP kernel (no semantic change), split out of the prior pre-commit chore commit for reviewability. Co-authored-by: Cursor <cursoragent@cursor.com>
Formatting/lint fixes plus inlining the shared _la_mtp_ref helper directly into the test files. Benchmark updates included. Co-authored-by: Cursor <cursoragent@cursor.com>
…nstexpr loop Replace three explicit ilp_rows==2/==4/==8 branches with a single range_constexpr(ilp_rows) path, mirroring the pattern already used in la_state_update_kvbuffer. Cuts ~550 LOC without changing semantics.
b0e2243 to
971d524
Compare
- la_verify_kvbuffer: re-check V % ilp_rows == 0 AFTER the ilp_rows->8 promotion (the pre-promotion assert could let a partial row-block be silently skipped); zero the sH0 M-padding rows before GEMM1 so the MMA fragment is well-defined instead of consuming stale/NaN SMEM. - assert K == 128 in the verify (MMA + shuffle) and state-update entry points, documenting the hardcoded head-dim assumption.
971d524 to
d564e14
Compare
| # match the MMA kernel's ilp_rows=8 override (M=8 fragment fill) | ||
| if ilp_rows_kv < 8 and (tile_v_kv // 4) % 8 == 0: | ||
| ilp_rows_kv = 8 | ||
| verify_buf_cache = _get_compiled_verify_kvbuffer_kernel( |
There was a problem hiding this comment.
TypeError: _get_compiled_verify_kvbuffer_kernel() takes 12 positional arguments but 14 were given
P0: I've runed this benchmark and encountered the above error.
The kernel-only timing path passed shuffle-only args (use_smem_v, use_packed_fma) to _get_compiled_verify_kvbuffer_kernel, causing a TypeError at T >= MMA_MIN_T. Co-authored-by: Cursor <cursoragent@cursor.com>
| # ───────────────────────────────────────────────────────────────────────────── | ||
| # Timing utility | ||
| # ───────────────────────────────────────────────────────────────────────────── | ||
| def benchmark_fn(fn, warmup=30, rep=200): |
There was a problem hiding this comment.
Consider reusing https://github.com/inclusionAI/cuLA/blob/main/benchmarks/utils.py#L129
You could add a new iqr support to the utils.py
📌 Description
Adds target-side Lightning Attention (LA) support for speculative decoding (multi-token prediction), in two complementary pieces:
la_decode_mtp): a single launch advances the LA recurrence over allTdraft tokens per(batch, head), with ILP variants and a work-unit heuristic (get_mtp_config); packed F32×2 FMA on SM100.la_verify_kvbuffer/la_state_update_kvbuffer): the parallel-verification path. Verify computes every draft step's output in closed form straight from(h0, k, v)— without materializing theTintermediate states — and optionally writesk/vinto a compact pooled KV buffer; state-update (commit) then advances the pooled state by the per-request accepted prefix lengthLread from that buffer.Closed form used by verify (per draft step
t):The two dot-product GEMMs run on Blackwell tensor cores via inline-PTX
mma.sync.m16n8k8(TF32, fp32 SMEM staging); everything downstream is plain scalar math. M/N are padded toBT=8, so any draft lengthT ∈ [1, 8](odd or even) is handled.What changed
cula/lightning/la_decode_mtp.py— fused MTP decode kernel + config heuristic + shared dot/update helpers.cula/lightning/la_verify_kvbuffer.py— KVBuffer verify (TF32 MMA, register-shuffle variant) + optional KV-buffer write.cula/lightning/la_state_update_kvbuffer.py— KVBuffer state-update (commit) kernel; per-requestaccepted_len, skips padded slots (h0_indices < 0).tests/test_la_decode_mtp.py,tests/test_la_kvbuffer.py, — correctness vs a PyTorch reference.benchmarks/bench_la_decode_mtp.py,benchmarks/bench_la_kvbuffer.py— MTP-decode bench, and verify+commit chain bench with an optional SGLang baseline (no hard dependency on a local SGLang checkout).🔍 Related Issues
🚀 Pull Request Checklist
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit.pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
All kernels are checked against a PyTorch reference. Coverage:
T(1, 3, 5, 7) and the M/N-padding path; negativeh0_indicesslots left untouched.L=0no-op, and skipped (padded) slots.tests/test_la_decode_mtp.pytests/test_la_kvbuffer.py⚡ Performance
B200 (SM100), K=V=128, bf16 in / fp32 state, accept
m = full(L = T), kernel-only timing.Baseline = SGLang upstream chain (
seg_la_mtp_kernelverify +fused_mamba_state_scatter_with_maskcommit, both Triton). Each cell = verify + commit chain speedup =sglang_total / cuLA_total;>1means cuLA is faster.H == HV(no GQA) for a fair comparison.Chain speedup, HV = H = 32
Chain speedup, HV = H = 64
Takeaways
T ≥ 6and mid/large batch.k,vper draft token) replaces SGLang'sTper-tokend×dintermediate states. At B=128, T=8: 16.8 MB vs 2147 MB (HV=32) and 33.6 MB vs 4295 MB (HV=64) — independent of latency.Reviewer Notes
linear_attention_verify_kvbufferdispatches to a register-shuffle MMA variant; the inline-PTXmma.sync.m16n8k8path is retained alongside it.write_kv) so the verify kernel can run standalone (verify-only path above) or fused with the buffer write.L == Tthe committed state is bit-equivalent to running the baseline withdisable_state_update=False.LA_SGLANG_PYTHONto enable); without it the bench still validates against the PyTorch reference.