-
Notifications
You must be signed in to change notification settings - Fork 65
[LA] Lightning Attention MTP decode + KVBuffer parallel verify / commit #97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
fkuner
wants to merge
11
commits into
inclusionAI:main
Choose a base branch
from
fkuner:la-decode-kvbuffer
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
57e1729
feat: LA decode MTP kernel + tests + benchmark
b73d78d
feat: LA KVBuffer verify + state-update kernels + tests + benchmark
776212d
test: cover odd T for KVBuffer verify/update; fix shuffle SMEM size
9557706
refactor: rename la_update_kvbuffer.py -> la_state_update_kvbuffer.py
109cc27
refactor: clean up la_decode_mtp kernel structure
f3ee54a
chore: fix pre-commit issues and inline _la_mtp_ref into test files
5d52b81
chore: add defensive bounds checks for T and V in LA kernels
5d37d85
chore: remove vestigial use_smem_v and use_packed_fma from MMA verifyβ¦
a20187d
refactor: collapse ilp_rows branches in la_decode_mtp into generic coβ¦
d564e14
harden LA verify/state-update kernels
265b465
fix: pass correct cache key to MMA verify kernel in benchmark
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,313 @@ | ||
| # Copyright 2025-2026 Ant Group Co., Ltd. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| Lightning Attention KVBuffer state-update kernel (paper Eq. 8 for LA). | ||
|
|
||
| After a parallel-verify cycle, advances the pooled state from h_init to | ||
| h_state_L for a per-batch accepted prefix length L = accepted_len[b]: | ||
|
|
||
| h_running = h_init | ||
| for i in 0..L-1: | ||
| h_running = exp(-decay_scales[h]) * h_running + k_i β v_i | ||
| s[cache_idx] = h_running | ||
|
|
||
| The loop body is bit-identical to the baseline T-loop body, so at L == T the | ||
| result is bit-equivalent to running the baseline with disable_state_update=False. | ||
|
|
||
| Reads s, k, v; writes s. Never touches q or o. | ||
|
|
||
| Grid: (B * HV * num_v_tiles, 1, 1), 128 threads/block β identical layout to the | ||
| baseline verify kernel, so the state write aligns with the verify kernel's h0 read. | ||
| """ | ||
|
|
||
| import functools | ||
|
|
||
| import cuda.bindings.driver as cuda | ||
| import cutlass | ||
| import cutlass.cute as cute | ||
| import torch | ||
| from cutlass.cute.runtime import from_dlpack | ||
|
|
||
| from cula.lightning.la_decode_mtp import ( | ||
| NUM_THREADS_MTP, | ||
| get_mtp_config, | ||
| la_update_pair, | ||
| ) | ||
| from cula.utils import USE_FAST_MATH, get_device_sm_version | ||
|
|
||
|
|
||
| @cute.kernel | ||
| def la_state_update_kernel( | ||
| h0_source: cute.Tensor, # [pool_size * HV, V, K] fp32 (read + written in place) | ||
| decay_scales: cute.Tensor, # [H] fp32 | ||
| k: cute.Tensor, # [B, T, H, K] bf16 | ||
| v: cute.Tensor, # [B, T, HV, V] bf16 | ||
| h0_indices: cute.Tensor, # [B] int32 | ||
| accepted_len: cute.Tensor, # [B] int32 | ||
| k_buf: cute.Tensor, # [pool_size, T, H, K] bf16 (READ when read_from_buf) | ||
| v_buf: cute.Tensor, # [pool_size, T, HV, V] bf16 (READ when read_from_buf) | ||
| vec_size: cutlass.Constexpr[int], | ||
| num_v_tiles: cutlass.Constexpr[int], | ||
| tile_v: cutlass.Constexpr[int], | ||
| B: cutlass.Constexpr[int], | ||
| T: cutlass.Constexpr[int], | ||
| H: cutlass.Constexpr[int], | ||
| HV: cutlass.Constexpr[int], | ||
| K: cutlass.Constexpr[int], | ||
| V: cutlass.Constexpr[int], | ||
| ilp_rows: cutlass.Constexpr[int], | ||
| use_packed_fma: cutlass.Constexpr[bool], | ||
| read_from_buf: cutlass.Constexpr[bool], | ||
| ): | ||
| tidx, _, _ = cute.arch.thread_idx() | ||
| lane_id = tidx % 32 | ||
| warp_idx = cute.arch.warp_idx() | ||
| warp_idx = cute.arch.make_warp_uniform(warp_idx) | ||
|
|
||
| threads_per_group: cutlass.Constexpr[int] = K // vec_size # 32 | ||
| groups_per_warp: cutlass.Constexpr[int] = 32 // threads_per_group # 1 | ||
| num_groups: cutlass.Constexpr[int] = 4 * groups_per_warp # 4 | ||
|
|
||
| lane_in_group = lane_id % threads_per_group | ||
| group_in_warp = lane_id // threads_per_group | ||
| group_idx = warp_idx * groups_per_warp + group_in_warp | ||
|
|
||
| block_idx, _, _ = cute.arch.block_idx() | ||
| i_v = block_idx % num_v_tiles | ||
| tmp = block_idx // num_v_tiles | ||
| i_hv = tmp % HV | ||
| i_n = tmp // HV | ||
| i_h = i_hv // (HV // H) | ||
|
|
||
| cache_idx = h0_indices[i_n] | ||
| L = accepted_len[i_n] | ||
|
|
||
| r_k = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) | ||
| r_k_bf16 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) | ||
| r_h = cute.make_rmem_tensor(cute.make_layout((8, vec_size), stride=(vec_size, 1)), cutlass.Float32) | ||
|
|
||
| if cache_idx >= 0 and L > 0: | ||
| r_decay = cute.exp(-cutlass.Float32(decay_scales[i_h]), fastmath=USE_FAST_MATH) | ||
| rows_per_group: cutlass.Constexpr[int] = tile_v // num_groups | ||
| flat_state_idx = cache_idx * HV + i_hv | ||
|
|
||
| # Process `ilp_rows` V-rows per iteration. ilp_rows is a compile-time | ||
| # constant, so range_constexpr fully unrolls the slot loops below β the | ||
| # generated SASS is identical to hand-unrolling each ilp_rows value, but | ||
| # one loop covers ilp_rows in {2, 4, 8}. | ||
| num_chunks: cutlass.Constexpr[int] = rows_per_group // ilp_rows | ||
| for chunk in cutlass.range_constexpr(num_chunks): | ||
| v_idx_0 = i_v * tile_v + group_idx * rows_per_group + chunk * ilp_rows | ||
| if v_idx_0 + (ilp_rows - 1) < V: | ||
| # Load the ilp_rows h-state rows this thread owns into registers. | ||
| for slot in cutlass.range_constexpr(ilp_rows): | ||
| h_tile = cute.local_tile(h0_source, (1, 1, vec_size), (flat_state_idx, v_idx_0 + slot, lane_in_group)) | ||
| cute.autovec_copy(h_tile, cute.slice_(r_h, (slot, None))) | ||
|
|
||
| # Recurrence: h = decay * h + k_i (x) v_i, for i in 0..L-1. | ||
| for i in cutlass.range(0, L, unroll=0): | ||
| if cutlass.const_expr(read_from_buf): | ||
| k_tile = cute.local_tile(k_buf, (1, 1, 1, vec_size), (cache_idx, i, i_h, lane_in_group)) | ||
| else: | ||
| k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i, i_h, lane_in_group)) | ||
| cute.autovec_copy(k_tile, r_k_bf16) | ||
| for j in cutlass.range_constexpr(vec_size): | ||
| r_k[j] = cutlass.Float32(r_k_bf16[j]) | ||
| for slot in cutlass.range_constexpr(ilp_rows): | ||
| if cutlass.const_expr(read_from_buf): | ||
| r_v_s = cutlass.Float32(v_buf[cache_idx, i, i_hv, v_idx_0 + slot]) | ||
| else: | ||
| r_v_s = cutlass.Float32(v[i_n, i, i_hv, v_idx_0 + slot]) | ||
| for j in cutlass.range_constexpr(0, vec_size, 2): | ||
| r_h[slot, j], r_h[slot, j + 1] = la_update_pair( | ||
| r_h[slot, j], r_h[slot, j + 1], r_k[j], r_k[j + 1], r_v_s, r_decay, use_packed_fma | ||
| ) | ||
|
|
||
| # Write the advanced state back in place. | ||
| for slot in cutlass.range_constexpr(ilp_rows): | ||
| h_out = cute.local_tile(h0_source, (1, 1, vec_size), (flat_state_idx, v_idx_0 + slot, lane_in_group)) | ||
| cute.autovec_copy(cute.slice_(r_h, (slot, None)), h_out) | ||
|
|
||
|
|
||
| @cute.jit | ||
| def run_la_state_update_kernel( | ||
| h0_source: cute.Tensor, | ||
| decay_scales: cute.Tensor, | ||
| k: cute.Tensor, | ||
| v: cute.Tensor, | ||
| h0_indices: cute.Tensor, | ||
| accepted_len: cute.Tensor, | ||
| k_buf: cute.Tensor, | ||
| v_buf: cute.Tensor, | ||
| B: cutlass.Constexpr[int], | ||
| T: cutlass.Constexpr[int], | ||
| H: cutlass.Constexpr[int], | ||
| HV: cutlass.Constexpr[int], | ||
| K: cutlass.Constexpr[int], | ||
| V: cutlass.Constexpr[int], | ||
| tile_v: cutlass.Constexpr[int], | ||
| vec_size: cutlass.Constexpr[int], | ||
| ilp_rows: cutlass.Constexpr[int], | ||
| use_packed_fma: cutlass.Constexpr[bool], | ||
| read_from_buf: cutlass.Constexpr[bool], | ||
| stream: cuda.CUstream, | ||
| ): | ||
| num_v_tiles: cutlass.Constexpr[int] = (V + tile_v - 1) // tile_v | ||
| grid_size = B * HV * num_v_tiles | ||
|
|
||
| la_state_update_kernel( | ||
| h0_source, | ||
| decay_scales, | ||
| k, | ||
| v, | ||
| h0_indices, | ||
| accepted_len, | ||
| k_buf, | ||
| v_buf, | ||
| vec_size, | ||
| num_v_tiles, | ||
| tile_v, | ||
| B, | ||
| T, | ||
| H, | ||
| HV, | ||
| K, | ||
| V, | ||
| ilp_rows, | ||
| use_packed_fma, | ||
| read_from_buf, | ||
| ).launch( | ||
| grid=(grid_size, 1, 1), | ||
| block=[NUM_THREADS_MTP, 1, 1], | ||
| stream=stream, | ||
| ) | ||
|
|
||
|
|
||
| @functools.cache | ||
| def _get_compiled_state_update_kernel( | ||
| B: int, | ||
| T: int, | ||
| H: int, | ||
| HV: int, | ||
| K: int, | ||
| V: int, | ||
| pool_size: int, | ||
| tile_v: int, | ||
| vec_size: int, | ||
| ilp_rows: int, | ||
| use_packed_fma: bool, | ||
| read_from_buf: bool, | ||
| ): | ||
| return {} | ||
|
|
||
|
|
||
| def linear_attention_state_update_kvbuffer( | ||
| k: torch.Tensor, # [B, T, H, K] bf16 β read when k_buf is None | ||
| v: torch.Tensor, # [B, T, HV, V] bf16 β read when v_buf is None | ||
| s: torch.Tensor, # [pool_size, HV, V, K] fp32, WRITTEN IN PLACE | ||
| decay_scales: torch.Tensor, # [H] fp32 | ||
| h0_indices: torch.Tensor, # [B] int32, -1 to skip | ||
| accepted_len: torch.Tensor, # [B] int32, in [0, T] | ||
| T: int, | ||
| k_buf: torch.Tensor | None = None, # [pool_size, T, H, K] bf16 | ||
| v_buf: torch.Tensor | None = None, # [pool_size, T, HV, V] bf16 | ||
| ) -> None: | ||
| """ | ||
| Advance pooled state from h_init to h_state_L per batch (KVBuffer Eq. 8). | ||
|
|
||
| When k_buf and v_buf are provided, reads k,v from pool-indexed buffers | ||
| instead of batch-indexed input tensors. | ||
| """ | ||
| B, T_k, H, K = k.shape | ||
| assert T_k == T, f"k.shape[1]={T_k} doesn't match T={T}" | ||
| assert K == 128, f"K={K} != 128: kernel hardcodes K=128 (threads_per_group, lane K-coverage)" | ||
| _, _, HV, V = v.shape | ||
| pool_size = s.shape[0] | ||
|
|
||
| read_from_buf = k_buf is not None and v_buf is not None | ||
| if (k_buf is None) != (v_buf is None): | ||
| raise ValueError("k_buf and v_buf must both be None or both be provided") | ||
|
|
||
|
fkuner marked this conversation as resolved.
|
||
| tile_v, vec_size, ilp_rows, _use_smem_v = get_mtp_config(B, T, HV, V, False) | ||
| assert V % ilp_rows == 0, f"V={V} % ilp_rows={ilp_rows} β 0: partial row-blocks would be silently skipped" | ||
| major, _ = get_device_sm_version(k.device) | ||
| use_packed_fma = major >= 10 | ||
|
|
||
| cache_key = ( | ||
| B, | ||
| T, | ||
| H, | ||
| HV, | ||
| K, | ||
| V, | ||
| pool_size, | ||
| tile_v, | ||
| vec_size, | ||
| ilp_rows, | ||
| use_packed_fma, | ||
| read_from_buf, | ||
| ) | ||
| cache = _get_compiled_state_update_kernel(*cache_key) | ||
|
|
||
| h0_view = s.view(pool_size * HV, V, K) | ||
|
|
||
| if not read_from_buf: | ||
| k_buf_t = torch.empty(1, 1, 1, 1, device=k.device, dtype=torch.bfloat16) | ||
| v_buf_t = torch.empty(1, 1, 1, 1, device=k.device, dtype=torch.bfloat16) | ||
| else: | ||
| k_buf_t = k_buf | ||
| v_buf_t = v_buf | ||
|
|
||
| if "compiled" not in cache: | ||
| stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) | ||
| compiled = cute.compile( | ||
| run_la_state_update_kernel, | ||
| from_dlpack(h0_view, assumed_align=16), | ||
| from_dlpack(decay_scales, assumed_align=16), | ||
| from_dlpack(k, assumed_align=16), | ||
| from_dlpack(v, assumed_align=16), | ||
| from_dlpack(h0_indices, assumed_align=16), | ||
| from_dlpack(accepted_len, assumed_align=16), | ||
| from_dlpack(k_buf_t, assumed_align=16), | ||
| from_dlpack(v_buf_t, assumed_align=16), | ||
| B=B, | ||
| T=T, | ||
| H=H, | ||
| HV=HV, | ||
| K=K, | ||
| V=V, | ||
| tile_v=tile_v, | ||
| vec_size=vec_size, | ||
| ilp_rows=ilp_rows, | ||
| use_packed_fma=use_packed_fma, | ||
| read_from_buf=read_from_buf, | ||
| stream=stream, | ||
| options="--enable-tvm-ffi", | ||
| ) | ||
| cache["compiled"] = compiled | ||
|
|
||
| compiled = cache["compiled"] | ||
| stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) | ||
| compiled( | ||
| h0_view, | ||
| decay_scales, | ||
| k, | ||
| v, | ||
| h0_indices, | ||
| accepted_len, | ||
| k_buf_t, | ||
| v_buf_t, | ||
| stream, | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.