diff --git a/benchmarks/bench_la_decode_mtp.py b/benchmarks/bench_la_decode_mtp.py new file mode 100644 index 00000000..d673e909 --- /dev/null +++ b/benchmarks/bench_la_decode_mtp.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmark: la_decode_mtp (CuTe DSL) vs alternatives on Lightning Attention MTP. + +Compares three implementations of T > 1 Lightning Attention decode: + 1. cula `linear_attention_decode_mtp` (this work — fused single-launch) + 2. fla `fused_recurrent_fwd` (Triton, T-aware) + 3. cula `linear_attention_decode` × T (cula self-comparison; T sequential calls) + +Two timing modes (mirroring bench_la_decode_vs_fla.py): + - kernel-only: pre-allocated buffers, pre-compiled kernel handle, pre-built stream + - wrapper: full Python entry point per call (cache lookup, CUstream, ...) + +Bandwidth analysis (SOL% against B200 HBM3e peak ~8 TB/s) printed alongside. + +Usage: + python benchmarks/bench_la_decode_mtp.py + python benchmarks/bench_la_decode_mtp.py --heads 64 --head-dim 128 --T 4 + python benchmarks/bench_la_decode_mtp.py --batch-sizes 1 4 16 64 --T 2 +""" + +import argparse +import os +import sys + +os.environ.setdefault("FLA_USE_FAST_OPS", os.getenv("CULA_USE_FAST_MATH", "1")) + +import cuda.bindings.driver as cuda_drv +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +try: + from fla.ops.common.fused_recurrent import fused_recurrent_fwd + + HAS_FLA = True +except ImportError: + HAS_FLA = False + +from cula.lightning.la_decode_mtp import ( + _get_compiled_la_mtp_kernel, + get_mtp_config, + linear_attention_decode_mtp, +) +from cula.ops.la_decode import linear_attention_decode +from cula.utils import USE_FAST_MATH, get_device_sm_version + + +# ───────────────────────────────────────────────────────────────────────────── +# Timing utility +# ───────────────────────────────────────────────────────────────────────────── +def benchmark_fn(fn, warmup=30, rep=200): + """CUDA-event timing with IQR-mean (drops outliers).""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + starts = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + for i in range(rep): + starts[i].record() + fn() + ends[i].record() + torch.cuda.synchronize() + + times = sorted(s.elapsed_time(e) for s, e in zip(starts, ends)) + n = len(times) + iqr = times[n // 4 : 3 * n // 4] + return sum(iqr) / len(iqr) + + +# ───────────────────────────────────────────────────────────────────────────── +# Bandwidth model — see spec §9.3 +# ───────────────────────────────────────────────────────────────────────────── +def la_mtp_bytes(B, T, H, HV, K, V, cache_intermediate_states, disable_state_update): + bf16, fp32 = 2, 4 + qkv = B * T * H * K * bf16 * 2 + B * T * HV * V * bf16 # q, k, v reads + out_w = B * T * HV * V * bf16 # o writes + h0_r = B * HV * V * K * fp32 # h0 reads + h0_w = 0 if disable_state_update else B * HV * V * K * fp32 # h0 writes + inter = B * T * HV * V * K * fp32 if cache_intermediate_states else 0 + return qkv + out_w + h0_r + h0_w + inter + + +def sol_pct(byte_count: int, kernel_ms: float, peak_bps: float) -> float: + """Speed-of-light percent of HBM peak.""" + return (byte_count / (kernel_ms * 1e-3)) / peak_bps * 100.0 + + +# ───────────────────────────────────────────────────────────────────────────── +# Core benchmark for one (B, T) configuration +# ───────────────────────────────────────────────────────────────────────────── +def run_config( + B, T, H, HV, K, V, layer_idx, num_layers, peak_bps, cache_intermediate_states=False, disable_state_update=False +): + device = "cuda" + dtype = torch.bfloat16 + scale = K**-0.5 + + # Per-head log decay (Lightning Attention formula) + g_gamma = -(8 / H * (1 - layer_idx / num_layers)) * torch.arange(H, device=device, dtype=torch.float32) + decay_scales = -g_gamma # la_decode_mtp convention: exp(-decay_scales) + + # ── Random inputs ────────────────────────────────────────────────────── + torch.manual_seed(42) + q_4d = torch.randn(B, T, H, K, device=device, dtype=dtype) + k_4d = torch.randn(B, T, H, K, device=device, dtype=dtype) + v_4d = torch.randn(B, T, HV, V, device=device, dtype=dtype) + state_init = torch.randn(B, HV, K, V, device=device, dtype=torch.float32) * 0.01 # K-major + + # ── fla reference output ─────────────────────────────────────────────── + o_fla = None + if HAS_FLA: + state_fla = state_init.clone() + with torch.no_grad(): + o_fla_fp32, ht_fla = fused_recurrent_fwd( + q_4d, + k_4d, + v_4d, + g_gamma=g_gamma, + scale=scale, + initial_state=state_fla, + output_final_state=True, + ) + o_fla = o_fla_fp32.to(dtype) # [B, T, H, V] (fla expects HV==H) + + # ── cula MTP ─────────────────────────────────────────────────────────── + s_cute = state_init.clone().permute(0, 1, 3, 2).contiguous() # [B, HV, V, K] + out_cute = torch.zeros(B, T, HV, V, device=device, dtype=dtype) + s_offsets = torch.arange(B, device=device, dtype=torch.int32) + inter = torch.empty(1, 1, 1, device=device, dtype=torch.float32) # dummy + cu_seqlens_dummy = torch.empty(1, device=device, dtype=torch.int32) + + if cache_intermediate_states: + inter = torch.zeros(B * T * HV, V, K, device=device, dtype=torch.float32) + + with torch.no_grad(): + linear_attention_decode_mtp( + q_4d, + k_4d, + v_4d, + s_cute, + inter, + out_cute, + decay_scales=decay_scales, + s_offsets=s_offsets, + cu_seqlens=cu_seqlens_dummy, + softmax_scale=scale, + T=T, + cache_intermediate_states=cache_intermediate_states, + disable_state_update=disable_state_update, + is_varlen=False, + ) + + # ── Correctness vs fla ───────────────────────────────────────────────── + rmse, rel_maxdiff = float("nan"), float("nan") + if o_fla is not None and HV == H: + out_cmp = out_cute.float() + ref_cmp = o_fla.float() + rmse = torch.sqrt(torch.mean((out_cmp - ref_cmp) ** 2)).item() + max_ref = torch.abs(ref_cmp).max().item() + rel_maxdiff = torch.abs(out_cmp - ref_cmp).max().item() / (max_ref + 1e-8) + + # ================================================================== + # Mode 1: KERNEL-ONLY — pre-allocated, pre-compiled, pre-built stream + # ================================================================== + pool_size = B + cache_key = ( + B, + T, + H, + HV, + K, + V, + pool_size, + scale, + disable_state_update, + cache_intermediate_states, + False, + *get_mtp_config(B, T, HV, V, disable_state_update), + get_device_sm_version(q_4d.device)[0] >= 10, + ) + cute_cache = _get_compiled_la_mtp_kernel(*cache_key) + compiled_cute = cute_cache["compiled"] + stream_handle = cuda_drv.CUstream(torch.cuda.current_stream().cuda_stream) + + state_kk = state_init.clone().permute(0, 1, 3, 2).contiguous().view(pool_size * HV, V, K) + out_kk = torch.empty(B, T, HV, V, device=device, dtype=dtype) + inter_kk = inter if cache_intermediate_states else torch.empty(1, 1, 1, device=device, dtype=torch.float32) + + def kernel_cute_mtp(): + compiled_cute( + state_kk, + inter_kk, + decay_scales, + q_4d, + k_4d, + v_4d, + out_kk, + s_offsets, + cu_seqlens_dummy, + stream_handle, + ) + + # cula T-sequential baseline: T calls to la_decode (T=1 each) + state_seq = state_init.clone().permute(0, 1, 3, 2).contiguous().view(B * HV, V, K) + out_seq_buf = torch.empty(B, HV, V, device=device, dtype=dtype) + q_slices = [q_4d[:, t].contiguous() for t in range(T)] + k_slices = [k_4d[:, t].contiguous() for t in range(T)] + v_slices = [v_4d[:, t].contiguous() for t in range(T)] + + def kernel_cute_seq(): + for t in range(T): + linear_attention_decode( + q_slices[t], + k_slices[t], + v_slices[t], + state_seq, + out_seq_buf, + softmax_scale=scale, + stride_q=0, + stride_k=0, + stride_v=0, + stride_s=0, + stride_o=0, + s_offsets=s_offsets, + decay_scales=decay_scales, + HEAD_DIM=K, + K_SPLIT_DIM=K, + V_SPLIT_DIM=V, + ) + + # fla kernel-only mode would require careful pre-allocation; use wrapper for fla. + with torch.no_grad(): + cute_mtp_ms = benchmark_fn(kernel_cute_mtp) + cute_seq_ms = benchmark_fn(kernel_cute_seq) + + # ================================================================== + # Mode 2: WRAPPER — full Python entry path (cache lookup + CUstream per call) + # ================================================================== + s_wrap = state_init.clone().permute(0, 1, 3, 2).contiguous() + out_wrap = torch.empty(B, T, HV, V, device=device, dtype=dtype) + inter_wrap = ( + torch.zeros(B * T * HV, V, K, device=device, dtype=torch.float32) + if cache_intermediate_states + else torch.empty(1, 1, 1, device=device, dtype=torch.float32) + ) + + def wrapper_cute_mtp(): + linear_attention_decode_mtp( + q_4d, + k_4d, + v_4d, + s_wrap, + inter_wrap, + out_wrap, + decay_scales=decay_scales, + s_offsets=s_offsets, + cu_seqlens=cu_seqlens_dummy, + softmax_scale=scale, + T=T, + cache_intermediate_states=cache_intermediate_states, + disable_state_update=disable_state_update, + is_varlen=False, + ) + + with torch.no_grad(): + wrap_cute_ms = benchmark_fn(wrapper_cute_mtp) + + # fla wrapper + fla_ms = float("nan") + if HAS_FLA: + state_fla_bench = state_init.clone() + + def wrapper_fla(): + fused_recurrent_fwd( + q_4d, + k_4d, + v_4d, + g_gamma=g_gamma, + scale=scale, + initial_state=state_fla_bench, + output_final_state=True, + ) + + with torch.no_grad(): + fla_ms = benchmark_fn(wrapper_fla) + + # ── Roofline ──────────────────────────────────────────────────────── + bytes_moved = la_mtp_bytes( + B, + T, + H, + HV, + K, + V, + cache_intermediate_states=cache_intermediate_states, + disable_state_update=disable_state_update, + ) + sol = sol_pct(bytes_moved, cute_mtp_ms, peak_bps) + + speedup_seq = cute_seq_ms / cute_mtp_ms + speedup_fla = fla_ms / cute_mtp_ms if HAS_FLA else float("nan") + + return { + "B": B, + "T": T, + "cute_mtp_ms": cute_mtp_ms, + "cute_seq_ms": cute_seq_ms, + "fla_ms": fla_ms, + "wrap_cute_ms": wrap_cute_ms, + "speedup_seq": speedup_seq, + "speedup_fla": speedup_fla, + "rmse": rmse, + "rel_maxdiff": rel_maxdiff, + "sol_pct": sol, + "bytes_GB": bytes_moved / 1e9, + } + + +# ───────────────────────────────────────────────────────────────────────────── +# Main +# ───────────────────────────────────────────────────────────────────────────── +def main(): + parser = argparse.ArgumentParser(description="Benchmark la_decode_mtp") + parser.add_argument("--batch-sizes", type=int, nargs="+", default=[1, 2, 4, 8, 16, 32, 64, 128]) + parser.add_argument("--T", type=int, nargs="+", default=[2, 4, 8]) + parser.add_argument("--heads", type=int, default=32) + parser.add_argument("--num-v-heads", type=int, default=None, help="HV (defaults to --heads for MHA)") + parser.add_argument("--head-dim", type=int, default=128) + parser.add_argument("--layer-idx", type=int, default=12) + parser.add_argument("--num-layers", type=int, default=24) + parser.add_argument("--peak-bps", type=float, default=8e12, help="HBM peak bytes/sec for SOL%% (B200 HBM3e ≈ 8e12)") + parser.add_argument("--cache-intermediate", action=argparse.BooleanOptionalAction, default=True) + parser.add_argument("--disable-state-update", action=argparse.BooleanOptionalAction, default=True) + args = parser.parse_args() + + H = args.heads + HV = args.num_v_heads if args.num_v_heads is not None else H + K = V = args.head_dim + + print("Lightning Attention MTP Decode Benchmark") + print(f" H={H}, HV={HV}, K={K}, V={V}, layer={args.layer_idx}/{args.num_layers}") + print(f" dtype=bf16, state=fp32, peak={args.peak_bps:.2e} B/s") + print(f" cache_intermediate_states={args.cache_intermediate}, disable_state_update={args.disable_state_update}") + print(f" USE_FAST_MATH={USE_FAST_MATH}, fla available={HAS_FLA}") + + fla_avail = HAS_FLA and HV == H # fla expects HV == H + if HAS_FLA and HV != H: + print(f" [warning] GQA HV={HV} != H={H}; fla baseline disabled (fla assumes HV==H)") + + cols = ( + f"{'B':>4} | {'T':>3} | {'cute_mtp(ms)':>12} | {'cute×T(ms)':>10} | " + f"{'fla(ms)':>9} | {'spd_seq':>7} | {'spd_fla':>7} | " + f"{'wrap(ms)':>9} | {'SOL%':>5} | {'GB':>6} | {'RMSE':>9}" + ) + print(f"\n{cols}") + print("─" * len(cols)) + + for T in args.T: + for B in args.batch_sizes: + r = run_config( + B, + T, + H, + HV, + K, + V, + args.layer_idx, + args.num_layers, + args.peak_bps, + cache_intermediate_states=args.cache_intermediate, + disable_state_update=args.disable_state_update, + ) + print( + f"{r['B']:>4} | {r['T']:>3} | {r['cute_mtp_ms']:>12.4f} | " + f"{r['cute_seq_ms']:>10.4f} | " + f"{(r['fla_ms'] if fla_avail else float('nan')):>9.4f} | " + f"{r['speedup_seq']:>6.2f}x | " + f"{(r['speedup_fla'] if fla_avail else float('nan')):>6.2f}x | " + f"{r['wrap_cute_ms']:>9.4f} | {r['sol_pct']:>5.1f} | " + f"{r['bytes_GB']:>6.3f} | {r['rmse']:>9.6f}" + ) + print() + + print("Notes:") + print(" cute_mtp : linear_attention_decode_mtp (fused single launch, T tokens)") + print(" cute×T : T sequential linear_attention_decode (T=1) calls — cula self-baseline") + print(" fla : fused_recurrent_fwd (Triton); kernel still re-launched per T internally") + print(" spd_seq : cute×T / cute_mtp (fusion benefit within cula)") + print(" spd_fla : fla / cute_mtp (vs industry reference)") + print(" wrap(ms) : cute_mtp full Python entry (cache lookup + CUstream + kernel)") + print(f" SOL% : (bytes / kernel_ms) / peak_bps × 100 (peak = {args.peak_bps:.2e} B/s)") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_la_kvbuffer.py b/benchmarks/bench_la_kvbuffer.py new file mode 100644 index 00000000..e8dc0471 --- /dev/null +++ b/benchmarks/bench_la_kvbuffer.py @@ -0,0 +1,646 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +""" +Benchmark: cuLA LA KVBuffer verify + state-update kernels. + +Times the KVBuffer path (verify writes k/v to a pool buffer; state-update advances +the pooled state from it) and validates it against a shared PyTorch reference. + +An optional SGLang baseline (seg_la_mtp_kernel + fused_mamba_state_scatter_with_mask) +is compared when available — SGLang is not required. If it cannot be imported the +sg_* columns show nan and only the cuLA path is benchmarked. Set +LA_SGLANG_PYTHON=/path/to/sglang/python to point at a custom checkout. + +Usage: + python benchmarks/bench_la_kvbuffer.py + python benchmarks/bench_la_kvbuffer.py --batch-sizes 1 4 16 64 --T 2 4 8 + LA_SGLANG_PYTHON=~/sglang/python python benchmarks/bench_la_kvbuffer.py --T 4 +""" + +import argparse +import os +import sys + +import torch + +os.environ.setdefault("FLA_USE_FAST_OPS", os.getenv("CULA_USE_FAST_MATH", "1")) + +import cuda.bindings.driver as cuda_drv + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +# SGLang is an OPTIONAL comparison baseline — not every developer has it checked +# out. We try to import it (honoring LA_SGLANG_PYTHON for a custom path); if it is +# unavailable, the benchmark still runs against the PyTorch reference and simply +# skips the sg_* columns. Mirrors the pattern in bench_kda_decode_mtp.py. +_HAVE_SGLANG, _SGLANG_ERR = True, "" +SegLaMeta = seg_la_mtp_kernel = seg_la_sum_kernel = None +fused_mamba_state_scatter_with_mask = None +try: + _sg_path = os.environ.get("LA_SGLANG_PYTHON", "") + if _sg_path and os.path.isdir(_sg_path): + sys.path.insert(0, _sg_path) + from sglang.srt.layers.attention.linear.seg_la import ( + SegLaMeta, + seg_la_mtp_kernel, + seg_la_sum_kernel, + ) + from sglang.srt.layers.attention.mamba.mamba_state_scatter_triton import ( + fused_mamba_state_scatter_with_mask, + ) +except Exception as e: # noqa: BLE001 — any import failure → run without SGLang + _HAVE_SGLANG, _SGLANG_ERR = False, repr(e) + +from cula.lightning.la_decode_mtp import ( # noqa: E402 + get_mtp_config, + linear_attention_decode_mtp, +) +from cula.lightning.la_state_update_kvbuffer import ( # noqa: E402 + _get_compiled_state_update_kernel, + linear_attention_state_update_kvbuffer, +) +from cula.lightning.la_verify_kvbuffer import ( # noqa: E402 + MMA_MIN_T, + _get_compiled_verify_kvbuffer_kernel, + _get_compiled_verify_kvbuffer_kernel_shuffle, + linear_attention_verify_kvbuffer, +) +from cula.utils import USE_FAST_MATH, get_device_sm_version # noqa: E402 + + +# ───────────────────────────────────────────────────────────────────────────── +# PyTorch reference +# ───────────────────────────────────────────────────────────────────────────── +def torch_la_mtp_ref(q, k, v, state, decay_scales, softmax_scale): + """ + Pure PyTorch reference for MTP decode. + + Args: + q, k: [B, T, H, K] bf16 + v: [B, T, H, V] bf16 (H == HV for SGLang compat) + state: [B, H, K, V] fp32 (K-major, SGLang convention) + decay_scales: [H] fp32 + softmax_scale: float + + Returns: + out: [B, T, H, V] fp32 + state: [B, H, K, V] fp32 (updated) + """ + B, T, H, K = q.shape + V = v.shape[-1] + state = state.clone().float() + out = torch.zeros(B, T, H, V, device=q.device, dtype=torch.float32) + + decay = torch.exp(-decay_scales).float() # [H] + + for t in range(T): + qt = q[:, t].float() * softmax_scale # [B, H, K] + kt = k[:, t].float() # [B, H, K] + vt = v[:, t].float() # [B, H, V] + state = state * decay[None, :, None, None] + kt.unsqueeze(-1) * vt.unsqueeze(-2) + out[:, t] = torch.einsum("bhk,bhkv->bhv", qt, state) + + return out, state + + +# ───────────────────────────────────────────────────────────────────────────── +# SGLang seg_la MTP wrapper (matches seg_la_fwd MTP path) +# ───────────────────────────────────────────────────────────────────────────── +def run_sglang_mtp( + q_3d, + k_3d, + v_3d, + s_sglang, + caches_sglang, + s_offsets, + cache_indices, + decay_scales, + meta, + softmax_scale, + HEAD_DIM, + step, + K_SPLIT_DIM=32, + V_SPLIT_DIM=64, +): + """ + Invoke seg_la_mtp_kernel the same way seg_la_fwd does for the MTP path. + + q_3d, k_3d, v_3d: [length, qo_heads, HEAD_DIM] (contiguous, length = B*step) + s_sglang: [pool_size, qo_heads, HEAD_DIM, HEAD_DIM] fp32 + caches_sglang: [pool_size * step, qo_heads, HEAD_DIM, HEAD_DIM] fp32 + """ + length = q_3d.shape[0] + qo_heads = q_3d.shape[1] + bs = meta.batch_size + + k_dim_block = HEAD_DIM // K_SPLIT_DIM + v_dim_block = HEAD_DIM // V_SPLIT_DIM + tmp = torch.empty((k_dim_block, length, qo_heads, HEAD_DIM), device=q_3d.device, dtype=q_3d.dtype) + grid = (bs, qo_heads, k_dim_block * v_dim_block) + num_warps = 2 + num_stages = 3 + + seg_la_mtp_kernel[grid]( + q_3d, + k_3d, + v_3d, + s_sglang, + caches_sglang, + tmp, + softmax_scale, + q_3d.stride(0), + k_3d.stride(0), + v_3d.stride(0), + s_sglang.stride(0), + caches_sglang.stride(0), + tmp.stride(0), + s_offsets, + cache_indices, + decay_scales, + step, + HEAD_DIM=HEAD_DIM, + K_SPLIT_DIM=K_SPLIT_DIM, + V_SPLIT_DIM=V_SPLIT_DIM, + num_warps=num_warps, + num_stages=num_stages, + ) + + if k_dim_block > 1: + if length < 2048: + o = tmp.sum(0) + else: + o = torch.empty((length, qo_heads, HEAD_DIM), device=q_3d.device, dtype=q_3d.dtype) + seg_la_sum_kernel[(length,)]( + tmp, + o, + DIM=qo_heads * HEAD_DIM, + NUM_BLOCK=k_dim_block, + num_warps=2, + num_stages=3, + ) + else: + o = tmp[0] + return o + + +# ───────────────────────────────────────────────────────────────────────────── +# SGLang commit wrapper (fused_mamba_state_scatter_with_mask) +# ───────────────────────────────────────────────────────────────────────────── +def run_sglang_commit(s_sglang, caches_sglang, s_offsets, step_indices, B, H, K, V, T): + """ + Invoke fused_mamba_state_scatter_with_mask the way hybrid_linear_attn_backend does. + + dst: [1, pool_size, H*K*V] — state pool (1 layer) + src: [1, B, T, H*K*V] — intermediate caches (1 layer) + """ + elem_per_entry = H * K * V + dst = s_sglang.reshape(1, -1, elem_per_entry) + src = caches_sglang.reshape(1, B, T, elem_per_entry) + fused_mamba_state_scatter_with_mask(dst, src, s_offsets, step_indices) + + +# ───────────────────────────────────────────────────────────────────────────── +# Timing utility +# ───────────────────────────────────────────────────────────────────────────── +def benchmark_fn(fn, warmup=30, rep=200): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + starts = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(rep)] + for i in range(rep): + starts[i].record() + fn() + ends[i].record() + torch.cuda.synchronize() + + times = sorted(s.elapsed_time(e) for s, e in zip(starts, ends)) + n = len(times) + iqr = times[n // 4 : 3 * n // 4] + return sum(iqr) / len(iqr) + + +# ───────────────────────────────────────────────────────────────────────────── +# Core benchmark for one (B, T) configuration +# ───────────────────────────────────────────────────────────────────────────── +def run_config(B, T, H, K, V, layer_idx, num_layers): + device = "cuda" + dtype = torch.bfloat16 + scale = K**-0.5 + HV = H # SGLang seg_la does not support GQA + + g_gamma = -(8 / H * (1 - layer_idx / num_layers)) * torch.arange(H, device=device, dtype=torch.float32) + decay_scales = -g_gamma # cuLA convention: exp(-decay_scales) + + torch.manual_seed(42) + q_4d = torch.randn(B, T, H, K, device=device, dtype=dtype) + k_4d = torch.randn(B, T, H, K, device=device, dtype=dtype) + v_4d = torch.randn(B, T, HV, V, device=device, dtype=dtype) + state_init_kmaj = torch.randn(B, H, K, V, device=device, dtype=torch.float32) * 0.01 + + # ── PyTorch reference ────────────────────────────────────────────────── + with torch.no_grad(): + o_ref, _ = torch_la_mtp_ref(q_4d, k_4d, v_4d, state_init_kmaj, decay_scales, scale) + + # ── SGLang setup ─────────────────────────────────────────────────────── + length = B * T + q_3d = q_4d.reshape(length, H, K).contiguous() + k_3d = k_4d.reshape(length, H, K).contiguous() + v_3d = v_4d.reshape(length, HV, V).contiguous() + + pool_size = B + + # ── SGLang baseline (optional) ────────────────────────────────────────── + rmse_sg = float("nan") + s_sglang = caches_sglang = s_offsets_sg = cache_indices_sg = meta = None + K_SPLIT_DIM = 32 + V_SPLIT_DIM = 32 if B <= 2 else 64 + if _HAVE_SGLANG: + s_sglang = state_init_kmaj.reshape(pool_size, H, K, V).contiguous() + caches_sglang = torch.zeros(pool_size * T, H, K, V, device=device, dtype=torch.float32) + + s_offsets_sg = torch.arange(B, device=device, dtype=torch.int64) + cache_indices_sg = torch.arange(B, device=device, dtype=torch.int64) * T + + q_offsets = torch.arange(B + 1, device=device, dtype=torch.int64) * T + q_lengths = torch.full((B,), T, device=device, dtype=torch.int64) + s_scales = torch.ones(B, device=device, dtype=torch.int64) + + meta = SegLaMeta( + batch_size=B, + max_q_length=T, + q_offsets=q_offsets, + s_offsets=s_offsets_sg, + q_lengths=q_lengths, + s_scales=s_scales, + ) + + # warmup sglang (Triton JIT compile) + with torch.no_grad(): + s_sg_run = s_sglang.clone() + c_sg_run = caches_sglang.clone() + o_sg = run_sglang_mtp( + q_3d, + k_3d, + v_3d, + s_sg_run, + c_sg_run, + s_offsets_sg, + cache_indices_sg, + decay_scales, + meta, + scale, + K, + T, + K_SPLIT_DIM, + V_SPLIT_DIM, + ) + o_sg_4d = o_sg.reshape(B, T, HV, V).float() + rmse_sg = torch.sqrt(torch.mean((o_sg_4d - o_ref) ** 2)).item() + + # ── cuLA MTP setup ───────────────────────────────────────────────────── + # SGLang seg_la_mtp writes intermediate caches but does NOT write back S, + # so the fair comparison is cache_intermediate_states=True, disable_state_update=True. + cache_inter = True + disable_su = True + + s_cute = state_init_kmaj.permute(0, 1, 3, 2).contiguous() # [B, HV, V, K] + out_cute = torch.zeros(B, T, HV, V, device=device, dtype=dtype) + s_offsets_cu = torch.arange(B, device=device, dtype=torch.int32) + inter = torch.zeros(B * T * HV, V, K, device=device, dtype=torch.float32) + cu_seqlens_dummy = torch.empty(1, device=device, dtype=torch.int32) + + with torch.no_grad(): + linear_attention_decode_mtp( + q_4d, + k_4d, + v_4d, + s_cute, + inter, + out_cute, + decay_scales=decay_scales, + s_offsets=s_offsets_cu, + cu_seqlens=cu_seqlens_dummy, + softmax_scale=scale, + T=T, + cache_intermediate_states=cache_inter, + disable_state_update=disable_su, + is_varlen=False, + ) + + out_cute_cmp = out_cute.float() + rmse_cu = torch.sqrt(torch.mean((out_cute_cmp - o_ref) ** 2)).item() + + # ── KVBuffer verify + state-update setup ─────────────────────────────── + s_kvbuf = state_init_kmaj.permute(0, 1, 3, 2).contiguous() # [B, HV, V, K] + out_kvbuf = torch.zeros(B, T, HV, V, device=device, dtype=dtype) + h0_indices_kv = torch.arange(B, device=device, dtype=torch.int32) + accepted_len_kv = torch.full((B,), T, device=device, dtype=torch.int32) + + with torch.no_grad(): + linear_attention_verify_kvbuffer( + q_4d, + k_4d, + v_4d, + s_kvbuf, + out_kvbuf, + decay_scales, + h0_indices_kv, + scale, + T, + ) + s_kvbuf_warmup = state_init_kmaj.permute(0, 1, 3, 2).contiguous() + linear_attention_state_update_kvbuffer( + k_4d, + v_4d, + s_kvbuf_warmup, + decay_scales, + h0_indices_kv, + accepted_len_kv, + T, + ) + + out_kvbuf_cmp = out_kvbuf.float() + rmse_kv = torch.sqrt(torch.mean((out_kvbuf_cmp - o_ref) ** 2)).item() + + # ================================================================== + # Kernel-only timing: pre-compiled handles, no Python overhead + # ================================================================== + + # ---- cuLA kernel-only setup ---- + pool_size = B + tile_v, vec_size, ilp_rows, use_smem_v = get_mtp_config(B, T, HV, V, disable_su) + major, _ = get_device_sm_version(q_4d.device) + use_packed_fma = major >= 10 + stream_handle = cuda_drv.CUstream(torch.cuda.current_stream().cuda_stream) + + # ---- SGLang: Triton kernel is already "kernel-only" (no Python wrapper overhead). + # We just avoid the redundant .clone() on state S, since seg_la_mtp_kernel + # does NOT write back to S (it writes to CACHES only). ---- + s_sg_bench = s_sglang # no clone needed, kernel only reads S + c_sg_bench = caches_sglang + + def kernel_sglang(): + run_sglang_mtp( + q_3d, + k_3d, + v_3d, + s_sg_bench, + c_sg_bench, + s_offsets_sg, + cache_indices_sg, + decay_scales, + meta, + scale, + K, + T, + K_SPLIT_DIM, + V_SPLIT_DIM, + ) + + # ---- SGLang commit setup ---- + step_indices_sg = torch.full((B,), T - 1, device=device, dtype=torch.int32) + + def kernel_sglang_commit(): + run_sglang_commit( + s_sg_bench, + c_sg_bench, + s_offsets_sg.int(), + step_indices_sg, + B, + H, + K, + V, + T, + ) + + # ---- cuLA KVBuffer with actual buffer write/read ---- + k_buf_bench = torch.zeros(pool_size, T, H, K, device=device, dtype=dtype) + v_buf_bench = torch.zeros(pool_size, T, HV, V, device=device, dtype=dtype) + + # Trigger compilation for write_kv=True variant + s_kvbuf_compile = state_init_kmaj.permute(0, 1, 3, 2).contiguous() + out_compile = torch.zeros(B, T, HV, V, device=device, dtype=dtype) + linear_attention_verify_kvbuffer( + q_4d, + k_4d, + v_4d, + s_kvbuf_compile, + out_compile, + decay_scales, + h0_indices_kv, + scale, + T, + k_buf=k_buf_bench, + v_buf=v_buf_bench, + ) + + # linear_attention_verify_kvbuffer dispatches by T: MMA kernel for T>=MMA_MIN_T, + # shuffle kernel otherwise. Fetch the matching pre-compiled handle for timing. + tile_v_kv, vec_size_kv, ilp_rows_kv, _ = get_mtp_config(B, T, HV, V, True) + if T >= MMA_MIN_T: + # match the MMA kernel's ilp_rows=8 override (M=8 fragment fill) + if ilp_rows_kv < 8 and (tile_v_kv // 4) % 8 == 0: + ilp_rows_kv = 8 + verify_buf_cache = _get_compiled_verify_kvbuffer_kernel( + B, + T, + H, + HV, + K, + V, + pool_size, + scale, + tile_v_kv, + vec_size_kv, + ilp_rows_kv, + True, # write_kv + ) + else: + # shuffle kernel: cache_key has no use_smem_v slot + verify_buf_cache = _get_compiled_verify_kvbuffer_kernel_shuffle( + B, + T, + H, + HV, + K, + V, + pool_size, + scale, + tile_v_kv, + vec_size_kv, + ilp_rows_kv, + use_packed_fma, + True, # write_kv + ) + compiled_verify_buf = verify_buf_cache["compiled"] + + s_kvbuf_kk_vb = state_init_kmaj.permute(0, 1, 3, 2).contiguous().view(pool_size * HV, V, K) + out_kvbuf_kk = torch.empty(B, T, HV, V, device=device, dtype=dtype) + + def kernel_kvbuf_verify_with_write(): + compiled_verify_buf( + s_kvbuf_kk_vb, + decay_scales, + q_4d, + k_4d, + v_4d, + out_kvbuf_kk, + h0_indices_kv, + k_buf_bench, + v_buf_bench, + stream_handle, + ) + + # Trigger compilation for read_from_buf=True variant + s_kvbuf_warmup2 = state_init_kmaj.permute(0, 1, 3, 2).contiguous() + linear_attention_state_update_kvbuffer( + k_4d, + v_4d, + s_kvbuf_warmup2, + decay_scales, + h0_indices_kv, + accepted_len_kv, + T, + k_buf=k_buf_bench, + v_buf=v_buf_bench, + ) + + tile_v_su, vec_size_su, ilp_rows_su, _smem_su = get_mtp_config(B, T, HV, V, False) + update_buf_cache_key = ( + B, + T, + H, + HV, + K, + V, + pool_size, + tile_v_su, + vec_size_su, + ilp_rows_su, + use_packed_fma, + True, # read_from_buf + ) + update_buf_cache = _get_compiled_state_update_kernel(*update_buf_cache_key) + compiled_update_buf = update_buf_cache["compiled"] + + s_kvbuf_kk_ub = state_init_kmaj.permute(0, 1, 3, 2).contiguous().view(pool_size * HV, V, K) + + def kernel_kvbuf_update_from_buf(): + compiled_update_buf( + s_kvbuf_kk_ub, + decay_scales, + k_4d, + v_4d, + h0_indices_kv, + accepted_len_kv, + k_buf_bench, + v_buf_bench, + stream_handle, + ) + + with torch.no_grad(): + cu_vfy_ms = benchmark_fn(kernel_kvbuf_verify_with_write) + cu_cmt_ms = benchmark_fn(kernel_kvbuf_update_from_buf) + if _HAVE_SGLANG: + sg_vfy_ms = benchmark_fn(kernel_sglang) + sg_cmt_ms = benchmark_fn(kernel_sglang_commit) + else: + sg_vfy_ms = sg_cmt_ms = float("nan") + + sg_total_ms = sg_vfy_ms + sg_cmt_ms + cu_total_ms = cu_vfy_ms + cu_cmt_ms + + return { + "B": B, + "T": T, + "sg_vfy_ms": sg_vfy_ms, + "sg_cmt_ms": sg_cmt_ms, + "sg_total_ms": sg_total_ms, + "cu_vfy_ms": cu_vfy_ms, + "cu_cmt_ms": cu_cmt_ms, + "cu_total_ms": cu_total_ms, + "speedup": (sg_total_ms / cu_total_ms) if _HAVE_SGLANG else float("nan"), + "rmse_sg": rmse_sg, + "rmse_cu": rmse_cu, + "rmse_kv": rmse_kv, + } + + +# ───────────────────────────────────────────────────────────────────────────── +# Main +# ───────────────────────────────────────────────────────────────────────────── +def main(): + parser = argparse.ArgumentParser(description="Benchmark la_decode_mtp vs SGLang seg_la") + parser.add_argument("--batch-sizes", type=int, nargs="+", default=[1, 2, 4, 8, 16, 32, 64, 128]) + parser.add_argument("--T", type=int, nargs="+", default=[2, 4, 8]) + parser.add_argument("--heads", type=int, default=32) + parser.add_argument("--head-dim", type=int, default=128) + parser.add_argument("--layer-idx", type=int, default=12) + parser.add_argument("--num-layers", type=int, default=24) + args = parser.parse_args() + + H = args.heads + K = V = args.head_dim + + print("LA KVBuffer verify + state-update benchmark (cuLA, optional SGLang baseline)") + print(f" H={H}, K={K}, V={V}, layer={args.layer_idx}/{args.num_layers}") + print(" dtype=bf16, state=fp32") + print(f" USE_FAST_MATH={USE_FAST_MATH}") + print(" cuLA MTP: cache_intermediate_states=True, disable_state_update=True") + print(" Timing: kernel-only (cuLA pre-compiled handle; SGLang no extra .clone())") + if _HAVE_SGLANG: + print(" SGLang baseline: AVAILABLE (sg_* columns active)") + else: + print(f" SGLang baseline: UNAVAILABLE — sg_* columns show nan. ({_SGLANG_ERR})") + print(" set LA_SGLANG_PYTHON=/path/to/sglang/python to enable the comparison.") + + hdr = ( + f"{'B':>4} | {'T':>3} | " + f"{'sg_vfy(ms)':>10} | {'sg_cmt(ms)':>10} | {'sg_total':>9} | " + f"{'cu_vfy(ms)':>10} | {'cu_cmt(ms)':>10} | {'cu_total':>9} | " + f"{'speedup':>7} | " + f"{'rmse_sg':>9} | {'rmse_cu':>9} | {'rmse_kv':>9}" + ) + print(f"\n{hdr}") + print("─" * len(hdr)) + + for T_val in args.T: + for B in args.batch_sizes: + r = run_config(B, T_val, H, K, V, args.layer_idx, args.num_layers) + print( + f"{r['B']:>4} | {r['T']:>3} | " + f"{r['sg_vfy_ms']:>10.4f} | {r['sg_cmt_ms']:>10.4f} | {r['sg_total_ms']:>9.4f} | " + f"{r['cu_vfy_ms']:>10.4f} | {r['cu_cmt_ms']:>10.4f} | {r['cu_total_ms']:>9.4f} | " + f"{r['speedup']:>6.2f}x | " + f"{r['rmse_sg']:>9.6f} | {r['rmse_cu']:>9.6f} | {r['rmse_kv']:>9.6f}" + ) + print() + + # Memory comparison + sg_mem = B * T_val * H * K * V * 4 + cu_mem = B * T_val * (H * K + H * V) * 2 + print(f"Memory per-pool (B={args.batch_sizes[-1]}, T={args.T[-1]}):") + print(f" SGLang intermediate caches: {sg_mem / 1e6:.1f} MB") + print(f" cuLA KV buffer: {cu_mem / 1e6:.1f} MB") + print(f" Ratio: {sg_mem / cu_mem:.0f}×") + + print("\nColumns:") + print(" sg_vfy : seg_la_mtp_kernel (Triton, SGLang upstream)") + print(" sg_cmt : fused_mamba_state_scatter_with_mask (Triton, SGLang)") + print(" cu_vfy : verify_kvbuffer with KV buffer write (CuTe DSL)") + print(" cu_cmt : state_update_kvbuffer reading from buffer (CuTe DSL)") + print(" speedup : sg_total / cu_total") + print(" rmse_* : RMSE vs PyTorch reference") + + +if __name__ == "__main__": + main() diff --git a/cula/lightning/__init__.py b/cula/lightning/__init__.py index fb5e5635..a97d5169 100644 --- a/cula/lightning/__init__.py +++ b/cula/lightning/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from cula.lightning.la_decode_mtp import linear_attention_decode_mtp +from cula.lightning.la_state_update_kvbuffer import linear_attention_state_update_kvbuffer +from cula.lightning.la_verify_kvbuffer import linear_attention_verify_kvbuffer from cula.ops.la_decode import linear_attention_decode from cula.ops.lightning_attn_sm100 import ( LinearAttentionChunkwiseDecay, @@ -24,4 +27,7 @@ "lightning_attn_fwd", "lightning_attn_fwd_varlen", "linear_attention_decode", + "linear_attention_decode_mtp", + "linear_attention_verify_kvbuffer", + "linear_attention_state_update_kvbuffer", ] diff --git a/cula/lightning/la_decode_mtp.py b/cula/lightning/la_decode_mtp.py new file mode 100644 index 00000000..3c150cfc --- /dev/null +++ b/cula/lightning/la_decode_mtp.py @@ -0,0 +1,560 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Lightning Attention MTP (Multi-Token Processing) Decode Kernel. + +Processes T > 1 tokens in one launch with h held in registers across the +whole T-loop. Targeted at speculative-decoding verify scenarios. + +Per timestep: + h_t = exp(-decay_scales[h]) * h_{t-1} + k_t ⊗ v_t + o_t = (h_t @ q_t) * softmax_scale + +`decay_scales` is per-head and time-invariant, so `r_decay` is computed ONCE +outside the T-loop. + +Grid: (B * HV * num_v_tiles, 1, 1). Each block handles one [tile_v] slice +across all T timesteps; h for that slice stays in registers. + +Reference: flashinfer/flashinfer/gdn_kernels/gdn_decode_mtp.py (inline variant). +""" + +import functools + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass.cute.runtime import from_dlpack + +from cula.utils import USE_FAST_MATH, get_device_sm_version + +# ============================================================================ +# Global configuration +# ============================================================================ +TILE_K_MTP = 128 +NUM_THREADS_MTP = 128 # 4 warps + + +# ============================================================================ +# FMA pair helpers (packed F32x2 on SM100; scalar fallback on SM90) +# ============================================================================ +@cute.jit +def la_update_pair(h_lo, h_hi, k_lo, k_hi, v_j, decay, use_packed_fma: cutlass.Constexpr[bool]): + """Inner LA recurrence on a (lo, hi) pair: h = h*decay + k*v_j.""" + if cutlass.const_expr(use_packed_fma): + # h *= decay (packed mul implemented as FMA with src_c=0) + h_lo, h_hi = cute.arch.fma_packed_f32x2( + src_a=(h_lo, h_hi), + src_b=(decay, decay), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + # h += k * v_j + h_lo, h_hi = cute.arch.fma_packed_f32x2( + src_a=(k_lo, k_hi), + src_b=(v_j, v_j), + src_c=(h_lo, h_hi), + ) + return h_lo, h_hi + else: + return h_lo * decay + k_lo * v_j, h_hi * decay + k_hi * v_j + + +@cute.jit +def hq_dot_pair(h_lo, h_hi, q_lo, q_hi, sum_lo, sum_hi, use_packed_fma: cutlass.Constexpr[bool]): + """Accumulate dot product over a (lo, hi) pair: sum += h * q.""" + if cutlass.const_expr(use_packed_fma): + return cute.arch.fma_packed_f32x2( + src_a=(h_lo, h_hi), + src_b=(q_lo, q_hi), + src_c=(sum_lo, sum_hi), + ) + else: + return h_lo * q_lo + sum_lo, h_hi * q_hi + sum_hi + + +# TODO: re-tune for LA after first benchmark. +# TODO (perf): for configs with row_iters > 1 (e.g. tile_v=64, ilp=4), q/k are +# reloaded from global on every row-loop iteration because the row-outer / T-inner +# structure is required to keep h register-resident across T (r_h budget is 8 rows). +# Stage q/k in SMEM per i_t (cooperative load + barrier) to avoid the (row_iters - 1) +# redundant reads; worst case (tile_v=64, ilp=4) wastes 3x the q/k bandwidth. +def get_mtp_config(B: int, T: int, HV: int, V: int, disable_state_update: bool) -> tuple: + """Pick (tile_v, vec_size, ilp_rows, use_smem_v) based on work units. + + Thresholds ported from GDN MTP (B200 grid search on Qwen3.5, HV=64). + LA's per-step compute is ~30% lighter (no delta rule), so we may need + to retune; the structure is preserved for now. + """ + work_units = B * HV + vec_size = 4 + + if work_units <= 64: + tile_v, ilp_rows, use_smem_v = 8, 2, False + elif work_units <= 128: + tile_v, ilp_rows, use_smem_v = 16, 4, False + elif work_units <= 448: + if T <= 2: + tile_v, ilp_rows, use_smem_v = 16, 2, False + else: + tile_v, ilp_rows, use_smem_v = 32, 4, False + elif work_units <= 1024: + tile_v, ilp_rows, use_smem_v = 32, 4, False + else: + tile_v = 64 + use_smem_v = True + ilp_rows = 4 + if not disable_state_update and T <= 2: + ilp_rows = 8 + use_smem_v = False + + tile_v = min(tile_v, V) + rows_per_group = tile_v // 4 + assert rows_per_group % ilp_rows == 0, ( + f"tile_v={tile_v} / num_groups=4 / ilp_rows={ilp_rows} doesn't divide cleanly " + f"(rows_per_group={rows_per_group}); the ILP loop would run zero iterations." + ) + return tile_v, vec_size, ilp_rows, use_smem_v + + +# ============================================================================ +# Kernel +# ============================================================================ +@cute.kernel +def la_verify_kernel_mtp( + h0_source: cute.Tensor, # [pool_size * HV, V, K] fp32 + intermediate_states: cute.Tensor, # [pool_size * T * HV, V, K] fp32 (or dummy) + decay_scales: cute.Tensor, # [H] fp32 + q: cute.Tensor, # [B, T, H, K] bf16 + k: cute.Tensor, # [B, T, H, K] bf16 + v: cute.Tensor, # [B, T, HV, V] bf16 + o: cute.Tensor, # [B, T, HV, V] bf16 + h0_indices: cute.Tensor, # [B] int32 + cu_seqlens: cute.Tensor, # [B+1] int32 (dummy when is_varlen=False) + vec_size: cutlass.Constexpr[int], + num_v_tiles: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], + scale: cutlass.Constexpr[float], + B: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + disable_state_update: cutlass.Constexpr[bool], + cache_intermediate_states: cutlass.Constexpr[bool], + is_varlen: cutlass.Constexpr[bool], + ilp_rows: cutlass.Constexpr[int], + use_smem_v: cutlass.Constexpr[bool], + use_packed_fma: cutlass.Constexpr[bool], +): + tidx, _, _ = cute.arch.thread_idx() + lane_id = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + threads_per_group: cutlass.Constexpr[int] = K // vec_size # 32 + groups_per_warp: cutlass.Constexpr[int] = 32 // threads_per_group # 1 + num_groups: cutlass.Constexpr[int] = 4 * groups_per_warp # 4 + + lane_in_group = lane_id % threads_per_group + group_in_warp = lane_id // threads_per_group + group_idx = warp_idx * groups_per_warp + group_in_warp + + block_idx, _, _ = cute.arch.block_idx() + i_v = block_idx % num_v_tiles + tmp = block_idx // num_v_tiles + i_hv = tmp % HV + i_n = tmp // HV + i_h = i_hv // (HV // H) + + cache_idx = h0_indices[i_n] + + # ------------------------------------------------------------------ + # SMEM allocation (sVdata + sOutput only — LA has no Phase 1 work) + # ------------------------------------------------------------------ + smem = cutlass.utils.SmemAllocator() + sVdata = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T, tile_v), stride=(tile_v, 1)), 16) + sOutput = smem.allocate_tensor(cutlass.BFloat16, cute.make_layout((T, tile_v), stride=(tile_v, 1)), 16) + + # ------------------------------------------------------------------ + # Register tensors + # ------------------------------------------------------------------ + r_q = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_k = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_q_bf16 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + r_k_bf16 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + # r_h always declared with 8 rows; ilp_rows constexpr picks which are used. + r_h = cute.make_rmem_tensor(cute.make_layout((8, vec_size), stride=(vec_size, 1)), cutlass.Float32) + + if cache_idx >= 0: + # r_decay is a T-loop invariant — computed ONCE. + r_decay = cute.exp(-cutlass.Float32(decay_scales[i_h]), fastmath=USE_FAST_MATH) + + # Optional v preload to SMEM (cooperative load across the whole block). + if cutlass.const_expr(use_smem_v): + for i_t in cutlass.range_constexpr(T): + v_tile_start = i_v * tile_v + if tidx < tile_v: + v_global_idx = v_tile_start + tidx + if v_global_idx < V: + sVdata[(i_t, tidx)] = cutlass.Float32(v[i_n, i_t, i_hv, v_global_idx]) + cute.arch.barrier() + + rows_per_group: cutlass.Constexpr[int] = tile_v // num_groups + flat_state_idx = cache_idx * HV + i_hv + + # Process `ilp_rows` V-rows per iteration. ilp_rows is a compile-time + # constant, so range_constexpr fully unrolls the slot loops below — the + # generated SASS is identical to hand-unrolling each ilp_rows value, but + # one loop covers ilp_rows ∈ {2, 4, 8}. + num_chunks: cutlass.Constexpr[int] = rows_per_group // ilp_rows + for chunk in cutlass.range_constexpr(num_chunks): + v_idx_0 = i_v * tile_v + group_idx * rows_per_group + chunk * ilp_rows + if v_idx_0 + (ilp_rows - 1) < V: + # Load ilp_rows h-state rows ONCE; they stay register-resident across T. + for slot in cutlass.range_constexpr(ilp_rows): + h_tile = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_0 + slot, lane_in_group), + ) + cute.autovec_copy(h_tile, cute.slice_(r_h, (slot, None))) + + for i_t in cutlass.range_constexpr(T): + # ---- inline q/k load for this t ---- + q_tile = cute.local_tile( + q, + (1, 1, 1, vec_size), + (i_n, i_t, i_h, lane_in_group), + ) + k_tile = cute.local_tile( + k, + (1, 1, 1, vec_size), + (i_n, i_t, i_h, lane_in_group), + ) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) * scale + r_k[i] = cutlass.Float32(r_k_bf16[i]) + + # Per-row dot-product accumulators (lo, hi) — zeroed each t step. + r_dot_lo = cute.make_rmem_tensor(cute.make_layout((ilp_rows,), stride=(1,)), cutlass.Float32) + r_dot_hi = cute.make_rmem_tensor(cute.make_layout((ilp_rows,), stride=(1,)), cutlass.Float32) + for slot in cutlass.range_constexpr(ilp_rows): + r_dot_lo[slot] = cutlass.Float32(0.0) + r_dot_hi[slot] = cutlass.Float32(0.0) + + # ---- fused decay + rank-1 update (per V-row) ---- + for slot in cutlass.range_constexpr(ilp_rows): + if cutlass.const_expr(use_smem_v): + r_v_s = sVdata[(i_t, v_idx_0 - i_v * tile_v + slot)] + else: + r_v_s = cutlass.Float32(v[i_n, i_t, i_hv, v_idx_0 + slot]) + for j in cutlass.range_constexpr(0, vec_size, 2): + r_h[slot, j], r_h[slot, j + 1] = la_update_pair( + r_h[slot, j], + r_h[slot, j + 1], + r_k[j], + r_k[j + 1], + r_v_s, + r_decay, + use_packed_fma, + ) + + # ---- optional intermediate-state cache ---- + if cutlass.const_expr(cache_intermediate_states): + flat_idx = i_n * T * HV + i_t * HV + i_hv + for slot in cutlass.range_constexpr(ilp_rows): + inter_tile = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v_idx_0 + slot, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (slot, None)), inter_tile) + + # ---- o_t = h_t @ q_t (per-row warp reduce) ---- + for slot in cutlass.range_constexpr(ilp_rows): + for j in cutlass.range_constexpr(0, vec_size, 2): + r_dot_lo[slot], r_dot_hi[slot] = hq_dot_pair( + r_h[slot, j], + r_h[slot, j + 1], + r_q[j], + r_q[j + 1], + r_dot_lo[slot], + r_dot_hi[slot], + use_packed_fma, + ) + r_acc = r_dot_lo[slot] + r_dot_hi[slot] + for offset in [16, 8, 4, 2, 1]: + r_acc += cute.arch.shuffle_sync_bfly(r_acc, offset=offset, mask=-1, mask_and_clamp=31) + r_dot_lo[slot] = r_acc # reuse slot for final result + + # ---- writeback ---- + if lane_in_group == 0: + if cutlass.const_expr(use_smem_v): + vla = v_idx_0 - i_v * tile_v + for slot in cutlass.range_constexpr(ilp_rows): + sOutput[(i_t, vla + slot)] = cutlass.BFloat16(r_dot_lo[slot]) + else: + for slot in cutlass.range_constexpr(ilp_rows): + o[(i_n, i_t, i_hv, v_idx_0 + slot)] = cutlass.BFloat16(r_dot_lo[slot]) + + # Final state writeback + if cutlass.const_expr(not disable_state_update): + for slot in cutlass.range_constexpr(ilp_rows): + h_tile_out = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_0 + slot, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (slot, None)), h_tile_out) + + # Cooperative output writeback (only when use_smem_v staged outputs to SMEM) + if cutlass.const_expr(use_smem_v): + cute.arch.barrier() + v_tile_base = i_v * tile_v + for t_idx in cutlass.range_constexpr(T): + if tidx < tile_v: + v_global = v_tile_base + tidx + if v_global < V: + o[(i_n, t_idx, i_hv, v_global)] = sOutput[(t_idx, tidx)] + + +# ============================================================================ +# Launcher +# ============================================================================ +@cute.jit +def run_la_verify_kernel_mtp( + h0_source: cute.Tensor, + intermediate_states: cute.Tensor, + decay_scales: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + o: cute.Tensor, + h0_indices: cute.Tensor, + cu_seqlens: cute.Tensor, + scale: cutlass.Constexpr[float], + B: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], + vec_size: cutlass.Constexpr[int], + ilp_rows: cutlass.Constexpr[int], + use_smem_v: cutlass.Constexpr[bool], + use_packed_fma: cutlass.Constexpr[bool], + disable_state_update: cutlass.Constexpr[bool], + cache_intermediate_states: cutlass.Constexpr[bool], + is_varlen: cutlass.Constexpr[bool], + stream: cuda.CUstream, +): + _, v_dim, _ = ( + h0_source.layout.shape[0], + h0_source.layout.shape[1], + h0_source.layout.shape[2], + ) + + num_v_tiles = cute.ceil_div(v_dim, tile_v) + grid_size = B * HV * num_v_tiles + + smem_bytes = ( + 4 * T * tile_v # sVdata + + 2 * T * tile_v # sOutput + + 128 # alignment + ) + + la_verify_kernel_mtp( + h0_source, + intermediate_states, + decay_scales, + q, + k, + v, + o, + h0_indices, + cu_seqlens, + vec_size, + num_v_tiles, + tile_v, + scale, + B, + T, + H, + HV, + K, + V, + disable_state_update, + cache_intermediate_states, + is_varlen, + ilp_rows, + use_smem_v, + use_packed_fma, + ).launch( + grid=(grid_size, 1, 1), + block=[NUM_THREADS_MTP, 1, 1], + smem=smem_bytes, + stream=stream, + ) + + +# ============================================================================ +# Compile cache +# ============================================================================ +@functools.cache +def _get_compiled_la_mtp_kernel( + B: int, + T: int, + H: int, + HV: int, + K: int, + V: int, + pool_size: int, + softmax_scale: float, + disable_state_update: bool, + cache_intermediate_states: bool, + is_varlen: bool, + tile_v: int, + vec_size: int, + ilp_rows: int, + use_smem_v: bool, + use_packed_fma: bool, +): + return {} + + +# ============================================================================ +# Public Python entry point +# ============================================================================ +def linear_attention_decode_mtp( + q: torch.Tensor, # [B, T, H, K] bf16 + k: torch.Tensor, # [B, T, H, K] bf16 + v: torch.Tensor, # [B, T, HV, V] bf16 + s: torch.Tensor, # [pool_size, HV, V, K] fp32 + intermediate_states: torch.Tensor, # [pool_size*T*HV, V, K] fp32 (or dummy) + out: torch.Tensor, # [B, T, HV, V] bf16 + decay_scales: torch.Tensor, # [H] fp32 + s_offsets: torch.Tensor, # [B] int32 (-1 to skip) + cu_seqlens: torch.Tensor, # [B+1] int32 (reserved; see note below) + softmax_scale: float, + T: int, + cache_intermediate_states: bool, + disable_state_update: bool, + is_varlen: bool, +) -> None: + """ + Lightning Attention multi-token decode (T > 1). + + Writes to ``out``; updates ``s`` in place unless ``disable_state_update`` is True; + writes ``intermediate_states`` when ``cache_intermediate_states`` is True. + + NOTE: For any batch ``i`` where ``s_offsets[i] < 0`` the kernel skips that batch + entirely — ``out[i]`` is LEFT UNCHANGED, and neither ``s`` nor + ``intermediate_states`` is written for that slot. Callers must initialize ``out`` + to a known value (e.g. ``torch.zeros``) before the call if any downstream code + may read those slots. + + NOTE: ``is_varlen`` and ``cu_seqlens`` are reserved in the signature to keep the + public API stable, but the early-stop branch is NOT implemented yet — same as + upstream flashinfer GDN MTP, which also exposes the flag without consuming it. + Callers should pass ``is_varlen=False`` and any int32 tensor for ``cu_seqlens``. + The kernel descriptor is built with ``assumed_align=16``, so even the dummy + ``cu_seqlens`` must be 16-byte aligned; pass a fresh ``torch.empty(N, dtype=int32)`` + (CUDA allocator guarantees alignment) — do NOT pass a slice that may misalign. + """ + B, T_q, H, K = q.shape + assert T_q == T, f"q.shape[1]={T_q} doesn't match T={T}" + _, _, HV, V = v.shape + pool_size = s.shape[0] + + tile_v, vec_size, ilp_rows, use_smem_v = get_mtp_config(B, T, HV, V, disable_state_update) + assert V % ilp_rows == 0, f"V={V} % ilp_rows={ilp_rows} ≠ 0: partial row-blocks would be silently skipped" + major, _ = get_device_sm_version(q.device) + use_packed_fma = major >= 10 + + cache_key = ( + B, + T, + H, + HV, + K, + V, + pool_size, + softmax_scale, + disable_state_update, + cache_intermediate_states, + is_varlen, + tile_v, + vec_size, + ilp_rows, + use_smem_v, + use_packed_fma, + ) + cache = _get_compiled_la_mtp_kernel(*cache_key) + + h0_view = s.view(pool_size * HV, V, K) + + if "compiled" not in cache: + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compiled = cute.compile( + run_la_verify_kernel_mtp, + from_dlpack(h0_view, assumed_align=16), + from_dlpack(intermediate_states, assumed_align=16), + from_dlpack(decay_scales, assumed_align=16), + from_dlpack(q, assumed_align=16), + from_dlpack(k, assumed_align=16), + from_dlpack(v, assumed_align=16), + from_dlpack(out, assumed_align=16), + from_dlpack(s_offsets, assumed_align=16), + from_dlpack(cu_seqlens, assumed_align=16), + scale=softmax_scale, + B=B, + T=T, + H=H, + HV=HV, + K=K, + V=V, + tile_v=tile_v, + vec_size=vec_size, + ilp_rows=ilp_rows, + use_smem_v=use_smem_v, + use_packed_fma=use_packed_fma, + disable_state_update=disable_state_update, + cache_intermediate_states=cache_intermediate_states, + is_varlen=is_varlen, + stream=stream, + options="--enable-tvm-ffi", + ) + cache["compiled"] = compiled + + compiled = cache["compiled"] + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + compiled( + h0_view, + intermediate_states, + decay_scales, + q, + k, + v, + out, + s_offsets, + cu_seqlens, + stream, + ) diff --git a/cula/lightning/la_state_update_kvbuffer.py b/cula/lightning/la_state_update_kvbuffer.py new file mode 100644 index 00000000..351f58c3 --- /dev/null +++ b/cula/lightning/la_state_update_kvbuffer.py @@ -0,0 +1,313 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Lightning Attention KVBuffer state-update kernel (paper Eq. 8 for LA). + +After a parallel-verify cycle, advances the pooled state from h_init to +h_state_L for a per-batch accepted prefix length L = accepted_len[b]: + + h_running = h_init + for i in 0..L-1: + h_running = exp(-decay_scales[h]) * h_running + k_i ⊗ v_i + s[cache_idx] = h_running + +The loop body is bit-identical to the baseline T-loop body, so at L == T the +result is bit-equivalent to running the baseline with disable_state_update=False. + +Reads s, k, v; writes s. Never touches q or o. + +Grid: (B * HV * num_v_tiles, 1, 1), 128 threads/block — identical layout to the +baseline verify kernel, so the state write aligns with the verify kernel's h0 read. +""" + +import functools + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass.cute.runtime import from_dlpack + +from cula.lightning.la_decode_mtp import ( + NUM_THREADS_MTP, + get_mtp_config, + la_update_pair, +) +from cula.utils import USE_FAST_MATH, get_device_sm_version + + +@cute.kernel +def la_state_update_kernel( + h0_source: cute.Tensor, # [pool_size * HV, V, K] fp32 (read + written in place) + decay_scales: cute.Tensor, # [H] fp32 + k: cute.Tensor, # [B, T, H, K] bf16 + v: cute.Tensor, # [B, T, HV, V] bf16 + h0_indices: cute.Tensor, # [B] int32 + accepted_len: cute.Tensor, # [B] int32 + k_buf: cute.Tensor, # [pool_size, T, H, K] bf16 (READ when read_from_buf) + v_buf: cute.Tensor, # [pool_size, T, HV, V] bf16 (READ when read_from_buf) + vec_size: cutlass.Constexpr[int], + num_v_tiles: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], + B: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + ilp_rows: cutlass.Constexpr[int], + use_packed_fma: cutlass.Constexpr[bool], + read_from_buf: cutlass.Constexpr[bool], +): + tidx, _, _ = cute.arch.thread_idx() + lane_id = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + threads_per_group: cutlass.Constexpr[int] = K // vec_size # 32 + groups_per_warp: cutlass.Constexpr[int] = 32 // threads_per_group # 1 + num_groups: cutlass.Constexpr[int] = 4 * groups_per_warp # 4 + + lane_in_group = lane_id % threads_per_group + group_in_warp = lane_id // threads_per_group + group_idx = warp_idx * groups_per_warp + group_in_warp + + block_idx, _, _ = cute.arch.block_idx() + i_v = block_idx % num_v_tiles + tmp = block_idx // num_v_tiles + i_hv = tmp % HV + i_n = tmp // HV + i_h = i_hv // (HV // H) + + cache_idx = h0_indices[i_n] + L = accepted_len[i_n] + + r_k = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_k_bf16 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + r_h = cute.make_rmem_tensor(cute.make_layout((8, vec_size), stride=(vec_size, 1)), cutlass.Float32) + + if cache_idx >= 0 and L > 0: + r_decay = cute.exp(-cutlass.Float32(decay_scales[i_h]), fastmath=USE_FAST_MATH) + rows_per_group: cutlass.Constexpr[int] = tile_v // num_groups + flat_state_idx = cache_idx * HV + i_hv + + # Process `ilp_rows` V-rows per iteration. ilp_rows is a compile-time + # constant, so range_constexpr fully unrolls the slot loops below — the + # generated SASS is identical to hand-unrolling each ilp_rows value, but + # one loop covers ilp_rows in {2, 4, 8}. + num_chunks: cutlass.Constexpr[int] = rows_per_group // ilp_rows + for chunk in cutlass.range_constexpr(num_chunks): + v_idx_0 = i_v * tile_v + group_idx * rows_per_group + chunk * ilp_rows + if v_idx_0 + (ilp_rows - 1) < V: + # Load the ilp_rows h-state rows this thread owns into registers. + for slot in cutlass.range_constexpr(ilp_rows): + h_tile = cute.local_tile(h0_source, (1, 1, vec_size), (flat_state_idx, v_idx_0 + slot, lane_in_group)) + cute.autovec_copy(h_tile, cute.slice_(r_h, (slot, None))) + + # Recurrence: h = decay * h + k_i (x) v_i, for i in 0..L-1. + for i in cutlass.range(0, L, unroll=0): + if cutlass.const_expr(read_from_buf): + k_tile = cute.local_tile(k_buf, (1, 1, 1, vec_size), (cache_idx, i, i_h, lane_in_group)) + else: + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i, i_h, lane_in_group)) + cute.autovec_copy(k_tile, r_k_bf16) + for j in cutlass.range_constexpr(vec_size): + r_k[j] = cutlass.Float32(r_k_bf16[j]) + for slot in cutlass.range_constexpr(ilp_rows): + if cutlass.const_expr(read_from_buf): + r_v_s = cutlass.Float32(v_buf[cache_idx, i, i_hv, v_idx_0 + slot]) + else: + r_v_s = cutlass.Float32(v[i_n, i, i_hv, v_idx_0 + slot]) + for j in cutlass.range_constexpr(0, vec_size, 2): + r_h[slot, j], r_h[slot, j + 1] = la_update_pair( + r_h[slot, j], r_h[slot, j + 1], r_k[j], r_k[j + 1], r_v_s, r_decay, use_packed_fma + ) + + # Write the advanced state back in place. + for slot in cutlass.range_constexpr(ilp_rows): + h_out = cute.local_tile(h0_source, (1, 1, vec_size), (flat_state_idx, v_idx_0 + slot, lane_in_group)) + cute.autovec_copy(cute.slice_(r_h, (slot, None)), h_out) + + +@cute.jit +def run_la_state_update_kernel( + h0_source: cute.Tensor, + decay_scales: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + h0_indices: cute.Tensor, + accepted_len: cute.Tensor, + k_buf: cute.Tensor, + v_buf: cute.Tensor, + B: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], + vec_size: cutlass.Constexpr[int], + ilp_rows: cutlass.Constexpr[int], + use_packed_fma: cutlass.Constexpr[bool], + read_from_buf: cutlass.Constexpr[bool], + stream: cuda.CUstream, +): + num_v_tiles: cutlass.Constexpr[int] = (V + tile_v - 1) // tile_v + grid_size = B * HV * num_v_tiles + + la_state_update_kernel( + h0_source, + decay_scales, + k, + v, + h0_indices, + accepted_len, + k_buf, + v_buf, + vec_size, + num_v_tiles, + tile_v, + B, + T, + H, + HV, + K, + V, + ilp_rows, + use_packed_fma, + read_from_buf, + ).launch( + grid=(grid_size, 1, 1), + block=[NUM_THREADS_MTP, 1, 1], + stream=stream, + ) + + +@functools.cache +def _get_compiled_state_update_kernel( + B: int, + T: int, + H: int, + HV: int, + K: int, + V: int, + pool_size: int, + tile_v: int, + vec_size: int, + ilp_rows: int, + use_packed_fma: bool, + read_from_buf: bool, +): + return {} + + +def linear_attention_state_update_kvbuffer( + k: torch.Tensor, # [B, T, H, K] bf16 — read when k_buf is None + v: torch.Tensor, # [B, T, HV, V] bf16 — read when v_buf is None + s: torch.Tensor, # [pool_size, HV, V, K] fp32, WRITTEN IN PLACE + decay_scales: torch.Tensor, # [H] fp32 + h0_indices: torch.Tensor, # [B] int32, -1 to skip + accepted_len: torch.Tensor, # [B] int32, in [0, T] + T: int, + k_buf: torch.Tensor | None = None, # [pool_size, T, H, K] bf16 + v_buf: torch.Tensor | None = None, # [pool_size, T, HV, V] bf16 +) -> None: + """ + Advance pooled state from h_init to h_state_L per batch (KVBuffer Eq. 8). + + When k_buf and v_buf are provided, reads k,v from pool-indexed buffers + instead of batch-indexed input tensors. + """ + B, T_k, H, K = k.shape + assert T_k == T, f"k.shape[1]={T_k} doesn't match T={T}" + assert K == 128, f"K={K} != 128: kernel hardcodes K=128 (threads_per_group, lane K-coverage)" + _, _, HV, V = v.shape + pool_size = s.shape[0] + + read_from_buf = k_buf is not None and v_buf is not None + if (k_buf is None) != (v_buf is None): + raise ValueError("k_buf and v_buf must both be None or both be provided") + + tile_v, vec_size, ilp_rows, _use_smem_v = get_mtp_config(B, T, HV, V, False) + assert V % ilp_rows == 0, f"V={V} % ilp_rows={ilp_rows} ≠ 0: partial row-blocks would be silently skipped" + major, _ = get_device_sm_version(k.device) + use_packed_fma = major >= 10 + + cache_key = ( + B, + T, + H, + HV, + K, + V, + pool_size, + tile_v, + vec_size, + ilp_rows, + use_packed_fma, + read_from_buf, + ) + cache = _get_compiled_state_update_kernel(*cache_key) + + h0_view = s.view(pool_size * HV, V, K) + + if not read_from_buf: + k_buf_t = torch.empty(1, 1, 1, 1, device=k.device, dtype=torch.bfloat16) + v_buf_t = torch.empty(1, 1, 1, 1, device=k.device, dtype=torch.bfloat16) + else: + k_buf_t = k_buf + v_buf_t = v_buf + + if "compiled" not in cache: + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + compiled = cute.compile( + run_la_state_update_kernel, + from_dlpack(h0_view, assumed_align=16), + from_dlpack(decay_scales, assumed_align=16), + from_dlpack(k, assumed_align=16), + from_dlpack(v, assumed_align=16), + from_dlpack(h0_indices, assumed_align=16), + from_dlpack(accepted_len, assumed_align=16), + from_dlpack(k_buf_t, assumed_align=16), + from_dlpack(v_buf_t, assumed_align=16), + B=B, + T=T, + H=H, + HV=HV, + K=K, + V=V, + tile_v=tile_v, + vec_size=vec_size, + ilp_rows=ilp_rows, + use_packed_fma=use_packed_fma, + read_from_buf=read_from_buf, + stream=stream, + options="--enable-tvm-ffi", + ) + cache["compiled"] = compiled + + compiled = cache["compiled"] + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + compiled( + h0_view, + decay_scales, + k, + v, + h0_indices, + accepted_len, + k_buf_t, + v_buf_t, + stream, + ) diff --git a/cula/lightning/la_verify_kvbuffer.py b/cula/lightning/la_verify_kvbuffer.py new file mode 100644 index 00000000..2b9e568d --- /dev/null +++ b/cula/lightning/la_verify_kvbuffer.py @@ -0,0 +1,942 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Lightning Attention KVBuffer verify kernel (paper Eq. 7 for LA). + +Closed-form parallel verification — computes each draft step's output directly +from (h0, k, v) without materializing the intermediate states: + + o_t = alpha^{t+1} * (h0 @ q_t * scale) <- "term1" (HQ) + + sum_{i=0..t} alpha^{t-i} * (q_t . k_i) * scale * v_i <- "term2" (QK·V) + +The two dot-product GEMMs run on tensor cores via inline-PTX mma.sync.m16n8k8 +(TF32). Operands are staged in fp32 SMEM (manual fragment addressing — no +LdMatrix/StMatrix). Everything downstream of the GEMMs is plain scalar math. + +PARALLELISM + Grid: (B * HV * num_v_tiles, 1, 1) — one block per (sequence, v-head, V-tile) + Block: 128 threads = 4 warps. Each warp owns `rows_per_group` output V-rows. + +PIPELINE (per block) + Stage 0 cooperative load q*scale, k -> SMEM (sQ, sK) + Stage 1 GEMM2: QK[t,i] = q_t . k_i (warp 0 only) -> s_qk_scaled + Stage 2 per V-row-block: load h0 -> SMEM, GEMM1: HQ = h0 @ q_t, + then scalar combine term1+term2 -> o + +MMA m16n8k8 FRAGMENT MAP (lane = gid*4 + tig, gid=lane//4 in 0..7, tig=lane%4 in 0..3) + A[16,8] row-major : a0=A[gid,tig] a1=A[gid+8,tig] a2=A[gid,tig+4] a3=A[gid+8,tig+4] + B[8,8] col-major : b0=B[tig,gid] b1=B[tig+4,gid] + C[16,8] : c0=C[gid,2tig] c1=C[gid,2tig+1] c2=C[gid+8,2tig] c3=C[gid+8,2tig+1] + We only have 8 valid rows (BT=8), so A rows 8..15 are fed as zeros and the + corresponding outputs c2,c3 / e2,e3 are unused padding. +""" + +import functools + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass._mlir.dialects import arith as _arith +from cutlass._mlir.dialects import llvm as _llvm +from cutlass.cute.runtime import from_dlpack +from cutlass.cutlass_dsl import T as _T +from cutlass.cutlass_dsl import dsl_user_op + +from cula.lightning.la_decode_mtp import ( + NUM_THREADS_MTP, + get_mtp_config, + hq_dot_pair, +) +from cula.utils import USE_FAST_MATH, get_device_sm_version + +# Dispatch threshold between the two verify implementations. +# The MMA (tensor-core) kernel wins at T>=4 (matches at T=4, +45% at T=8 vs the +# shuffle kernel), but the shuffle kernel wins at small T (T<=2) where the MMA +# GEMMs are under-utilised and its larger SMEM footprint caps occupancy. +# See docs/la_verify_kvbuffer_dev_history.md §6 for the full benchmark. +MMA_MIN_T: int = 4 + +# --------------------------------------------------------------------------- +# Inline PTX mma.sync.m16n8k8.tf32 — copied from kda_decode_mtp_kvbuffer.py +# --------------------------------------------------------------------------- + + +@dsl_user_op +def _mma_m16n8k8_tf32(a0, a1, a2, a3, b0, b1, c0, c1, c2, c3, *, loc=None, ip=None): + """One mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32; returns (d0,d1,d2,d3).""" + f32 = _T.f32() + i32 = _T.i32() + + def _bits(v): + vv = v.ir_value(loc=loc, ip=ip) if hasattr(v, "ir_value") else v + return _arith.bitcast(i32, vv, loc=loc, ip=ip) + + def _f(v): + return v.ir_value(loc=loc, ip=ip) if hasattr(v, "ir_value") else v + + res_ty = _llvm.StructType.get_literal([f32, f32, f32, f32]) + res = _llvm.inline_asm( + res_ty, + [_bits(a0), _bits(a1), _bits(a2), _bits(a3), _bits(b0), _bits(b1), _f(c0), _f(c1), _f(c2), _f(c3)], + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {$0,$1,$2,$3}, {$4,$5,$6,$7}, {$8,$9}, {$10,$11,$12,$13};", + "=f,=f,=f,=f,r,r,r,r,r,r,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=_llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + d0 = cutlass.Float32(_llvm.extractvalue(f32, res, [0], loc=loc, ip=ip)) + d1 = cutlass.Float32(_llvm.extractvalue(f32, res, [1], loc=loc, ip=ip)) + d2 = cutlass.Float32(_llvm.extractvalue(f32, res, [2], loc=loc, ip=ip)) + d3 = cutlass.Float32(_llvm.extractvalue(f32, res, [3], loc=loc, ip=ip)) + return d0, d1, d2, d3 + + +BT: int = 8 # pad M and N dimensions to 8 for mma fragment + + +@cute.kernel +def la_verify_kvbuffer_kernel( + h0_source: cute.Tensor, # [pool_size * HV, V, K] fp32 (READ ONLY) + decay_scales: cute.Tensor, # [H] fp32 + q: cute.Tensor, # [B, T, H, K] bf16 + k: cute.Tensor, # [B, T, H, K] bf16 + v: cute.Tensor, # [B, T, HV, V] bf16 + o: cute.Tensor, # [B, T, HV, V] bf16 (WRITTEN) + h0_indices: cute.Tensor, # [B] int32 + k_buf: cute.Tensor, # [pool_size, T, H, K] bf16 (WRITTEN when write_kv) + v_buf: cute.Tensor, # [pool_size, T, HV, V] bf16 (WRITTEN when write_kv) + vec_size: cutlass.Constexpr[int], + num_v_tiles: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], + scale: cutlass.Constexpr[float], + B: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + ilp_rows: cutlass.Constexpr[int], + write_kv: cutlass.Constexpr[bool], +): + tidx, _, _ = cute.arch.thread_idx() + lane_id = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # MMA lane decomposition (see fragment map in module docstring). + gid = lane_id // 4 # 0..7: row index within the MMA tile + tig = lane_id % 4 # 0..3: k-pair within the current 8-wide K-slab + + # 4 warps/block; each warp owns a disjoint set of output V-rows. All 32 lanes + # of a warp cooperate over the full K dimension (K=128, vec_size=4). + NUM_WARPS: cutlass.Constexpr[int] = 4 + + # Block -> (sequence n, v-head hv, V-tile i_v); i_h maps the v-head to its q/k head. + block_idx, _, _ = cute.arch.block_idx() + i_v = block_idx % num_v_tiles + tmp = block_idx // num_v_tiles + i_hv = tmp % HV + i_n = tmp // HV + i_h = i_hv // (HV // H) + + cache_idx = h0_indices[i_n] + + # ---- Per-lane registers ---- + r_decay_pow = cute.make_rmem_tensor(cute.make_layout((T + 1,), stride=(1,)), cutlass.Float32) + r_q_bf16 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + r_k_bf16 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + + # ---- SMEM (all fp32; MMA bitcasts fp32->TF32, no separate conversion) ---- + # KP = K+4 pads the row stride so 132%32=4: the gid*4+tig access pattern then + # hits 32 distinct banks, giving conflict-free SMEM reads in both GEMMs. + KP: cutlass.Constexpr[int] = K + 4 + smem = cutlass.utils.SmemAllocator() + # GEMM operands. sQ holds q*scale, doubles as GEMM2-A and GEMM1-B. + sQ = smem.allocate_tensor(cutlass.Float32, cute.make_layout((BT, KP), stride=(KP, 1)), 16) + sK = smem.allocate_tensor(cutlass.Float32, cute.make_layout((BT, KP), stride=(KP, 1)), 16) + # h0, one [BT, K] region per warp (each warp does GEMM1 for its own V-rows). + sH0 = smem.allocate_tensor(cutlass.Float32, cute.make_layout((NUM_WARPS, BT, KP), stride=(BT * KP, KP, 1)), 16) + # Decay-masked QK coefficients [T, T], produced by GEMM2, consumed by every warp. + s_qk_scaled = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T, T), stride=(T, 1)), 16) + # v is lane-invariant within a warp; stage it once in SMEM and broadcast-read. + sVbuf = smem.allocate_tensor(cutlass.Float32, cute.make_layout((NUM_WARPS, T, BT), stride=(T * BT, BT, 1)), 16) + + if cache_idx >= 0: + alpha = cute.exp(-cutlass.Float32(decay_scales[i_h]), fastmath=USE_FAST_MATH) + + r_decay_pow[0] = cutlass.Float32(1.0) + for t in cutlass.range_constexpr(1, T + 1): + r_decay_pow[t] = r_decay_pow[t - 1] * alpha + + rows_per_group: cutlass.Constexpr[int] = tile_v // NUM_WARPS + flat_state_idx = cache_idx * HV + i_hv + + # ================================================================ + # Stage 0: cooperative load q*scale, k -> SMEM (sQ, sK), fp32. + # Warp w loads tokens {w, w+4, ...}; within a token, lane_id covers the + # K dimension (vec_size contiguous elements each). Rows T..BT-1 are the + # MMA M-padding and are zeroed. + # ================================================================ + tokens_per_warp: cutlass.Constexpr[int] = (BT + NUM_WARPS - 1) // NUM_WARPS + for tt in cutlass.range_constexpr(tokens_per_warp): + t_tok = tt * NUM_WARPS + warp_idx + if t_tok < T: + q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, t_tok, i_h, lane_id)) + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, t_tok, i_h, lane_id)) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) + for c in cutlass.range_constexpr(vec_size): + col = lane_id * vec_size + c + sQ[(t_tok, col)] = cutlass.Float32(r_q_bf16[c]) * scale + sK[(t_tok, col)] = cutlass.Float32(r_k_bf16[c]) + # Persist k to the pool buffer while it is already in registers. + if cutlass.const_expr(write_kv): + if i_v == 0 and i_hv % (HV // H) == 0: + kb_tile = cute.local_tile(k_buf, (1, 1, 1, vec_size), (cache_idx, t_tok, i_h, lane_id)) + cute.autovec_copy(r_k_bf16, kb_tile) + if t_tok >= T and t_tok < BT: + for c in cutlass.range_constexpr(vec_size): + col = lane_id * vec_size + c + sQ[(t_tok, col)] = cutlass.Float32(0.0) + sK[(t_tok, col)] = cutlass.Float32(0.0) + + cute.arch.barrier() + + # ================================================================ + # Stage 1: GEMM2 — QK[t,i] = q_t . k_i, accumulated over the full K. + # A = Q[8,K] (rows = tokens), B = K[8,K] read col-major as K^T. Warp 0 + # alone has enough lanes (M=N=T<=8), so the other warps skip this. + # ================================================================ + if warp_idx == 0: + c0 = cutlass.Float32(0.0) + c1 = cutlass.Float32(0.0) + c2 = cutlass.Float32(0.0) # c2,c3 = padding rows 8..15, unused + c3 = cutlass.Float32(0.0) + for ks in cutlass.range_constexpr(K // 8): + kb = ks * 8 + a0 = sQ[(gid, kb + tig)] + a1 = cutlass.Float32(0.0) + a2 = sQ[(gid, kb + tig + 4)] + a3 = cutlass.Float32(0.0) + b0 = sK[(gid, kb + tig)] + b1 = sK[(gid, kb + tig + 4)] + c0, c1, c2, c3 = _mma_m16n8k8_tf32(a0, a1, a2, a3, b0, b1, c0, c1, c2, c3) + + # c0,c1 hold QK[gid, 2tig], QK[gid, 2tig+1]. Keep the causal lower + # triangle, pre-multiply by the decay alpha^{t-i}, store coefficients. + for fi in cutlass.range_constexpr(2): + row = gid + col = 2 * tig + fi + cv = c1 if cutlass.const_expr(fi == 1) else c0 + if row < T and col < T: + if col <= row: + s_qk_scaled[(row, col)] = r_decay_pow[row - col] * cv + else: + s_qk_scaled[(row, col)] = cutlass.Float32(0.0) + + cute.arch.barrier() + + # ================================================================ + # Stage 2: for each block of `ilp_rows` V-rows owned by this warp, + # load h0 -> SMEM, run GEMM1 (HQ = h0 @ q_t), then combine the two terms. + # ================================================================ + num_row_blocks: cutlass.Constexpr[int] = rows_per_group // ilp_rows + for row_block in cutlass.range_constexpr(num_row_blocks): + v_base = i_v * tile_v + warp_idx * rows_per_group + row_block * ilp_rows + if v_base + (ilp_rows - 1) < V: + # (a) Coalesced h0 load: lane_id indexes vec_size contiguous K + # elements, so the 32 lanes read one full contiguous row per step + # (no over-fetch). Each warp fills its own sH0 region. + sH0_w = sH0[(warp_idx, None, None)] # [BT, KP] + gH0 = h0_source[(flat_state_idx, None, None)] # [V, K] + for row in cutlass.range_constexpr(ilp_rows): + h_g = cute.local_tile(gH0, (1, vec_size), (v_base + row, lane_id)) + h_s = cute.local_tile(sH0_w, (1, vec_size), (row, lane_id)) + cute.autovec_copy(h_g, h_s) + # Zero the M-padding rows (ilp_rows..BT-1). GEMM1 reads all BT rows; + # their outputs are unused, but leaving stale/NaN SMEM as MMA inputs + # is unclean — explicitly zero so the fragment is well-defined. + for row in cutlass.range_constexpr(ilp_rows, BT): + for c in cutlass.range_constexpr(vec_size): + sH0_w[(row, lane_id * vec_size + c)] = cutlass.Float32(0.0) + cute.arch.sync_warp() # make sH0 writes visible to this warp's GEMM1 + + # (b) GEMM1: HQ[row, t] = h0_row . q_t, over the full K. + # A = sH0 (this warp's V-rows), B = sQ read col-major as Q^T. + e0 = cutlass.Float32(0.0) + e1 = cutlass.Float32(0.0) + e2 = cutlass.Float32(0.0) # e2,e3 = padding rows 8..15, unused + e3 = cutlass.Float32(0.0) + for ks in cutlass.range_constexpr(K // 8): + kb = ks * 8 + a0 = sH0[(warp_idx, gid, kb + tig)] + a1 = cutlass.Float32(0.0) + a2 = sH0[(warp_idx, gid, kb + tig + 4)] + a3 = cutlass.Float32(0.0) + b0 = sQ[(gid, kb + tig)] + b1 = sQ[(gid, kb + tig + 4)] + e0, e1, e2, e3 = _mma_m16n8k8_tf32(a0, a1, a2, a3, b0, b1, e0, e1, e2, e3) + # e0,e1 now hold HQ[gid, 2tig], HQ[gid, 2tig+1] (gid = V-row index). + + # (c) Stage v in SMEM (lane-invariant within the warp) and persist it. + if lane_id < ilp_rows: + for t in cutlass.range_constexpr(T): + vv = v[i_n, t, i_hv, v_base + lane_id] + sVbuf[(warp_idx, t, lane_id)] = cutlass.Float32(vv) + if cutlass.const_expr(write_kv): + v_buf[(cache_idx, t, i_hv, v_base + lane_id)] = vv + + # (d) Combine: o[t, row] = alpha^{t+1}*HQ[row,t] + sum_i qk[t,i]*v[i,row]. + # The (t, row) output grid has T*ilp_rows entries. Distribute them + # across the 32 lanes in a grid-stride fashion: lane L handles outputs + # L, L+32, L+64, ... so each lane emits ceil(T*ilp_rows/32) of them. + # This keeps every lane doing useful work for ANY T (T=4 -> 1 each, + # T=8 -> 2 each, T=2 -> half the lanes), with no redundant compute and + # no SMEM reshuffle — HQ is fetched straight from its owner lane. + num_out: cutlass.Constexpr[int] = T * ilp_rows + outs_per_lane: cutlass.Constexpr[int] = (num_out + 31) // 32 + for oj in cutlass.range_constexpr(outs_per_lane): + out_idx = lane_id + oj * 32 + my_t = out_idx // ilp_rows + my_slot = out_idx % ilp_rows + # shuffle_sync must execute on ALL lanes (warp-collective), so it + # stays outside the my_t None: + """ + Closed-form parallel verify (KVBuffer Eq. 7). Writes out; does not touch s. + + When k_buf and v_buf are provided, also writes k,v to pool-indexed buffers + so the caller can free the original k,v tensors after this call returns. + + Dispatches between two equivalent implementations by draft depth T: the + tensor-core MMA kernel below for T >= MMA_MIN_T, and the warp-shuffle kernel + for smaller T (where MMA's GEMMs are under-utilised). Both share the same + interface, grid, and KVBuffer write semantics. + """ + if T < MMA_MIN_T: + return linear_attention_verify_kvbuffer_shuffle( + q, + k, + v, + s, + out, + decay_scales, + h0_indices, + softmax_scale, + T, + k_buf=k_buf, + v_buf=v_buf, + ) + + B, T_q, H, K = q.shape + assert T_q == T, f"q.shape[1]={T_q} doesn't match T={T}" + assert K == 128, f"K={K} != 128: kernel hardcodes K=128 (threads_per_group, KP=K+4, lane K-coverage)" + _, _, HV, V = v.shape + pool_size = s.shape[0] + + write_kv = k_buf is not None and v_buf is not None + if (k_buf is None) != (v_buf is None): + raise ValueError("k_buf and v_buf must both be None or both be provided") + + tile_v, vec_size, ilp_rows, _use_smem_v = get_mtp_config(B, T, HV, V, True) + assert T <= 8, f"T={T} > 8: MMA kernel's BT=8 token staging only covers T ≤ 8" + # The MMA tile has M=8 valid rows, so process 8 V-rows per warp per block: + # this fills the fragment (vs ilp_rows=4 wasting half the MMA) and halves the + # number of row-blocks. Only applies when the V-rows-per-warp is a multiple of 8. + if ilp_rows < 8 and (tile_v // 4) % 8 == 0: + ilp_rows = 8 + # Re-check after the promotion above: a partial row-block (V not a multiple of + # the final ilp_rows) would be silently skipped by the kernel's bounds guard. + assert V % ilp_rows == 0, f"V={V} % ilp_rows={ilp_rows} ≠ 0: partial row-blocks would be silently skipped" + + cache_key = ( + B, + T, + H, + HV, + K, + V, + pool_size, + softmax_scale, + tile_v, + vec_size, + ilp_rows, + write_kv, + ) + cache = _get_compiled_verify_kvbuffer_kernel(*cache_key) + + h0_view = s.view(pool_size * HV, V, K) + + if not write_kv: + k_buf_t = torch.empty(1, 1, 1, 1, device=q.device, dtype=torch.bfloat16) + v_buf_t = torch.empty(1, 1, 1, 1, device=q.device, dtype=torch.bfloat16) + else: + k_buf_t = k_buf + v_buf_t = v_buf + + if "compiled" not in cache: + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + compiled = cute.compile( + run_la_verify_kvbuffer_kernel, + from_dlpack(h0_view, assumed_align=16), + from_dlpack(decay_scales, assumed_align=16), + from_dlpack(q, assumed_align=16), + from_dlpack(k, assumed_align=16), + from_dlpack(v, assumed_align=16), + from_dlpack(out, assumed_align=16), + from_dlpack(h0_indices, assumed_align=16), + from_dlpack(k_buf_t, assumed_align=16), + from_dlpack(v_buf_t, assumed_align=16), + scale=softmax_scale, + B=B, + T=T, + H=H, + HV=HV, + K=K, + V=V, + tile_v=tile_v, + vec_size=vec_size, + ilp_rows=ilp_rows, + write_kv=write_kv, + stream=stream, + options="--enable-tvm-ffi", + ) + cache["compiled"] = compiled + + compiled = cache["compiled"] + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + compiled( + h0_view, + decay_scales, + q, + k, + v, + out, + h0_indices, + k_buf_t, + v_buf_t, + stream, + ) + + +# =========================================================================== +# Warp-shuffle verify kernel (baseline). Dispatched for small T (T < MMA_MIN_T) +# by linear_attention_verify_kvbuffer above. Uses butterfly shuffle reduce for +# the dot products instead of tensor-core MMA — h0 stays in registers (no SMEM +# fragment staging), giving higher occupancy that wins when T is small. +# =========================================================================== + + +@cute.kernel +def la_verify_kvbuffer_shuffle_kernel( + h0_source: cute.Tensor, # [pool_size * HV, V, K] fp32 (READ ONLY) + decay_scales: cute.Tensor, # [H] fp32 + q: cute.Tensor, # [B, T, H, K] bf16 + k: cute.Tensor, # [B, T, H, K] bf16 + v: cute.Tensor, # [B, T, HV, V] bf16 + o: cute.Tensor, # [B, T, HV, V] bf16 (WRITTEN) + h0_indices: cute.Tensor, # [B] int32 + k_buf: cute.Tensor, # [pool_size, T, H, K] bf16 (WRITTEN when write_kv) + v_buf: cute.Tensor, # [pool_size, T, HV, V] bf16 (WRITTEN when write_kv) + vec_size: cutlass.Constexpr[int], + num_v_tiles: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], + scale: cutlass.Constexpr[float], + B: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + ilp_rows: cutlass.Constexpr[int], + use_packed_fma: cutlass.Constexpr[bool], + write_kv: cutlass.Constexpr[bool], +): + tidx, _, _ = cute.arch.thread_idx() + lane_id = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + threads_per_group: cutlass.Constexpr[int] = K // vec_size # 32 + groups_per_warp: cutlass.Constexpr[int] = 32 // threads_per_group # 1 + num_groups: cutlass.Constexpr[int] = 4 * groups_per_warp # 4 + + lane_in_group = lane_id % threads_per_group + group_in_warp = lane_id // threads_per_group + group_idx = warp_idx * groups_per_warp + group_in_warp + + block_idx, _, _ = cute.arch.block_idx() + i_v = block_idx % num_v_tiles + tmp = block_idx // num_v_tiles + i_hv = tmp % HV + i_n = tmp // HV + i_h = i_hv // (HV // H) + + cache_idx = h0_indices[i_n] + + r_q_bf16 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + r_k_bf16 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + r_h = cute.make_rmem_tensor(cute.make_layout((8, vec_size), stride=(vec_size, 1)), cutlass.Float32) + r_decay_pow = cute.make_rmem_tensor(cute.make_layout((T + 1,), stride=(1,)), cutlass.Float32) + o_partial = cute.make_rmem_tensor(cute.make_layout((8,), stride=(1,)), cutlass.Float32) + + smem = cutlass.utils.SmemAllocator() + s_qk_scaled = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T, T), stride=(T, 1)), 16) + # v staged to SMEM (block-shared over the whole v-tile). v has no K dim, so + # keeping it in per-lane registers wasted 8*T regs/thread and capped occupancy; + # SMEM costs only T*tile_v*4 bytes and is read warp-uniformly (broadcast). + sVdata = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T, tile_v), stride=(tile_v, 1)), 16) + # q (scaled) and k staged to SMEM. They depend only on lane_in_group (NOT on + # warp/group), so a single copy of 32 K-slices is shared by all 4 warps — + # this also removes the redundant per-warp q/k loads. Lane-minor layout + # (T, vec_size, 32) keeps the 32 lanes of a warp on consecutive banks + # (conflict-free); cost is 2 * T*vec_size*32*4 bytes (~8KB at T=8). + s_q = smem.allocate_tensor( + cutlass.Float32, + cute.make_layout((T, vec_size, threads_per_group), stride=(vec_size * threads_per_group, threads_per_group, 1)), + 16, + ) + s_k = smem.allocate_tensor( + cutlass.Float32, + cute.make_layout((T, vec_size, threads_per_group), stride=(vec_size * threads_per_group, threads_per_group, 1)), + 16, + ) + + if cache_idx >= 0: + alpha = cute.exp(-cutlass.Float32(decay_scales[i_h]), fastmath=USE_FAST_MATH) + + # alpha^0 .. alpha^T (T+1 powers; term1 uses alpha^{t+1}) + r_decay_pow[0] = cutlass.Float32(1.0) + for t in cutlass.range_constexpr(1, T + 1): + r_decay_pow[t] = r_decay_pow[t - 1] * alpha + + rows_per_group: cutlass.Constexpr[int] = tile_v // num_groups + flat_state_idx = cache_idx * HV + i_hv + + # Stage all T q (scaled) and k (fp32) into SMEM. q/k are warp-independent, + # so only warp 0 (its 32 lanes cover the full K dim) loads them once. + # The k_buf write is fused here, replacing the old per-warp redundant store. + if warp_idx == 0: + for t in cutlass.range_constexpr(T): + q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, t, i_h, lane_id)) + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, t, i_h, lane_id)) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) + for j in cutlass.range_constexpr(vec_size): + s_q[(t, j, lane_id)] = cutlass.Float32(r_q_bf16[j]) * scale + s_k[(t, j, lane_id)] = cutlass.Float32(r_k_bf16[j]) + + # Write k to buffer — gated: only one block per (b, h, t) writes + if cutlass.const_expr(write_kv): + if i_v == 0 and i_hv % (HV // H) == 0: + kb_tile = cute.local_tile(k_buf, (1, 1, 1, vec_size), (cache_idx, t, i_h, lane_id)) + cute.autovec_copy(r_k_bf16, kb_tile) + + # Cooperative v load: first tile_v threads each stage one v-row for all T + # steps into SMEM. v_buf write (when enabled) is fused here — every + # (cache_idx, t, hv, v_row) is written exactly once by its owning thread. + v_tile_start = i_v * tile_v + for t in cutlass.range_constexpr(T): + if tidx < tile_v: + v_global_idx = v_tile_start + tidx + if v_global_idx < V: + vv = v[i_n, t, i_hv, v_global_idx] + sVdata[(t, tidx)] = cutlass.Float32(vv) + if cutlass.const_expr(write_kv): + v_buf[(cache_idx, t, i_hv, v_global_idx)] = vv + + cute.arch.barrier() # q/k/v staged → visible to all warps + + # Phase 1: cooperative QK matrix — 4 warps split T*(T+1)/2 qk dot products. + # Warp w handles rows where min(t, T-1-t) % 4 == w (head-tail pairing) so that + # each warp's total row-length is balanced: heavy tail rows are paired with light + # head rows, making per-warp work ≈ T*(T+1)/8 regardless of T. + for t_assign in cutlass.range_constexpr(T): + if min(t_assign, T - 1 - t_assign) % 4 == warp_idx: + for i in cutlass.range_constexpr(t_assign + 1): + qk_lo = cutlass.Float32(0.0) + qk_hi = cutlass.Float32(0.0) + for j in cutlass.range_constexpr(0, vec_size, 2): + qk_lo, qk_hi = hq_dot_pair( + s_q[t_assign, j, lane_in_group], + s_q[t_assign, j + 1, lane_in_group], + s_k[i, j, lane_in_group], + s_k[i, j + 1, lane_in_group], + qk_lo, + qk_hi, + use_packed_fma, + ) + qk = qk_lo + qk_hi + for offset in [16, 8, 4, 2, 1]: + qk += cute.arch.shuffle_sync_bfly(qk, offset=offset, mask=-1, mask_and_clamp=31) + if lane_in_group == 0: + s_qk_scaled[(t_assign, i)] = r_decay_pow[t_assign - i] * qk + + cute.arch.barrier() # s_qk_scaled written by Phase 1 → read by Phase 2 + + num_row_blocks: cutlass.Constexpr[int] = rows_per_group // ilp_rows + for row_block in cutlass.range_constexpr(num_row_blocks): + v_base = i_v * tile_v + group_idx * rows_per_group + row_block * ilp_rows + v_local = group_idx * rows_per_group + row_block * ilp_rows # offset within sVdata's v-tile + if v_base + (ilp_rows - 1) < V: + # Load h_init rows (persistent across the T loop). + for slot in cutlass.range_constexpr(ilp_rows): + h_tile = cute.local_tile(h0_source, (1, 1, vec_size), (flat_state_idx, v_base + slot, lane_in_group)) + cute.autovec_copy(h_tile, cute.slice_(r_h, (slot, None))) + + for t in cutlass.range_constexpr(T): + # term1: alpha^{t+1} * (h_init @ q_t) (per-slot warp reduce) + for slot in cutlass.range_constexpr(ilp_rows): + hq_lo = cutlass.Float32(0.0) + hq_hi = cutlass.Float32(0.0) + for j in cutlass.range_constexpr(0, vec_size, 2): + hq_lo, hq_hi = hq_dot_pair( + r_h[slot, j], + r_h[slot, j + 1], + s_q[t, j, lane_in_group], + s_q[t, j + 1, lane_in_group], + hq_lo, + hq_hi, + use_packed_fma, + ) + hq = hq_lo + hq_hi + for offset in [16, 8, 4, 2, 1]: + hq += cute.arch.shuffle_sync_bfly(hq, offset=offset, mask=-1, mask_and_clamp=31) + o_partial[slot] = r_decay_pow[t + 1] * hq + + # term2: read pre-computed decay-scaled qk + staged v from SMEM + for i in cutlass.range_constexpr(t + 1): + coeff = s_qk_scaled[(t, i)] + for slot in cutlass.range_constexpr(ilp_rows): + o_partial[slot] = o_partial[slot] + coeff * sVdata[(i, v_local + slot)] + + # writeback (all lanes hold the reduced value; lane 0 writes) + if lane_in_group == 0: + for slot in cutlass.range_constexpr(ilp_rows): + o[(i_n, t, i_hv, v_base + slot)] = cutlass.BFloat16(o_partial[slot]) + + +@cute.jit +def run_la_verify_kvbuffer_shuffle_kernel( + h0_source: cute.Tensor, + decay_scales: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + o: cute.Tensor, + h0_indices: cute.Tensor, + k_buf: cute.Tensor, + v_buf: cute.Tensor, + scale: cutlass.Constexpr[float], + B: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], + vec_size: cutlass.Constexpr[int], + ilp_rows: cutlass.Constexpr[int], + use_packed_fma: cutlass.Constexpr[bool], + write_kv: cutlass.Constexpr[bool], + stream: cuda.CUstream, +): + num_v_tiles: cutlass.Constexpr[int] = (V + tile_v - 1) // tile_v + grid_size = B * HV * num_v_tiles + + # s_qk_scaled[T][T] + sVdata[T][tile_v] + s_q/s_k[T][vec_size][32] + threads_per_group = 32 + smem_bytes = ( + T * T * 4 # s_qk_scaled + + T * tile_v * 4 # sVdata + + 2 * T * vec_size * threads_per_group * 4 # s_q + s_k + + 4 * 16 # per-allocation 16B alignment padding (4 tensors) + ) + + la_verify_kvbuffer_shuffle_kernel( + h0_source, + decay_scales, + q, + k, + v, + o, + h0_indices, + k_buf, + v_buf, + vec_size, + num_v_tiles, + tile_v, + scale, + B, + T, + H, + HV, + K, + V, + ilp_rows, + use_packed_fma, + write_kv, + ).launch( + grid=(grid_size, 1, 1), + block=[NUM_THREADS_MTP, 1, 1], + smem=smem_bytes, + stream=stream, + ) + + +@functools.cache +def _get_compiled_verify_kvbuffer_kernel_shuffle( + B: int, + T: int, + H: int, + HV: int, + K: int, + V: int, + pool_size: int, + softmax_scale: float, + tile_v: int, + vec_size: int, + ilp_rows: int, + use_packed_fma: bool, + write_kv: bool, +): + return {} + + +def linear_attention_verify_kvbuffer_shuffle( + q: torch.Tensor, # [B, T, H, K] bf16 + k: torch.Tensor, # [B, T, H, K] bf16 + v: torch.Tensor, # [B, T, HV, V] bf16 + s: torch.Tensor, # [pool_size, HV, V, K] fp32, READ ONLY + out: torch.Tensor, # [B, T, HV, V] bf16, WRITTEN + decay_scales: torch.Tensor, # [H] fp32 + h0_indices: torch.Tensor, # [B] int32, -1 to skip + softmax_scale: float, + T: int, + k_buf: torch.Tensor | None = None, # [pool_size, T, H, K] bf16, WRITTEN + v_buf: torch.Tensor | None = None, # [pool_size, T, HV, V] bf16, WRITTEN +) -> None: + """ + Closed-form parallel verify (KVBuffer Eq. 7). Writes out; does not touch s. + + When k_buf and v_buf are provided, also writes k,v to pool-indexed buffers + so the caller can free the original k,v tensors after this call returns. + + For batch b with h0_indices[b] < 0, out[b] is LEFT UNCHANGED — callers must + pre-initialize out if downstream code reads those slots. + """ + B, T_q, H, K = q.shape + assert T_q == T, f"q.shape[1]={T_q} doesn't match T={T}" + assert K == 128, f"K={K} != 128: kernel hardcodes K=128 (threads_per_group, lane K-coverage)" + _, _, HV, V = v.shape + pool_size = s.shape[0] + + write_kv = k_buf is not None and v_buf is not None + if (k_buf is None) != (v_buf is None): + raise ValueError("k_buf and v_buf must both be None or both be provided") + + tile_v, vec_size, ilp_rows, _ = get_mtp_config(B, T, HV, V, True) + assert V % ilp_rows == 0, f"V={V} % ilp_rows={ilp_rows} ≠ 0: partial row-blocks would be silently skipped" + major, _ = get_device_sm_version(q.device) + use_packed_fma = major >= 10 + + cache_key = ( + B, + T, + H, + HV, + K, + V, + pool_size, + softmax_scale, + tile_v, + vec_size, + ilp_rows, + use_packed_fma, + write_kv, + ) + cache = _get_compiled_verify_kvbuffer_kernel_shuffle(*cache_key) + + h0_view = s.view(pool_size * HV, V, K) + + # Dummy tensors when write_kv=False (never accessed by kernel) + if not write_kv: + k_buf_t = torch.empty(1, 1, 1, 1, device=q.device, dtype=torch.bfloat16) + v_buf_t = torch.empty(1, 1, 1, 1, device=q.device, dtype=torch.bfloat16) + else: + k_buf_t = k_buf + v_buf_t = v_buf + + if "compiled" not in cache: + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + compiled = cute.compile( + run_la_verify_kvbuffer_shuffle_kernel, + from_dlpack(h0_view, assumed_align=16), + from_dlpack(decay_scales, assumed_align=16), + from_dlpack(q, assumed_align=16), + from_dlpack(k, assumed_align=16), + from_dlpack(v, assumed_align=16), + from_dlpack(out, assumed_align=16), + from_dlpack(h0_indices, assumed_align=16), + from_dlpack(k_buf_t, assumed_align=16), + from_dlpack(v_buf_t, assumed_align=16), + scale=softmax_scale, + B=B, + T=T, + H=H, + HV=HV, + K=K, + V=V, + tile_v=tile_v, + vec_size=vec_size, + ilp_rows=ilp_rows, + use_packed_fma=use_packed_fma, + write_kv=write_kv, + stream=stream, + options="--enable-tvm-ffi", + ) + cache["compiled"] = compiled + + compiled = cache["compiled"] + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + compiled( + h0_view, + decay_scales, + q, + k, + v, + out, + h0_indices, + k_buf_t, + v_buf_t, + stream, + ) diff --git a/tests/conftest.py b/tests/conftest.py index f144c10b..a9338aca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import re + import pytest import torch @@ -56,9 +57,5 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_slow) continue callspec = getattr(item, "callspec", None) - if ( - callspec is not None - and callspec.params.get("disable_recompute") - and "kda_fast_norecomp" not in item.keywords - ): + if callspec is not None and callspec.params.get("disable_recompute") and "kda_fast_norecomp" not in item.keywords: item.add_marker(skip_fast_norecomp) diff --git a/tests/test_la_decode_mtp.py b/tests/test_la_decode_mtp.py new file mode 100644 index 00000000..cfca7e5d --- /dev/null +++ b/tests/test_la_decode_mtp.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for la_decode_mtp (CuTe DSL Lightning Attention MTP decode kernel). + +Compares against a PyTorch reference implementation of multi-token +Lightning Attention decode (T > 1). + +Layouts: + q, k: [B, T, H, K] bf16 + v: [B, T, HV, V] bf16 + s: [pool_size, HV, V, K] fp32 (V-major, K-last) + intermediate_states: [pool_size * T * HV, V, K] fp32, or 1-elem dummy + out: [B, T, HV, V] bf16 + decay_scales: [H] fp32 (positive; kernel does exp(-x)) + s_offsets: [B] int32 (pool index per batch; -1 to skip) +""" + +import pathlib +import sys + +import pytest +import torch + +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent)) + +from cula.lightning.la_decode_mtp import linear_attention_decode_mtp + + +# --------------------------------------------------------------------------- +# Pure PyTorch reference for multi-token Lightning Attention decode +# --------------------------------------------------------------------------- +def torch_la_mtp_ref(q, k, v, state, decay_scales, scale, T, cache_intermediate_states=False, disable_state_update=False): + """Pure PyTorch reference. + + Args: + q, k: [B, T, H, D] bf16 + v: [B, T, HV, D] bf16 + state: [B, HV, D, D] fp32 (K-major, V-minor) + decay_scales: [H] fp32 (positive; kernel does exp(-x)) + scale: float + T: int + cache_intermediate_states: cache per-step state to inter + disable_state_update: do not update state_new at end + + Returns: + out: [B, T, HV, D] bf16 + state_new: [B, HV, D, D] fp32 + inter: [B*T*HV, D, D] fp32 or None + """ + B, _, H, D = q.shape + HV = v.shape[2] + q_f = q.float() * scale + k_f, v_f = k.float(), v.float() + decay_per_q_head = torch.exp(-decay_scales) + decay_per_hv = decay_per_q_head.repeat_interleave(HV // H).view(1, HV, 1, 1) + + state_running = state.clone() + out = torch.zeros(B, T, HV, D, dtype=torch.bfloat16, device=q.device) + inter = torch.zeros(B * T * HV, D, D, dtype=torch.float32, device=q.device) if cache_intermediate_states else None + + for t in range(T): + q_hv = q_f[:, t].repeat_interleave(HV // H, dim=1) + k_hv = k_f[:, t].repeat_interleave(HV // H, dim=1) + v_t = v_f[:, t] + state_running = state_running * decay_per_hv + k_hv.unsqueeze(-1) * v_t.unsqueeze(-2) + out[:, t] = torch.einsum("bhk,bhkv->bhv", q_hv, state_running).bfloat16() + if cache_intermediate_states: + for b in range(B): + inter[b * T * HV + t * HV : b * T * HV + (t + 1) * HV] = state_running[b] + + state_final = state.clone() if disable_state_update else state_running + return out, state_final, inter + + +def _skip_if_no_sm90_or_later(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + cc = torch.cuda.get_device_capability("cuda") + if cc[0] < 9: + pytest.skip(f"requires SM90+, got SM{cc[0]}{cc[1]}") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def make_inputs(B, T, H, HV, D, device="cuda", seed=42): + """Returns q[B,T,H,D] bf16, k[B,T,H,D] bf16, v[B,T,HV,D] bf16, state[B,HV,D,D] fp32.""" + torch.manual_seed(seed) + q = torch.randn(B, T, H, D, device=device, dtype=torch.bfloat16) + k = torch.randn(B, T, H, D, device=device, dtype=torch.bfloat16) + v = torch.randn(B, T, HV, D, device=device, dtype=torch.bfloat16) + state = torch.randn(B, HV, D, D, device=device, dtype=torch.float32) * 0.01 + return q, k, v, state + + +def run_la_mtp( + q, + k, + v, + state_4d, + decay_scales, + scale, + T, + cache_intermediate_states=False, + disable_state_update=False, +): + """ + Wraps linear_attention_decode_mtp with proper state-layout conversion. + + state_4d: [B, HV, K, V] fp32 (K-major) + Kernel expects s: [pool_size=B, HV, V, K]; we transpose K and V. + """ + B, HV, K, V = state_4d.shape + H = q.shape[2] + assert HV % H == 0, "HV must be a multiple of H" + + # pretranspose: [B, HV, V, K] + s_cute = state_4d.permute(0, 1, 3, 2).contiguous().clone() + out = torch.zeros(B, T, HV, V, device=q.device, dtype=torch.bfloat16) + s_offsets = torch.arange(B, device=q.device, dtype=torch.int32) + + if cache_intermediate_states: + inter = torch.zeros(B * T * HV, V, K, device=q.device, dtype=torch.float32) + else: + inter = torch.empty(1, 1, 1, device=q.device, dtype=torch.float32) # dummy + + cu_seqlens = torch.empty(1, device=q.device, dtype=torch.int32) # dummy when is_varlen=False + + linear_attention_decode_mtp( + q, + k, + v, + s_cute, + inter, + out, + decay_scales=decay_scales, + s_offsets=s_offsets, + cu_seqlens=cu_seqlens, + softmax_scale=scale, + T=T, + cache_intermediate_states=cache_intermediate_states, + disable_state_update=disable_state_update, + is_varlen=False, + ) + + # convert state back: [B, HV, V, K] -> [B, HV, K, V] + state_out = s_cute.permute(0, 1, 3, 2).contiguous() + + if cache_intermediate_states: + # inter (kernel): [B*T*HV, V, K] -> ref layout [B*T*HV, K, V] + inter_out = inter.permute(0, 2, 1).contiguous() + else: + inter_out = None + + return out, state_out, inter_out + + +# --------------------------------------------------------------------------- +# Tests vs PyTorch reference +# --------------------------------------------------------------------------- +# Each (B, T) below targets a distinct heuristic config (with H=HV=64): +# B=1, T=4: work_units=64 → tile_v=8, ilp=2, smem_v=False +# B=2, T=2: work_units=128 → tile_v=16, ilp=4, smem_v=False +# B=2, T=4: work_units=128 → tile_v=16, ilp=4, smem_v=False +# B=8, T=4: work_units=512 → tile_v=32, ilp=4, smem_v=False +# B=32, T=2: work_units=2048 → tile_v=64, ilp=8, smem_v=False (state_update ON) +# B=32, T=4: work_units=2048 → tile_v=64, ilp=4, smem_v=True +@pytest.mark.parametrize( + "B,T,expected_config", + [ + (1, 4, "tile_v=8_ilp=2"), + (2, 2, "tile_v=16_ilp=4"), + (2, 4, "tile_v=16_ilp=4"), + (8, 4, "tile_v=32_ilp=4"), + (32, 2, "tile_v=64_ilp=8"), + (32, 4, "tile_v=64_ilp=4_smem_v"), + ], +) +def test_output_vs_torch_ref(B, T, expected_config): + _skip_if_no_sm90_or_later() + H, HV, D = 64, 64, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + + q, k, v, state = make_inputs(B, T, H, HV, D) + o_ref, state_ref, _ = torch_la_mtp_ref(q, k, v, state, decay_scales, scale, T) + o_cute, state_cute, _ = run_la_mtp(q, k, v, state, decay_scales, scale, T) + + # Output check + rmse = torch.sqrt(torch.mean((o_cute.float() - o_ref.float()) ** 2)).item() + max_ref = torch.abs(o_ref.float()).max().item() + rel = rmse / (max_ref + 1e-8) + assert rel < 0.01, f"B={B} T={T} [{expected_config}]: output rel RMSE {rel:.6f} too large" + + # State check + state_rmse = torch.sqrt(torch.mean((state_cute - state_ref) ** 2)).item() + state_max = torch.abs(state_ref).max().item() + state_rel = state_rmse / (state_max + 1e-8) + assert state_rel < 0.001, f"B={B} T={T} [{expected_config}]: state rel RMSE {state_rel:.6f} too large" + + +@pytest.mark.parametrize("H,HV", [(16, 16), (8, 32), (16, 64)]) # MHA + GQA +def test_different_heads(H, HV): + """GQA support: HV is multiple of H; q/k indexed by i_h = i_hv // (HV//H).""" + _skip_if_no_sm90_or_later() + B, T, D = 4, 4, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + + q, k, v, state = make_inputs(B, T, H, HV, D) + o_ref, state_ref, _ = torch_la_mtp_ref(q, k, v, state, decay_scales, scale, T) + o_cute, state_cute, _ = run_la_mtp(q, k, v, state, decay_scales, scale, T) + + rmse = torch.sqrt(torch.mean((o_cute.float() - o_ref.float()) ** 2)).item() + max_ref = torch.abs(o_ref.float()).max().item() + assert rmse / (max_ref + 1e-8) < 0.01, f"H={H} HV={HV}: output mismatch" + + state_rmse = torch.sqrt(torch.mean((state_cute - state_ref) ** 2)).item() + state_max = torch.abs(state_ref).max().item() + assert state_rmse / (state_max + 1e-8) < 0.001, f"H={H} HV={HV}: state mismatch" + + +def test_disable_state_update(): + """h0_source remains bitwise-equal to the input snapshot.""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + + q, k, v, state = make_inputs(B, T, H, HV, D) + state_snapshot = state.clone() + + _, state_out, _ = run_la_mtp( + q, + k, + v, + state, + decay_scales, + scale, + T, + disable_state_update=True, + ) + assert torch.equal(state_out, state_snapshot), "state was mutated despite disable_state_update=True" + + +def test_cache_intermediate_states(): + """Each per-t slice of inter matches the reference state_running at that step.""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + + q, k, v, state = make_inputs(B, T, H, HV, D) + _, _, inter_ref = torch_la_mtp_ref( + q, + k, + v, + state, + decay_scales, + scale, + T, + cache_intermediate_states=True, + ) + _, _, inter_cute = run_la_mtp( + q, + k, + v, + state, + decay_scales, + scale, + T, + cache_intermediate_states=True, + ) + + rmse = torch.sqrt(torch.mean((inter_cute - inter_ref) ** 2)).item() + max_ref = torch.abs(inter_ref).max().item() + assert rmse / (max_ref + 1e-8) < 0.001, f"intermediate states mismatch, rel_rmse={rmse / (max_ref + 1e-8):.6f}" + + inter_cute_v = inter_cute.view(B, T, HV, D, D) + inter_ref_v = inter_ref.view(B, T, HV, D, D) + for b in range(B): + for t in range(T): + slot_c = inter_cute_v[b, t] + slot_r = inter_ref_v[b, t] + slot_rmse = torch.sqrt(torch.mean((slot_c - slot_r) ** 2)).item() + slot_max = torch.abs(slot_r).max().item() + assert slot_rmse / (slot_max + 1e-8) < 0.001, ( + f"(b={b}, t={t}) intermediate mismatch, rel_rmse={slot_rmse / (slot_max + 1e-8):.6f}" + ) + + assert not torch.allclose(inter_cute_v[0, 0], inter_cute_v[0, 1]) + + +def test_skip_with_negative_offset(): + """s_offsets[i]=-1: that batch's `out` slot stays at initial value.""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + + q, k, v, state = make_inputs(B, T, H, HV, D) + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() + sentinel = 123.0 + out = torch.full((B, T, HV, D), sentinel, device=q.device, dtype=torch.bfloat16) + s_offsets = torch.arange(B, device=q.device, dtype=torch.int32) + s_offsets[2] = -1 # skip batch index 2 + + inter = torch.empty(1, 1, 1, device=q.device, dtype=torch.float32) + cu_seqlens = torch.empty(1, device=q.device, dtype=torch.int32) + linear_attention_decode_mtp( + q, + k, + v, + s_cute, + inter, + out, + decay_scales=decay_scales, + s_offsets=s_offsets, + cu_seqlens=cu_seqlens, + softmax_scale=scale, + T=T, + cache_intermediate_states=False, + disable_state_update=False, + is_varlen=False, + ) + # batch 2 should be untouched (sentinel value) + assert torch.all(out[2] == torch.full_like(out[2], sentinel)), "skipped batch was modified" + # other batches should differ from sentinel + assert not torch.all(out[0] == torch.full_like(out[0], sentinel)), "non-skipped batch unchanged" + + +def test_skip_with_negative_offset_cache_intermediate(): + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + + q, k, v, state = make_inputs(B, T, H, HV, D) + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() + out = torch.zeros(B, T, HV, D, device=q.device, dtype=torch.bfloat16) + s_offsets = torch.arange(B, device=q.device, dtype=torch.int32) + s_offsets[2] = -1 + + inter_sentinel = 7.5 + inter = torch.full((B * T * HV, D, D), inter_sentinel, device=q.device, dtype=torch.float32) + cu_seqlens = torch.empty(1, device=q.device, dtype=torch.int32) + + linear_attention_decode_mtp( + q, + k, + v, + s_cute, + inter, + out, + decay_scales=decay_scales, + s_offsets=s_offsets, + cu_seqlens=cu_seqlens, + softmax_scale=scale, + T=T, + cache_intermediate_states=True, + disable_state_update=False, + is_varlen=False, + ) + + skipped = inter[2 * T * HV : 3 * T * HV] + assert torch.all(skipped == inter_sentinel), ( + f"intermediate_states for skipped batch was written (min={skipped.min().item()}, max={skipped.max().item()})" + ) + + others = torch.cat([inter[: 2 * T * HV], inter[3 * T * HV :]], dim=0) + assert not torch.all(others == inter_sentinel), "non-skipped intermediate slots were not written" + + +def test_zero_decay(): + """With decay=0: state_new = state_old + k⊗v (no decay applied).""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + scale = D**-0.5 + decay_scales = torch.zeros(H, device="cuda", dtype=torch.float32) + + q, k, v, state = make_inputs(B, T, H, HV, D) + o_ref, _, _ = torch_la_mtp_ref(q, k, v, state, decay_scales, scale, T) + o_cute, _, _ = run_la_mtp(q, k, v, state, decay_scales, scale, T) + + rmse = torch.sqrt(torch.mean((o_cute.float() - o_ref.float()) ** 2)).item() + max_ref = torch.abs(o_ref.float()).max().item() + assert rmse / (max_ref + 1e-8) < 0.01, "zero decay: output mismatch" + + +def test_zero_state(): + """With zero initial state.""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.ones(H, device="cuda", dtype=torch.float32) + + q, k, v, _ = make_inputs(B, T, H, HV, D) + state = torch.zeros(B, HV, D, D, device="cuda", dtype=torch.float32) + o_ref, _, _ = torch_la_mtp_ref(q, k, v, state, decay_scales, scale, T) + o_cute, _, _ = run_la_mtp(q, k, v, state, decay_scales, scale, T) + + rmse = torch.sqrt(torch.mean((o_cute.float() - o_ref.float()) ** 2)).item() + max_ref = torch.abs(o_ref.float()).max().item() + assert rmse / (max_ref + 1e-8) < 0.01, "zero state: output mismatch" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/test_la_kvbuffer.py b/tests/test_la_kvbuffer.py new file mode 100644 index 00000000..dc0be2a6 --- /dev/null +++ b/tests/test_la_kvbuffer.py @@ -0,0 +1,669 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the KVBuffer verify + state-update kernels.""" + +import pathlib +import sys + +import pytest +import torch + +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent)) + +from cula.lightning.la_decode_mtp import linear_attention_decode_mtp +from cula.lightning.la_state_update_kvbuffer import linear_attention_state_update_kvbuffer +from cula.lightning.la_verify_kvbuffer import linear_attention_verify_kvbuffer + + +# --------------------------------------------------------------------------- +# Pure PyTorch reference for multi-token Lightning Attention decode +# --------------------------------------------------------------------------- +def torch_la_mtp_ref(q, k, v, state, decay_scales, scale, T, cache_intermediate_states=False, disable_state_update=False): + """Pure PyTorch reference. + + Args: + q, k: [B, T, H, D] bf16 + v: [B, T, HV, D] bf16 + state: [B, HV, D, D] fp32 (K-major, V-minor) + decay_scales: [H] fp32 (positive; kernel does exp(-x)) + scale: float + T: int + cache_intermediate_states: cache per-step state to inter + disable_state_update: do not update state_new at end + + Returns: + out: [B, T, HV, D] bf16 + state_new: [B, HV, D, D] fp32 + inter: [B*T*HV, D, D] fp32 or None + """ + B, _, H, D = q.shape + HV = v.shape[2] + q_f = q.float() * scale + k_f, v_f = k.float(), v.float() + decay_per_q_head = torch.exp(-decay_scales) + decay_per_hv = decay_per_q_head.repeat_interleave(HV // H).view(1, HV, 1, 1) + + state_running = state.clone() + out = torch.zeros(B, T, HV, D, dtype=torch.bfloat16, device=q.device) + inter = torch.zeros(B * T * HV, D, D, dtype=torch.float32, device=q.device) if cache_intermediate_states else None + + for t in range(T): + q_hv = q_f[:, t].repeat_interleave(HV // H, dim=1) + k_hv = k_f[:, t].repeat_interleave(HV // H, dim=1) + v_t = v_f[:, t] + state_running = state_running * decay_per_hv + k_hv.unsqueeze(-1) * v_t.unsqueeze(-2) + out[:, t] = torch.einsum("bhk,bhkv->bhv", q_hv, state_running).bfloat16() + if cache_intermediate_states: + for b in range(B): + inter[b * T * HV + t * HV : b * T * HV + (t + 1) * HV] = state_running[b] + + state_final = state.clone() if disable_state_update else state_running + return out, state_final, inter + + +def _skip_if_no_sm90_or_later(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + cc = torch.cuda.get_device_capability("cuda") + if cc[0] < 9: + pytest.skip(f"requires SM90+, got SM{cc[0]}{cc[1]}") + + +def _make_inputs(B, T, H, HV, D, device="cuda", seed=42): + torch.manual_seed(seed) + q = torch.randn(B, T, H, D, device=device, dtype=torch.bfloat16) + k = torch.randn(B, T, H, D, device=device, dtype=torch.bfloat16) + v = torch.randn(B, T, HV, D, device=device, dtype=torch.bfloat16) + state = torch.randn(B, HV, D, D, device=device, dtype=torch.float32) * 0.01 + return q, k, v, state + + +def test_state_update_L0_no_op(): + """accepted_len=0 everywhere: s must be byte-for-byte unchanged.""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + _, k, v, state = _make_inputs(B, T, H, HV, D) + + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() # [B, HV, V, K] + s_snapshot = s_cute.clone() + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + accepted_len = torch.zeros(B, device="cuda", dtype=torch.int32) + + linear_attention_state_update_kvbuffer( + k, + v, + s_cute, + decay_scales, + h0_indices, + accepted_len, + T, + ) + assert torch.equal(s_cute, s_snapshot), "L=0 must leave state unchanged" + + +def _ref_state_after_L(state, k, v, decay_scales, L_per_batch, T): + """state[B,HV,K,V] fp32; returns the per-batch state after L recurrent steps.""" + B, HV, K, V = state.shape + H = k.shape[2] + k_f, v_f = k.float(), v.float() + decay_per_q_head = torch.exp(-decay_scales) + decay_per_hv = decay_per_q_head.repeat_interleave(HV // H).view(HV, 1, 1) + out = state.clone() + for b in range(B): + L = int(L_per_batch[b].item()) + running = state[b].clone() + for i in range(L): + k_hv = k_f[b, i].repeat_interleave(HV // H, dim=0) # [HV, K] + v_i = v_f[b, i] # [HV, V] + running = running * decay_per_hv + k_hv.unsqueeze(-1) * v_i.unsqueeze(-2) + out[b] = running + return out + + +@pytest.mark.parametrize( + "B,T,H,HV,D", + [(4, 4, 16, 16, 128), (8, 4, 64, 64, 128), (4, 3, 16, 16, 128), (8, 7, 64, 64, 128)], +) +def test_state_update_full_accept(B, T, H, HV, D): + """accepted_len=T everywhere: bit-exact vs baseline recurrence reference.""" + _skip_if_no_sm90_or_later() + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + _, k, v, state = _make_inputs(B, T, H, HV, D) + + L_per_batch = torch.full((B,), T, device="cuda", dtype=torch.int32) + ref = _ref_state_after_L(state, k, v, decay_scales, L_per_batch, T) # [B,HV,K,V] + + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() # [B,HV,V,K] + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + linear_attention_state_update_kvbuffer( + k, + v, + s_cute, + decay_scales, + h0_indices, + L_per_batch, + T, + ) + got = s_cute.permute(0, 1, 3, 2).contiguous() # back to [B,HV,K,V] + rmse = torch.sqrt(torch.mean((got - ref) ** 2)).item() + rel = rmse / (torch.abs(ref).max().item() + 1e-8) + assert rel < 1e-3, f"full-accept state rel RMSE {rel:.6f} too large" + + +@pytest.mark.parametrize("L", [0, 1, 3]) +def test_state_update_partial(L): + """Uniform accepted_len=L across all batches.""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + _, k, v, state = _make_inputs(B, T, H, HV, D) + + L_per_batch = torch.full((B,), L, device="cuda", dtype=torch.int32) + ref = _ref_state_after_L(state, k, v, decay_scales, L_per_batch, T) + + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + linear_attention_state_update_kvbuffer( + k, + v, + s_cute, + decay_scales, + h0_indices, + L_per_batch, + T, + ) + got = s_cute.permute(0, 1, 3, 2).contiguous() + rel = torch.sqrt(torch.mean((got - ref) ** 2)).item() / (torch.abs(ref).max().item() + 1e-8) + assert rel < 1e-3, f"L={L} state rel RMSE {rel:.6f}" + + +def test_state_update_per_batch_L(): + """accepted_len varies per batch: [0, 1, T-1, T].""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + _, k, v, state = _make_inputs(B, T, H, HV, D) + + L_per_batch = torch.tensor([0, 1, T - 1, T], device="cuda", dtype=torch.int32) + ref = _ref_state_after_L(state, k, v, decay_scales, L_per_batch, T) + + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + linear_attention_state_update_kvbuffer( + k, + v, + s_cute, + decay_scales, + h0_indices, + L_per_batch, + T, + ) + got = s_cute.permute(0, 1, 3, 2).contiguous() + for b in range(B): + rel = torch.sqrt(torch.mean((got[b] - ref[b]) ** 2)).item() / (torch.abs(ref[b]).max().item() + 1e-8) + assert rel < 1e-3, f"batch {b} (L={int(L_per_batch[b])}) rel RMSE {rel:.6f}" + + +def test_state_update_skip_negative_h0_indices(): + """h0_indices[b]=-1: that pool slot is untouched even with accepted_len>0.""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + _, k, v, state = _make_inputs(B, T, H, HV, D) + + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() + snapshot_b2 = s_cute[2].clone() + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + h0_indices[2] = -1 + L_per_batch = torch.full((B,), T, device="cuda", dtype=torch.int32) + + linear_attention_state_update_kvbuffer( + k, + v, + s_cute, + decay_scales, + h0_indices, + L_per_batch, + T, + ) + assert torch.equal(s_cute[2], snapshot_b2), "skipped batch slot was modified" + + +def test_verify_skip_negative_h0_indices(): + """h0_indices[b]=-1: out[b] stays at its sentinel value.""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + q, k, v, state = _make_inputs(B, T, H, HV, D) + + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() + sentinel = 123.0 + out = torch.full((B, T, HV, D), sentinel, device="cuda", dtype=torch.bfloat16) + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + h0_indices[2] = -1 + + linear_attention_verify_kvbuffer( + q, + k, + v, + s_cute, + out, + decay_scales, + h0_indices, + scale, + T, + ) + assert torch.all(out[2] == sentinel), "skipped batch out slot was modified" + + +@pytest.mark.parametrize( + "B,T", + [(1, 4), (2, 2), (2, 4), (8, 4), (32, 2), (32, 4), (2, 1), (2, 3), (8, 5), (8, 7)], +) +def test_verify_outputs_match_ref(B, T): + """Verify kernel o matches torch_la_mtp_ref across the baseline configs.""" + _skip_if_no_sm90_or_later() + H, HV, D = 64, 64, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + q, k, v, state = _make_inputs(B, T, H, HV, D) + + o_ref, _, _ = torch_la_mtp_ref(q, k, v, state, decay_scales, scale, T) + + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() + out = torch.zeros(B, T, HV, D, device="cuda", dtype=torch.bfloat16) + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + linear_attention_verify_kvbuffer( + q, + k, + v, + s_cute, + out, + decay_scales, + h0_indices, + scale, + T, + ) + rel = torch.sqrt(torch.mean((out.float() - o_ref.float()) ** 2)).item() / (torch.abs(o_ref.float()).max().item() + 1e-8) + assert rel < 1e-2, f"B={B} T={T}: verify output rel RMSE {rel:.6f} too large" + + +@pytest.mark.parametrize("H,HV", [(16, 16), (8, 32), (16, 64)]) +def test_verify_different_heads(H, HV): + _skip_if_no_sm90_or_later() + B, T, D = 4, 4, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + q, k, v, state = _make_inputs(B, T, H, HV, D) + o_ref, _, _ = torch_la_mtp_ref(q, k, v, state, decay_scales, scale, T) + + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() + out = torch.zeros(B, T, HV, D, device="cuda", dtype=torch.bfloat16) + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + linear_attention_verify_kvbuffer( + q, + k, + v, + s_cute, + out, + decay_scales, + h0_indices, + scale, + T, + ) + rel = torch.sqrt(torch.mean((out.float() - o_ref.float()) ** 2)).item() / (torch.abs(o_ref.float()).max().item() + 1e-8) + assert rel < 1e-2, f"H={H} HV={HV}: verify output mismatch {rel:.6f}" + + +def test_verify_zero_decay(): + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + scale = D**-0.5 + decay_scales = torch.zeros(H, device="cuda", dtype=torch.float32) + q, k, v, state = _make_inputs(B, T, H, HV, D) + o_ref, _, _ = torch_la_mtp_ref(q, k, v, state, decay_scales, scale, T) + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() + out = torch.zeros(B, T, HV, D, device="cuda", dtype=torch.bfloat16) + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + linear_attention_verify_kvbuffer(q, k, v, s_cute, out, decay_scales, h0_indices, scale, T) + rel = torch.sqrt(torch.mean((out.float() - o_ref.float()) ** 2)).item() / (torch.abs(o_ref.float()).max().item() + 1e-8) + assert rel < 1e-2, f"zero decay: {rel:.6f}" + + +def test_verify_zero_state(): + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.ones(H, device="cuda", dtype=torch.float32) + q, k, v, _ = _make_inputs(B, T, H, HV, D) + state = torch.zeros(B, HV, D, D, device="cuda", dtype=torch.float32) + o_ref, _, _ = torch_la_mtp_ref(q, k, v, state, decay_scales, scale, T) + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() + out = torch.zeros(B, T, HV, D, device="cuda", dtype=torch.bfloat16) + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + linear_attention_verify_kvbuffer(q, k, v, s_cute, out, decay_scales, h0_indices, scale, T) + rel = torch.sqrt(torch.mean((out.float() - o_ref.float()) ** 2)).item() / (torch.abs(o_ref.float()).max().item() + 1e-8) + assert rel < 1e-2, f"zero state: {rel:.6f}" + + +def test_end_to_end_equivalence_with_baseline(): + """KVBuffer (verify + state_update L=T) == baseline (cache_inter=T, disable=T).""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 8, 4, 64, 64, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + q, k, v, state = _make_inputs(B, T, H, HV, D) + + # ---- Baseline: capture out + all intermediate states ---- + s_base = state.permute(0, 1, 3, 2).contiguous().clone() # [B,HV,V,K] + out_base = torch.zeros(B, T, HV, D, device="cuda", dtype=torch.bfloat16) + s_offsets = torch.arange(B, device="cuda", dtype=torch.int32) + inter = torch.zeros(B * T * HV, D, D, device="cuda", dtype=torch.float32) # [.,V,K] + cu_seqlens = torch.empty(1, device="cuda", dtype=torch.int32) + linear_attention_decode_mtp( + q, + k, + v, + s_base, + inter, + out_base, + decay_scales=decay_scales, + s_offsets=s_offsets, + cu_seqlens=cu_seqlens, + softmax_scale=scale, + T=T, + cache_intermediate_states=True, + disable_state_update=True, + is_varlen=False, + ) + + # ---- KVBuffer: verify writes out; state-update (L=T) writes state ---- + s_kv = state.permute(0, 1, 3, 2).contiguous().clone() # [B,HV,V,K] + out_kv = torch.zeros(B, T, HV, D, device="cuda", dtype=torch.bfloat16) + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + linear_attention_verify_kvbuffer( + q, + k, + v, + s_kv, + out_kv, + decay_scales, + h0_indices, + scale, + T, + ) + accepted_len = torch.full((B,), T, device="cuda", dtype=torch.int32) + linear_attention_state_update_kvbuffer( + k, + v, + s_kv, + decay_scales, + h0_indices, + accepted_len, + T, + ) + + # (a) outputs match + rel_o = torch.sqrt(torch.mean((out_kv.float() - out_base.float()) ** 2)).item() / ( + torch.abs(out_base.float()).max().item() + 1e-8 + ) + assert rel_o < 1e-2, f"output mismatch vs baseline: {rel_o:.6f}" + + # (b) updated state == baseline's last intermediate slice [B,HV,V,K] + inter_v = inter.view(B, T, HV, D, D) # [B,T,HV,V,K] + last_state = inter_v[:, T - 1] # [B,HV,V,K] + rel_s = torch.sqrt(torch.mean((s_kv - last_state) ** 2)).item() / (torch.abs(last_state).max().item() + 1e-8) + assert rel_s < 1e-3, f"state mismatch vs baseline last intermediate: {rel_s:.6f}" + + +@pytest.mark.parametrize("B,T", [(4, 4), (8, 2), (32, 4)]) +def test_verify_writes_kv_buffer(B, T): + """Verify kernel with k_buf/v_buf writes correct copies of k and v.""" + _skip_if_no_sm90_or_later() + H, HV, D = 64, 64, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + q, k, v, state = _make_inputs(B, T, H, HV, D) + + pool_size = B + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() + out = torch.zeros(B, T, HV, D, device="cuda", dtype=torch.bfloat16) + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + k_buf = torch.zeros(pool_size, T, H, D, device="cuda", dtype=torch.bfloat16) + v_buf = torch.zeros(pool_size, T, HV, D, device="cuda", dtype=torch.bfloat16) + + linear_attention_verify_kvbuffer( + q, + k, + v, + s_cute, + out, + decay_scales, + h0_indices, + scale, + T, + k_buf=k_buf, + v_buf=v_buf, + ) + + for b in range(B): + pool_idx = h0_indices[b].item() + assert torch.equal(k_buf[pool_idx], k[b]), f"k_buf mismatch at batch {b}" + assert torch.equal(v_buf[pool_idx], v[b]), f"v_buf mismatch at batch {b}" + + +def test_verify_output_unchanged_with_kv_write(): + """Output o is identical whether k_buf/v_buf are provided or not.""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 8, 4, 64, 64, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + q, k, v, state = _make_inputs(B, T, H, HV, D) + + pool_size = B + s1 = state.permute(0, 1, 3, 2).contiguous().clone() + s2 = s1.clone() + out_no_buf = torch.zeros(B, T, HV, D, device="cuda", dtype=torch.bfloat16) + out_with_buf = torch.zeros(B, T, HV, D, device="cuda", dtype=torch.bfloat16) + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + + linear_attention_verify_kvbuffer( + q, + k, + v, + s1, + out_no_buf, + decay_scales, + h0_indices, + scale, + T, + ) + + k_buf = torch.zeros(pool_size, T, H, D, device="cuda", dtype=torch.bfloat16) + v_buf = torch.zeros(pool_size, T, HV, D, device="cuda", dtype=torch.bfloat16) + linear_attention_verify_kvbuffer( + q, + k, + v, + s2, + out_with_buf, + decay_scales, + h0_indices, + scale, + T, + k_buf=k_buf, + v_buf=v_buf, + ) + + assert torch.equal(out_no_buf, out_with_buf), "kv write should not affect output" + + +@pytest.mark.parametrize("B,T,H,HV,D", [(4, 4, 16, 16, 128), (8, 4, 64, 64, 128)]) +def test_state_update_from_buffer(B, T, H, HV, D): + """State update from k_buf/v_buf matches state update from raw k,v.""" + _skip_if_no_sm90_or_later() + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + _, k, v, state = _make_inputs(B, T, H, HV, D) + + pool_size = B + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + L_per_batch = torch.full((B,), T, device="cuda", dtype=torch.int32) + + # Path A: read from raw k, v + s_raw = state.permute(0, 1, 3, 2).contiguous().clone() + linear_attention_state_update_kvbuffer( + k, + v, + s_raw, + decay_scales, + h0_indices, + L_per_batch, + T, + ) + + # Path B: read from buffer (fill buffer with same k, v) + k_buf = torch.zeros(pool_size, T, H, D, device="cuda", dtype=torch.bfloat16) + v_buf = torch.zeros(pool_size, T, HV, D, device="cuda", dtype=torch.bfloat16) + for b in range(B): + k_buf[h0_indices[b].item()] = k[b] + v_buf[h0_indices[b].item()] = v[b] + + s_buf = state.permute(0, 1, 3, 2).contiguous().clone() + linear_attention_state_update_kvbuffer( + k, + v, + s_buf, + decay_scales, + h0_indices, + L_per_batch, + T, + k_buf=k_buf, + v_buf=v_buf, + ) + + assert torch.equal(s_raw, s_buf), "buffer-read state must match raw-read state" + + +def test_verify_skip_negative_indices_no_buffer_write(): + """h0_indices[b]=-1: k_buf and v_buf slots are untouched.""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 4, 4, 16, 16, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + q, k, v, state = _make_inputs(B, T, H, HV, D) + + pool_size = B + sentinel = 42.0 + k_buf = torch.full((pool_size, T, H, D), sentinel, device="cuda", dtype=torch.bfloat16) + v_buf = torch.full((pool_size, T, HV, D), sentinel, device="cuda", dtype=torch.bfloat16) + k_buf_snap = k_buf.clone() + v_buf_snap = v_buf.clone() + + s_cute = state.permute(0, 1, 3, 2).contiguous().clone() + out = torch.zeros(B, T, HV, D, device="cuda", dtype=torch.bfloat16) + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + h0_indices[2] = -1 + + linear_attention_verify_kvbuffer( + q, + k, + v, + s_cute, + out, + decay_scales, + h0_indices, + scale, + T, + k_buf=k_buf, + v_buf=v_buf, + ) + + assert torch.equal(k_buf[2], k_buf_snap[2]), "skipped batch k_buf slot was modified" + assert torch.equal(v_buf[2], v_buf_snap[2]), "skipped batch v_buf slot was modified" + + +def test_end_to_end_with_buffer(): + """Full pipeline: verify(+kv write) → state_update(from buffer) matches baseline.""" + _skip_if_no_sm90_or_later() + B, T, H, HV, D = 8, 4, 64, 64, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(H, device="cuda", dtype=torch.float32) / H + q, k, v, state = _make_inputs(B, T, H, HV, D) + + pool_size = B + h0_indices = torch.arange(B, device="cuda", dtype=torch.int32) + + # Reference: existing end-to-end (no buffer) + s_ref = state.permute(0, 1, 3, 2).contiguous().clone() + out_ref = torch.zeros(B, T, HV, D, device="cuda", dtype=torch.bfloat16) + linear_attention_verify_kvbuffer( + q, + k, + v, + s_ref, + out_ref, + decay_scales, + h0_indices, + scale, + T, + ) + accepted_len = torch.full((B,), T, device="cuda", dtype=torch.int32) + linear_attention_state_update_kvbuffer( + k, + v, + s_ref, + decay_scales, + h0_indices, + accepted_len, + T, + ) + + # Buffer path: verify writes buffer, state_update reads buffer + s_buf = state.permute(0, 1, 3, 2).contiguous().clone() + out_buf = torch.zeros(B, T, HV, D, device="cuda", dtype=torch.bfloat16) + k_buf = torch.zeros(pool_size, T, H, D, device="cuda", dtype=torch.bfloat16) + v_buf = torch.zeros(pool_size, T, HV, D, device="cuda", dtype=torch.bfloat16) + + linear_attention_verify_kvbuffer( + q, + k, + v, + s_buf, + out_buf, + decay_scales, + h0_indices, + scale, + T, + k_buf=k_buf, + v_buf=v_buf, + ) + linear_attention_state_update_kvbuffer( + k, + v, + s_buf, + decay_scales, + h0_indices, + accepted_len, + T, + k_buf=k_buf, + v_buf=v_buf, + ) + + assert torch.equal(out_ref, out_buf), "output mismatch with buffer pipeline" + assert torch.equal(s_ref, s_buf), "state mismatch with buffer pipeline"