[Refactor] Refactor examples into eager style#1948
[Refactor] Refactor examples into eager style#1948LeiWang1999 wants to merge 2 commits intotile-ai:mainfrom
Conversation
… and performance - Refactored the `matmul` function to accept tensors directly instead of dimensions, enhancing usability. - Simplified the `flashattn_bwd_preprocess` and `flashattn_bwd_postprocess` functions by updating parameter names and types for better readability. - Adjusted kernel context initialization and memory allocation for shared tensors in both forward and backward attention functions. - Enhanced the handling of dynamic shapes and added support for variable-length sequences in attention sink examples. - Cleaned up example scripts by removing unnecessary comments and improving code structure for better maintainability.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Important Review skippedReview was skipped as selected files did not have any reviewable changes. 💤 Files selected but had no reviewable changes (2)
⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR systematically refactors TileLang JIT kernel examples across 60+ files, converting them from a "factory pattern" where functions return nested Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
✨ Finishing Touches🧪 Generate unit tests (beta)
|
There was a problem hiding this comment.
Actionable comments posted: 17
Note
Due to the large number of review comments, Critical severity comments were prioritized as inline comments.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (7)
examples/flash_attention/example_gqa_fwd_varlen.py (1)
145-150:⚠️ Potential issue | 🟡 MinorRemove duplicate random seed setting.
tilelang.testing.set_random_seed(0)is called twice (lines 145 and 150). Remove one of these calls.Proposed fix
tilelang.testing.set_random_seed(0) if is_causal: total_flops *= 0.5 - tilelang.testing.set_random_seed(0) - dtype = torch.float16🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention/example_gqa_fwd_varlen.py` around lines 145 - 150, Remove the duplicate call to tilelang.testing.set_random_seed(0) — keep a single invocation (either before or after the is_causal block) to preserve deterministic behavior; update the area around the is_causal check and total_flops adjustment so only one tilelang.testing.set_random_seed(0) call remains (refer to tilelang.testing.set_random_seed and the is_causal / total_flops code block to locate the redundant line).examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py (1)
33-36:⚠️ Potential issue | 🟡 MinorError message inconsistent with assertion.
The assertion checks for
T.bfloat16but the error message states "float16". This could confuse users when debugging dtype mismatches.📝 Proposed fix
assert out_dtype in [ T.bfloat16, T.float32, - ], "Currently only float16 and float32 are supported" + ], "Currently only bfloat16 and float32 are supported"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py` around lines 33 - 36, The assertion message is inconsistent with the checked dtypes: update the error string to match the actual allowed types (T.bfloat16 and T.float32) used in the assert for variable out_dtype; e.g., change the message to "Currently only bfloat16 and float32 are supported" so it references T.bfloat16 and T.float32 consistently with the assert on out_dtype.examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py (1)
161-174:⚠️ Potential issue | 🔴 CriticalUndefined variable
D- useKinstead.Line 161 extracts the dimension as
Kfromk.shape, but lines 171-174 useDwhich is undefined in this scope. This will cause aNameErrorat runtime.🐛 Proposed fix
def parallel_nsa_fwd( ... ): B, C_SEQ_LEN, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] batch = len(offsets) - 1 HQ = q.shape[2] G = HQ // H BS = block_size WS = window_size o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device) native_sparse_attention_varlen( - q.view(C_SEQ_LEN, HQ, D), - k.view(C_SEQ_LEN, H, D), - v.view(C_SEQ_LEN, H, D), + q.view(C_SEQ_LEN, HQ, K), + k.view(C_SEQ_LEN, H, K), + v.view(C_SEQ_LEN, H, K), o_slc.view(C_SEQ_LEN, HQ, V),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py` around lines 161 - 174, The code uses an undefined symbol D when calling native_sparse_attention_varlen; replace D with the correct dimension variable K (extracted earlier as part of B, C_SEQ_LEN, H, K, V, S = *k.shape ...) so the q/k/v/o_slc reshapes use K instead of D; update the calls/reshape expressions around native_sparse_attention_varlen (references: q.view(..., D) k.view(..., D) v.view(..., D) o_slc.view(..., V)) to use K for the embedding dimension.examples/deepseek_mhc/example_mhc_post.py (1)
16-45:⚠️ Potential issue | 🟠 MajorAllocation and loop bounds should derive from tensor-extracted dimensions, not unconstrained scalar parameters.
Tensor shapes are inferred via
T.const()(lines 18–22), giving symbolic dimensions (hc_a, hc_c, h, h_d, h_x). However, memory allocations at lines 32–41 and loop bounds at line 45 use scalar function parameters (hc, hidden) instead. Without explicit guards, if the caller passes scalar values differing from actual tensor shapes, allocations will mismatch tensor extents and line 46 indexing into tensorb[i_n, 0, i0_h * h_blk](wherebhas shape [n, hc_c, h]) can exceed bounds.Align all allocation and loop-bound references to use the extracted symbolic dimensions (prefer
hover scalarhiddenat line 45; anchorhc-based allocations to a specific tensor dimension).Suggested direction
- for i0_h in T.Pipelined(T.ceildiv(hidden, h_blk), num_stages=2): + for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2):def mhc_post( x: torch.Tensor, residual: torch.Tensor, post_layer_mix: torch.Tensor, comb_res_mix: torch.Tensor, ) -> torch.Tensor: out = torch.empty_like(residual) + hc, hidden = residual.shape[1], residual.shape[2] - mhc_post_tilelang(comb_res_mix, residual, post_layer_mix.squeeze(-1), x, out, residual.shape[-2], residual.shape[-1]) + mhc_post_tilelang(comb_res_mix, residual, post_layer_mix.squeeze(-1), x, out, hc, hidden) return out🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_mhc/example_mhc_post.py` around lines 16 - 45, The kernel currently uses scalar parameters hc and hidden for allocations and loop bounds, which can mismatch the tensor-inferred symbolic dims (hc_a, hc_b, hc_c, h, h_d, h_x) and cause out-of-bounds accesses; update mhc_post_tilelang to derive sizes from the T.const() symbols: replace uses of hc with the appropriate per-tensor symbols (e.g., a_local should be (hc_a, hc_b), c_local size hc_c, x_shared/x_local should use hc_a and the h_x extent, b_shared/b_local should use hc_c and h, d_shared/d_local should use h_d), compute h_blk = math.gcd(h, h_blk) and change the loop bound T.ceildiv(hidden, h_blk) to T.ceildiv(h, h_blk) so all allocations and tiling iterate over the actual tensor dimensions (use hc_a/hc_b/hc_c/h/h_d/h_x where relevant).examples/cast/example_group_per_split_token_cast_to_fp8.py (1)
204-211:⚠️ Potential issue | 🟠 MajorBug: dtype comparisons use strings but
dtypeis a TileLang type.The global
dtype = T.bfloat16(line 8) is a TileLang type object, not a string. The comparisonsdtype == "float",dtype == "float16",dtype == "bfloat16"will always beFalse, causing all valid dtypes to raiseValueError.Compare with
main()(lines 160-167) which correctly usesdtype == T.float,dtype == T.float16,dtype == T.bfloat16.🐛 Proposed fix
def run_regression_perf(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None): if batch_sizes is None: batch_sizes = [2048, 6144] - if dtype == "float": + if dtype == T.float: x = torch.randn(M, N, device="cuda", dtype=torch.float32) - elif dtype == "float16": + elif dtype == T.float16: x = torch.randn(M, N, device="cuda", dtype=torch.float16) - elif dtype == "bfloat16": + elif dtype == T.bfloat16: x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) else: raise ValueError(f"Unsupported dtype: {dtype}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/cast/example_group_per_split_token_cast_to_fp8.py` around lines 204 - 211, The dtype comparisons are using string literals but the module-level dtype is a TileLang type (e.g., T.bfloat16); change the conditional in the block that creates x (the branches that set x = torch.randn(...)) to compare against TileLang constants (dtype == T.float, dtype == T.float16, dtype == T.bfloat16) and map those to the correct torch dtypes (torch.float32, torch.float16, torch.bfloat16) so the right branch runs instead of raising the ValueError; update the conditionals and keep the same torch.randn calls for each matching TileLang dtype.examples/dsa_sparse_finetune/sparse_mla_bwd.py (1)
246-288:⚠️ Potential issue | 🟠 Major
return_kernelis now a no-op.The wrapper still exposes
return_kernel, but this branch disappeared in the refactor and the function always launches preprocess/bwd/postprocess eagerly. That is a silent API regression for callers that previously usedreturn_kernel=Trueto stage a reusable kernel; either restore the branch or remove the parameter.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/dsa_sparse_finetune/sparse_mla_bwd.py` around lines 246 - 288, The wrapper sparse_mla_bwd exposes return_kernel but no longer respects it; restore the original behavior by adding a branch that, when return_kernel=True, returns the staged kernel objects (the results/closures produced by preprocess/bwd/postprocess) instead of eagerly running them, or remove the return_kernel parameter from sparse_mla_bwd and all callers to avoid the silent API regression; locate the symbols preprocess, bwd, postprocess and the return_kernel parameter in sparse_mla_bwd to implement the fix (either reintroduce the conditional that returns kernels when return_kernel is True, or delete the parameter and update callers).examples/deepseek_v32/sparse_mla_bwd.py (1)
231-255:⚠️ Potential issue | 🟠 Major
return_kernelis no longer honored here.The wrapper still accepts
return_kernel, but it now always executes the eager preprocess/bwd/postprocess path and never returns a reusable kernel handle. Keeping the flag while ignoring it is a silent API break; either restore the old behavior or drop the parameter.
🟠 Major comments (13)
examples/deepseek_v32/fp8_lighting_indexer.py-186-190 (1)
186-190:⚠️ Potential issue | 🟠 MajorMissing bounds check causes potential out-of-bounds write.
When
seq_len_kvis not a multiple ofblock_K(4096), the final iteration writes to indices beyond the tensor bounds.For example, if
seq_len_kv=5000:
T.ceildiv(5000, 4096) = 2loop iterations- Iteration 1 computes
idxvalues from 4096 to 8191- Writes to
Logits[bx, idx]foridx >= 5000are out-of-boundsAdd a bounds check before the write.
Proposed fix
for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): for k_i in T.serial(block_K // threads): idx = n_i * block_K + k_i * threads + tx - if idx < cu_k_s or idx >= cu_k_e: - Logits[bx, idx] = -T.infinity(T.float32) + if idx < seq_len_kv: + if idx < cu_k_s or idx >= cu_k_e: + Logits[bx, idx] = -T.infinity(T.float32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_v32/fp8_lighting_indexer.py` around lines 186 - 190, The loop that writes Logits[bx, idx] can write past seq_len_kv when seq_len_kv is not a multiple of block_K; update the inner loop (the block using T.Pipelined / T.ceildiv with variables block_K, threads, tx computing idx) to include an additional bounds check such as "if idx < seq_len_kv and idx >= 0 and idx >= cu_k_s and idx < cu_k_e" (or simply add "and idx < seq_len_kv" to the existing cu_k_s/cu_k_e condition) before performing Logits[bx, idx] = -T.infinity(T.float32) so no out-of-bounds writes occur.examples/attention_sink/example_gqa_sink_fwd_varlen.py-101-123 (1)
101-123:⚠️ Potential issue | 🟠 MajorClamp empty KV windows before entering
T.Pipelined.Line 119 can go negative on batches shorter than
max_seqlen_qwhenwindow_sizeis set, because the extrabxtiles may computestart > end. That makes the pipelined K loop invalid for ragged inputs.🩹 Minimal fix
- loop_range = end - start + loop_range = T.max(0, end - start)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/attention_sink/example_gqa_sink_fwd_varlen.py` around lines 101 - 123, The pipelined K-loop can receive a negative range when start > end (ragged inputs with window_size), so clamp the window before calling T.Pipelined by computing loop_range = max(0, end - start) and only enter the T.Pipelined loop when loop_range > 0; keep the existing start/end usage for computing actual_k and the K_unpad slice (symbols: start, end, loop_range, T.Pipelined, actual_k, K_unpad).examples/deepseek_v32/sparse_mla_fwd_seesaw.py-91-98 (1)
91-98:⚠️ Potential issue | 🟠 MajorInitialize
lsefor the queries the grid skips.When
CP0is true, the x-grid intentionally skips the firstkv_stride - 1queries, but the post-fixup only repairsout. The returnedlseslice for that prefix is never written, so callers get garbage values wheneverkv_stride > 1.Also applies to: 505-508
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_v32/sparse_mla_fwd_seesaw.py` around lines 91 - 98, The kernel skips the first (kv_stride - 1) queries when CP0 is true but never initializes the corresponding lse entries, leaving garbage returned to callers; inside the T.Kernel that creates (bx,by,bz) (the CP0 path), set the prefix length p = max(0, kv_stride - 1) (or compute p = max(0, seq_len - kv_stride + 1) depending on the existing grid math) and explicitly initialize lse[:p] to the correct neutral/underflow value used elsewhere (e.g., -inf or the same filled value used when masking attention) before any writes to out/other buffers, and apply the same initialization fix to the other analogous kernel instance flagged in the comment (the later T.Kernel usage around the second occurrence).examples/deepseek_mla/example_mla_decode.py-245-255 (1)
245-255:⚠️ Potential issue | 🟠 MajorBoth correctness checks still use the stale six-argument reference.
Lines 255 and 290 validate the eager
(Q, Q_pe, KV, K_pe)kernel againstref_program(q, q_pe, kv, k_pe, glse, Output_partial). Either slimref_programdown to four inputs or pass a 4-arg wrapper at both call sites, otherwisemain()andrun_regression_perf()stay out of sync with the refactor.Also applies to: 280-290
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_mla/example_mla_decode.py` around lines 245 - 255, The correctness checks call profiler.assert_allclose(ref_program, ...) but ref_program still expects six args (q, q_pe, kv, k_pe, glse, Output_partial); make the calls and reference consistent by either slimming ref_program to accept only the four inputs used by the eager kernel (Q, Q_pe, KV, K_pe) or add a thin 4-arg wrapper (e.g., ref_program4) that forwards those four arguments to the existing ref_program with suitable defaults for glse and Output_partial, and then use that 4-arg function at both profiler.assert_allclose call sites (the checks invoked from main() and run_regression_perf()) so the kernel and reference signatures match.examples/deepseek_mla/example_mla_decode_ws.py-569-579 (1)
569-579:⚠️ Potential issue | 🟠 MajorKeep the reference callable aligned with the compiled entrypoint.
Line 579 still hands
assert_allclosethe pre-refactor six-argumentref_program, even though the compiled kernel now exposes only(Q, Q_pe, KV, K_pe). The validation step stays stale until those unused parameters are removed or shimmed away here.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_mla/example_mla_decode_ws.py` around lines 569 - 579, The reference callable passed to profiler.assert_allclose is still the old six-argument ref_program but the compiled entrypoint now exposes only (Q, Q_pe, KV, K_pe); update the validation to use a reference callable that matches the compiled signature by either changing ref_program to accept only (Q, Q_pe, KV, K_pe) or provide a small shim/wrapper (e.g., lambda or def ref_shim(Q, Q_pe, KV, K_pe): return ref_program(Q, Q_pe, KV, K_pe, <shim_K>, <shim_V>)) and pass that shim to profiler.assert_allclose so the callable signatures align with the compiled entrypoint.examples/deepseek_mla/example_mla_decode_persistent.py-208-219 (1)
208-219:⚠️ Potential issue | 🟠 MajorUpdate
ref_programto the new eager signature.Line 219 now validates a kernel whose public tensor inputs are just
(Q, Q_pe, KV, K_pe), butref_programstill declares the removedglseandOutput_partialparameters. Either drop those params fromref_programor wrap it in a 4-arg shim before callingassert_allclose.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_mla/example_mla_decode_persistent.py` around lines 208 - 219, ref_program's signature still expects the removed glse and Output_partial parameters but profiler.assert_allclose (via compiled.get_profiler()) now supplies only (Q, Q_pe, KV, K_pe); update ref_program to match by either removing the obsolete glse and Output_partial parameters from ref_program's declaration or create a 4-argument shim that accepts (Q, Q_pe, KV, K_pe) and calls the old ref_program supplying appropriate default/dummy values for glse and Output_partial before passing to profiler.assert_allclose; target the ref_program symbol (or place the shim right before the call to profiler.assert_allclose) so the profiler and ref_program signatures match.examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py-276-278 (1)
276-278:⚠️ Potential issue | 🟠 MajorFix the reference call to match
_get_inputs().
profiler._get_inputs()now reflects the eager compiled entrypoint, but Line 278 still splats those tensors into a six-argumentref_program. Call a 4-input reference here, or updateref_program, before this correctness path can work reliably again.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py` around lines 276 - 278, The correctness check is calling ref_program with six args while profiler._get_inputs() now returns the eager compiled entrypoint inputs; update the call to match the 4-input reference signature by passing only the first four tensors from profiler._get_inputs() (use profiler._get_inputs() -> input_tensors and call ref_program with those four), or alternatively update ref_program to accept the full set if that is intended; adjust either the ref_program signature or the caller at the site where compiled(*input_tensors) and ref_program(...) are invoked so both use the same 4-input contract (referencing profiler._get_inputs(), compiled, and ref_program).examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py-269-275 (1)
269-275:⚠️ Potential issue | 🟠 Major
--autotuneno longer changes what gets benchmarked.Lines 274-275 unconditionally compile the fixed
(BLOCK_N, BLOCK_H, num_split, threads)config after the autotune branch, so the later profiler path ignores the autotuned result. With--autotune, the script now does extra work and still reports the hard-coded kernel.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py` around lines 269 - 275, The profiling currently always compiles the hard-coded (BLOCK_N, BLOCK_H, num_split, threads) config, ignoring the autotuned run; change the compile call to match the invocation: if enable_autotune is true, call flashmla_decode.compile(Q, Q_pe, KV, K_pe) (no fixed params) so the compiled kernel matches the autotuned result and avoid the extra fixed-params compile, otherwise compile with the explicit block_N, block_H, num_split, threads arguments as before; ensure the profiler is taken from that conditionally-selected compiled object (use compiled.get_profiler()).examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py-388-397 (1)
388-397:⚠️ Potential issue | 🟠 MajorThis perf harness is timing setup and an empty workload.
The timed closure recreates
SparseFlashAttn, regenerates Q/K/V, and leavesblock_indices_tensorentirely-1. Even with the combine guard fixed, that still measures setup cost and a no-op path instead of sparse decode throughput.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py` around lines 388 - 397, The timed harness in run_kernel_only recreates SparseFlashAttn and regenerates Q/K/V each run and passes an all -1 block_indices_tensor, so you end up timing setup and a no-op path; to fix, move creation of SparseFlashAttn, Q, K, V, cache_seqlens and block_indices_tensor outside the timed closure so they are allocated once, populate block_indices_tensor with realistic non -1 block indices (or a mix representative of sparse_ratio) for each batch/head_kv instead of leaving it all -1, and then call sparse_kernel(...) repeatedly inside the timed loop to measure real sparse decode throughput (refer to run_kernel_only, SparseFlashAttn, block_indices_tensor, cache_seqlens, Q/K/V).examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py-410-419 (1)
410-419:⚠️ Potential issue | 🟠 MajorMove setup out of the timed closure.
run_kernel_only()recreates the module, allocates/fills every CUDA tensor, and benchmarks an all-false mask on each sample. That makes the reported number mostly reflect setup work and the empty-path behavior rather than sparse decode performance.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py` around lines 410 - 419, The timed closure run_kernel_only is recreating the SparseFlashAttn model and allocating/filling CUDA tensors (Q, K, V, cache_seqlens, block_mask_tensor, num_blocks) on each invocation, skewing the benchmark; move the setup out of run_kernel_only so the model, dtype, Q, K, V, cache_seqlens, num_blocks and block_mask_tensor are created once before timing and run_kernel_only only calls model(Q, K, V, block_mask_tensor, cache_seqlens) (and if you need per-iteration variability, update only the minimal tensors inside the timed loop).examples/convolution/example_convolution_autotune.py-77-95 (1)
77-95:⚠️ Potential issue | 🟠 Major
dtype/accum_dtypeare public knobs, but the kernel hard-codes fp16/fp32.Lines 94-95 bind the inputs with the caller-provided types, then Lines 100-101 overwrite those variables before the shared/output tensors are created. A non-default call will therefore mix caller-typed inputs with fp16/fp32 internals. Either honor the parameters end-to-end or remove them from the signature.
Also applies to: 100-101
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/convolution/example_convolution_autotune.py` around lines 77 - 95, The convolution function currently binds the input tensors to the caller-provided dtype/accum_dtype but then overwrites those types with hard-coded fp16/fp32 (reassignments of dtype and accum_dtype), causing mixed-type internals; update the function so it either removes dtype/accum_dtype from the signature or — preferably — stop overriding the passed-in dtype and accum_dtype and use the provided dtype/accum_dtype when declaring shared/output tensors and any internal casts (look for the places where dtype and accum_dtype are reassigned to T.float16/T.float32 and remove those reassignments so that data/kernel_weight and tensors created later respect the original dtype and accum_dtype parameters).examples/blocksparse_gemm/example_blocksparse_gemm.py-125-132 (1)
125-132:⚠️ Potential issue | 🟠 MajorGenerate
BlockMaskwith the same tiling the kernel is using.Lines 125-127, 144-145, and 180-181 size the mask with
DEFAULT_BLOCK_*and floor division, while the kernel indexesBlockMask[by, bx, k]using the activeblock_M/N/Kand aceildivgrid. Any autotune candidate that changes tile sizes—or any input with remainder tiles—will misalign the mask and can read the wrong sparsity decisions. Please either rebuild the mask per config, or keep tile sizes fixed and size the tensor with ceiling division/assert divisibility.Also applies to: 143-158, 179-193
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/blocksparse_gemm/example_blocksparse_gemm.py` around lines 125 - 132, The mask is built using DEFAULT_BLOCK_M/N/K and floor division (mask_shape = (M // block_M, N // block_N, K // block_K)) but the kernel and autotuner use the active tile sizes and ceildiv when indexing BlockMask[by, bx, k], causing misalignment for different tile candidates or remainder tiles; fix by generating/rebuilding block_mask using the kernel's actual tile sizes (use ceildiv(M, tile_M), ceildiv(N, tile_N), ceildiv(K, tile_K)) taken from the compiled kernel or best_config (e.g., after blocksparse_matmul.compile/getting best_config) or alternatively enforce fixed tiles and size the mask with ceiling division or assert exact divisibility so BlockMask indexing aligns with the kernel's block_M/block_N/block_K.examples/convolution/example_convolution_autotune.py-92-99 (1)
92-99:⚠️ Potential issue | 🟠 MajorGuard rectangular filters or compute
OWfromKW.Lines 97-99 collapse the filter shape to
K = KHand use that single extent for both output dimensions. After the tensor-first API change,kernel_weightcan be rectangular, soKH != KWwill produce the wrong width/output indexing.💡 One safe option
- K = KH - OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 - OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + assert KH == KW, "Only square kernels are supported" + OH = (H + 2 * P - D * (KH - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (KW - 1) - 1) // S + 1🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/convolution/example_convolution_autotune.py` around lines 92 - 99, The code incorrectly collapses the kernel shape by setting K = KH and using it for both spatial dims, which breaks rectangular kernels; change the logic to keep distinct kernel heights and widths (use KH and KW separately), compute OH using KH (or K_h) and OW using KW (or K_w), and update any downstream uses of K to use the appropriate per-dimension symbol (e.g., replace references to K when computing OW/indexing with KW or K_w); ensure kernel_weight is treated as [KH, KW, C, F] when computing output sizes and indexing.
🟡 Minor comments (9)
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py-517-518 (1)
517-518:⚠️ Potential issue | 🟡 MinorFix
acc_s_castallocation to use fragment instead of shared memory.Line 518 allocates
acc_s_castwithT.alloc_shared, but all other NSA examples (example_tilelang_nsa_fwd.py,example_tilelang_nsa_fwd_varlen.py,example_tilelang_nsa_decode.py,example_tilelang_nsa_bwd.py) useT.alloc_fragment. For consistency and performance (fragments stay in registers), change toT.alloc_fragment([G, BS], dtype).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py` around lines 517 - 518, The allocation for acc_s_cast is using T.alloc_shared but should use a fragment like the other NSA examples; update the acc_s_cast allocation to T.alloc_fragment([G, BS], dtype) (matching acc_s's shape [G, BS] and using dtype) so acc_s_cast is allocated as a fragment (register-backed) instead of shared memory for consistency and performance.examples/attention_sink/example_mha_sink_bwd_bhsd.py-135-139 (1)
135-139:⚠️ Potential issue | 🟡 MinorRename the new
Oidentifiers before lint fails.Lines 135-139 and Line 490 use
O, which Ruff flags as ambiguous (E741). Renaming those new identifiers tooutoroutputshould clear the lint regression.Also applies to: 490-490
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/attention_sink/example_mha_sink_bwd_bhsd.py` around lines 135 - 139, The identifier O in flashattn_bwd_preprocess should be renamed to a non-ambiguous name (e.g., out or output) to avoid Ruff E741; update the function signature and its typed annotation T.Tensor[[batch, heads, seq_len, dim], T.float16] to use the new name and change all internal references to O accordingly, and apply the same renaming for the other occurrence at the site referenced around line 490 so both locations consistently use the new identifier (also verify any related uses like dO remain correct).examples/flash_attention/example_mha_bwd_bhsd.py-88-91 (1)
88-91:⚠️ Potential issue | 🟡 MinorRename
Obefore Ruff fails this helper.Lines 88-91 introduce
O, which Ruff flags as ambiguous (E741). Renaming it tooutkeeps the new eager-style preprocess helper lint-clean.🩹 Minimal fix
-def flashattn_bwd_preprocess(O, dO): +def flashattn_bwd_preprocess(out, dO): batch, heads, seq_len, dim = T.const("batch heads seq_len dim") - O: T.Tensor[[batch, heads, seq_len, dim], T.float16] + out: T.Tensor[[batch, heads, seq_len, dim], T.float16] dO: T.Tensor[[batch, heads, seq_len, dim], T.float16] Delta = T.empty([batch, heads, seq_len], T.float32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention/example_mha_bwd_bhsd.py` around lines 88 - 91, The helper function flashattn_bwd_preprocess declares a parameter named O which triggers Ruff E741; rename the parameter O to out (and update the corresponding type annotation and any internal uses) so the signature becomes flashattn_bwd_preprocess(out, dO) and the annotation uses out: T.Tensor[[batch, heads, seq_len, dim], T.float16]; keep dO unchanged and ensure all references inside the function and any callers are updated to use out instead of O.examples/amd/example_amd_flash_attn_bwd.py-238-245 (1)
238-245:⚠️ Potential issue | 🟡 MinorRename
Obefore Ruff fails this helper.Line 238 introduces
O, which Ruff flags as ambiguous (E741). Renaming the parameter and its annotation tooutkeeps the new eager helper lint-clean.🩹 Minimal fix
-@tilelang.jit -def flashattn_bwd_preprocess(O, dO, batch: int = 1, heads: int = 8, seq_len: int = 1024, dim: int = 64): +@tilelang.jit +def flashattn_bwd_preprocess(out, dO, batch: int = 1, heads: int = 8, seq_len: int = 1024, dim: int = 64): batch, heads, seq_len, dim = T.const("batch heads seq_len dim") dtype = T.float16 accum_dtype = T.float32 blk = 32 - O: T.Tensor[[batch, seq_len, heads, dim], dtype] + out: T.Tensor[[batch, seq_len, heads, dim], dtype] dO: T.Tensor[[batch, seq_len, heads, dim], dtype] Delta = T.empty([batch, heads, seq_len], accum_dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/amd/example_amd_flash_attn_bwd.py` around lines 238 - 245, The helper function flashattn_bwd_preprocess uses the single-letter parameter name O which triggers Ruff E741; rename the parameter and its annotated binding from O to a clearer name (e.g., out) and update all references and the type annotation (change "O" to "out" in the parameter list and the "O: T.Tensor[[...], dtype]" line) so the function signature and internal annotations remain consistent and lint-clean while preserving dO and other symbols.examples/flash_attention/example_gqa_bwd_tma_reduce.py-350-353 (1)
350-353:⚠️ Potential issue | 🟡 MinorInconsistent copy pattern for
dkvsdv.Lines 350-351 copy
dvtodv_sharedthendv_sharedtodV, but lines 352-353 copydktodk_sharedthen copydk(notdk_shared) directly todK. This appears inconsistent and may be a copy-paste error.🔧 Suggested fix for consistency
T.copy(dv, dv_shared) T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) + T.copy(dk_shared, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention/example_gqa_bwd_tma_reduce.py` around lines 350 - 353, The copy sequence is inconsistent: after copying dk into dk_shared you then copy dk (not dk_shared) into dK; change the second copy to use dk_shared so it mirrors the dv -> dv_shared -> dV pattern. Update the T.copy that currently uses dk as source to instead use dk_shared for writing into dK (the statements involving dk, dk_shared, and dK with indices bx, groups, bz, by, block_M).examples/flash_attention/example_gqa_bwd.py-340-343 (1)
340-343:⚠️ Potential issue | 🟡 MinorSame inconsistent copy pattern for
dkas inexample_gqa_bwd_tma_reduce.py.Lines 340-341 copy
dvtodv_sharedthendv_sharedtodV, but lines 342-343 copydktodk_sharedthen copydk(notdk_shared) directly todK. This is the same inconsistency noted in the TMA reduce variant.🔧 Suggested fix for consistency
T.copy(dv, dv_shared) T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) + T.copy(dk_shared, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention/example_gqa_bwd.py` around lines 340 - 343, The copy for dk is inconsistent: after copying dk into dk_shared you must copy dk_shared into dK (just like dv -> dv_shared -> dV). Update the second dk copy so that T.copy(dk_shared, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) is used instead of copying dk directly; this keeps the pattern for dk consistent with dv and matches the behavior in functions using dv/dv_shared and dV.examples/attention_sink/example_gqa_sink_bwd_bhsd.py-69-76 (1)
69-76:⚠️ Potential issue | 🟡 MinorFragment allocation size mismatch for
sinks.The
sinksfragment is allocated with size[heads]but the parallel loop at lines 75-76 iterates overblock_Melements. Since you're copying a single valueSinks[by]to allblock_Mpositions, the fragment should be allocated as[block_M]to match the loop bounds and consistent with other position-level fragments.🔧 Suggested fix
- sinks = T.alloc_fragment([heads], dtype) + sinks = T.alloc_fragment([block_M], dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/attention_sink/example_gqa_sink_bwd_bhsd.py` around lines 69 - 76, The sinks fragment is allocated with the wrong size—T.alloc_fragment([heads], dtype) but the loop for i in T.Parallel(block_M) writes sinks[i]; change the allocation to match the loop by allocating sinks with size [block_M] (i.e., use T.alloc_fragment([block_M], dtype)) so that sinks, the T.alloc_fragment call, and the parallel loop filling sinks from Sinks[by] are consistent.examples/deepseek_mla/example_mla_decode_paged.py-26-26 (1)
26-26:⚠️ Potential issue | 🟡 MinorMake the
softmax_scaleannotation match itsNonedefault.Line 26 advertises
softmax_scaleas a plainfloat, but the function immediately treatsNoneas a valid value. Update the annotation tofloat | NoneorOptional[float], or drop it iftilelang.jitrequires the scalar annotation to remain concrete.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_mla/example_mla_decode_paged.py` at line 26, The softmax_scale parameter is annotated as float but defaults to None; update the function signature for the parameter softmax_scale in example_mla_decode_paged (the function taking softmax_scale: float = None) to use a nullable type (e.g., softmax_scale: float | None or softmax_scale: Optional[float]) and add the necessary typing import if you choose Optional; alternatively, if tilelang.jit requires a concrete scalar annotation, remove the default None or adjust the call sites to pass a concrete float so the annotation and default are consistent.examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py-219-223 (1)
219-223:⚠️ Potential issue | 🟡 MinorImport
autotuneexplicitly instead of using a star import.Line 219 uses the
@autotunedecorator, but the file relies onfrom tilelang.autotuner import *at line 3. This causes a Ruff F405 undefined name warning. Since onlyautotuneis used from this import (theget_configs()function is defined locally), explicit import is preferred.🧹 Minimal cleanup
-from tilelang.autotuner import * +from tilelang.autotuner import autotune🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py` around lines 219 - 223, Replace the star import from tilelang.autotuner with an explicit import for autotune (or add an explicit "from tilelang.autotuner import autotune") so the `@autotune` decorator used on the kernel function is defined; update the import near the top of the file where tilelang.autotuner is currently imported with a wildcard and leave get_configs() as-is since it is defined locally.
🧹 Nitpick comments (14)
examples/deepseek_mhc/example_mhc_bwd.py (1)
69-76: Unused constantsn_sandn_s2create potential confusion.The constants
n_sandn_s2extracted viaT.const()are only used in the tensor annotations (lines 70-71) but never referenced in the kernel logic. All allocations and loop bounds usen_streaminstead:
- Line 73:
tensor_shape = [seqlen, n_stream, n_stream]- Line 100:
T.alloc_fragment([tilesize, n_stream, n_stream], ...)- Line 80:
T.Parallel(tilesize, n_stream, n_stream)If the input tensors' actual shape dimensions differ from
n_stream, there could be a silent shape mismatch between what's annotated versus what's allocated. Consider either:
- Using
n_s/n_s2consistently throughout the kernel, or- Removing them if
n_streamis the authoritative compile-time shape♻️ Option 1: Use extracted constants consistently
- seqlen, n_s, n_s2 = T.const("seqlen n_s n_s2") + seqlen, n_s, _ = T.const("seqlen n_s n_s2") out: T.Tensor[[seqlen, n_s, n_s2], T.float32] dout: T.Tensor[[seqlen, n_s, n_s2], T.float32] - tensor_shape = [seqlen, n_stream, n_stream] + tensor_shape = [seqlen, n_s, n_s] dtype = T.float32 res = T.empty(tensor_shape, dtype)Then update all
n_streamreferences in macros and allocations ton_s.♻️ Option 2: Simplify by removing unused constants
- seqlen, n_s, n_s2 = T.const("seqlen n_s n_s2") - out: T.Tensor[[seqlen, n_s, n_s2], T.float32] - dout: T.Tensor[[seqlen, n_s, n_s2], T.float32] + seqlen = T.const("seqlen") + out: T.Tensor[[seqlen, n_stream, n_stream], T.float32] + dout: T.Tensor[[seqlen, n_stream, n_stream], T.float32]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_mhc/example_mhc_bwd.py` around lines 69 - 76, The code defines compile-time constants n_s and n_s2 via T.const but never uses them in kernel logic (allocations and loops use n_stream), causing silent shape mismatch risk; either (A) replace n_stream with n_s/n_s2 across the kernel (update tensor_shape, T.alloc_fragment calls, and T.Parallel bounds) so annotations and allocations match the T.const annotations, or (B) remove the unused T.const("n_s n_s2") and the n_s/n_s2 symbols (and their annotations in out/dout) and keep n_stream as the single authoritative compile-time size; update the tensor_shape, res allocation, and any T.alloc_fragment / T.Parallel usages accordingly and ensure the out/dout annotations match the chosen shape variables.examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py (1)
140-140: Unused parameterin_dtype.The
in_dtypeparameter is passed toassert_tl_gemm_correctnessbut is never used within the function. The input dtype is determined byA_fp8.dtypeimplicitly. Consider removing this parameter to avoid confusion.🧹 Proposed fix
-def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype): +def assert_tl_gemm_correctness(M, N, K, block_N, out_dtype, accum_dtype):And update the call sites in
main()and the__main__block accordingly.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py` at line 140, The function assert_tl_gemm_correctness currently accepts an unused parameter in_dtype; remove in_dtype from its signature and from any internal references, then update all call sites (notably in main() and the __main__ block) to stop passing that argument so the parameter list matches the new signature; search for assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype) and replace calls to assert_tl_gemm_correctness(M, N, K, block_N, out_dtype, accum_dtype) (or reorder to match the updated signature) to ensure types still align with how A_fp8.dtype is used.examples/deepseek_nsa/example_tilelang_nsa_bwd.py (1)
2-6: Duplicatetorchimport.
torchis imported twice (lines 2 and 6).♻️ Remove duplicate import
# ruff: noqa import torch from typing import Optional, Union from packaging.version import parse -import torch import triton🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_nsa/example_tilelang_nsa_bwd.py` around lines 2 - 6, The file contains a duplicate import of the torch module; remove the redundant import statement so only a single "import torch" remains at the top of the module (look for the two occurrences of the "import torch" statement among the top-level imports alongside "from typing import Optional, Union" and "from packaging.version import parse")—delete one of the identical imports to avoid the duplication.examples/deepseek_v32/inference/kernel.py (1)
147-147: Inconsistent 1D shape tuple syntax.
(block_M)evaluates to justblock_M(an integer), not a single-element tuple. For consistency with other 1D allocations in this file (e.g., line 65:(blk_m,)), consider using a trailing comma.Suggested fix
- Scale_C_shared = T.alloc_shared((block_M), FP32) + Scale_C_shared = T.alloc_shared((block_M,), FP32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_v32/inference/kernel.py` at line 147, The allocation call for Scale_C_shared uses (block_M) which is an int, not a 1-tuple; update the T.alloc_shared invocation for Scale_C_shared to use a single-element tuple syntax (block_M,) to match other 1D allocations (e.g., at blk_m,) and ensure the shape is interpreted as a 1D array by the allocator.examples/deepseek_mhc/example_mhc_pre.py (1)
35-35: Unused symbolic dimension declaration.This line creates two symbolic constants named literally
"_"that are immediately discarded. If the intent was to skip dimensions from a tensor shape, there's no corresponding tensor type annotation that uses"_ _". This appears to be dead code that can be removed.🧹 Suggested fix
n_splits_d, num_tokens, hc_mult3_d = T.const("n_splits_d num_tokens hc_mult3_d") - _, _ = T.const("_ _") hc_scale_len = T.const("hc_scale_len")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_mhc/example_mhc_pre.py` at line 35, The line calling T.const("_ _") creates unused symbolic constants and should be removed; delete the expression "_, _ = T.const(\"_ _\")" from example_mhc_pre.py (or replace it with the correct tensor type annotation if you actually intended to declare anonymous dimensions), ensuring no dead code remains and that any needed symbolic dimensions are declared where a tensor shape/type uses them (e.g., in the relevant T.tensor/T.type annotation) instead of as a discarded const call.examples/deepseek_v32/topk_selector.py (1)
69-72: Scan only the active[start, end)window.Line 69 and Line 97 still iterate over
ceildiv(seq_len, BLOCK_SIZE), so short rows pay O(max_seq_len) work and just branch away most lanes. Basing both stage-1 loops onl_end_idx - l_start_idxwould make the new varlen path proportional to each row’s actual length.Also applies to: 97-100
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_v32/topk_selector.py` around lines 69 - 72, The loop over T.serial currently uses ceildiv(seq_len, BLOCK_SIZE) causing all rows to iterate to max length; change both stage-1 loops that use T.serial(T.ceildiv(seq_len, BLOCK_SIZE)) so they instead iterate over T.ceildiv(l_end_idx - l_start_idx, BLOCK_SIZE) (i.e., base the iteration count on the active window length), and compute input_idx = s * BLOCK_SIZE + tx + l_start_idx (or otherwise offset by l_start_idx) before the existing bounds checks; update references in the same block that load input (convert_to_uint16(input[bx, input_idx])) so they use the new input_idx/offset logic to make work proportional to each row’s actual length.examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
135-159: Same ambiguousOparameter name as other files.Consider renaming
OtoOutorOutputto address the E741 static analysis warning, consistent with any changes made in other files.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/attention_sink/example_gqa_sink_bwd_bhsd.py` around lines 135 - 159, Rename the ambiguous parameter O in the function flashattn_bwd_preprocess to a clearer name (e.g., Output or Out) and update every reference inside the function (parameter list, T.copy sources, T.alloc_fragment uses that reference, and any other occurrences like the copy from O to o) to use the new identifier; keep dO and Delta unchanged and ensure the new name matches the convention used in other files to resolve the E741 static analysis warning.examples/flash_attention/example_gqa_bwd_tma_reduce.py (2)
359-367: Unused shape unpacking variables.The variables
BATCH,N_CTX,H,D_HEAD_QKare unpacked but never used. Consider using underscore prefix per static analysis suggestion.💡 Optional fix
`@staticmethod` def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): - BATCH, N_CTX, H, D_HEAD_QK = q.shape + _BATCH, _N_CTX, _H, _D_HEAD_QK = q.shape block_M = 128 block_N = 64Or simply remove the unpacking if not needed:
`@staticmethod` def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): - BATCH, N_CTX, H, D_HEAD_QK = q.shape block_M = 128 block_N = 64🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention/example_gqa_bwd_tma_reduce.py` around lines 359 - 367, In the forward function, the shape unpacking line "BATCH, N_CTX, H, D_HEAD_QK = q.shape" creates unused variables; either remove the unpacking entirely or rename them to unused placeholders (e.g., _batch, _n_ctx, _h, _d_head_qk) to satisfy static analysis. Update the forward function (around the call to flashattn_fwd and ctx.save_for_backward) so no unused local names remain while preserving q, k, v, causal, groups and use_atomic behavior.
437-438: Unnecessary dummy tensor allocations in non-atomic path.When using the split path,
flashattn_bwd_postprocess(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32))allocates zero tensors fordKanddVjust to satisfy the function signature, then immediately discards them. This wastes GPU memory.Consider either:
- Overloading
flashattn_bwd_postprocessto accept onlydQfor this use case- Or documenting this is intentional for API consistency
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention/example_gqa_bwd_tma_reduce.py` around lines 437 - 438, The call to flashattn_bwd_postprocess in the split/non-atomic path is creating throwaway zero tensors for dK and dV (torch.zeros_like(k, dtype=torch.float32)) which wastes GPU memory; update flashattn_bwd_postprocess to accept optional dK/dV (e.g., default None) or add an overload that only takes dq, and change the call in this file to pass only dq (remove the torch.zeros_like arguments) so the function avoids allocating zeros when dK/dV are not needed; ensure flashattn_bwd_postprocess checks for None and computes/returns dK/dV appropriately or returns placeholders without allocating memory, and update any other callers to the new signature if necessary.examples/flash_attention/example_mha_bwd_bshd.py (1)
85-109: Consider renamingOparameter to avoid ambiguity.Static analysis flags
Oas an ambiguous variable name (E741) since it can be confused with0. WhileOis conventional in attention literature for "Output", consider renaming toOutorOutputfor clarity.💡 Optional rename suggestion
-def flashattn_bwd_preprocess(O, dO): +def flashattn_bwd_preprocess(Out, dO): batch, seq_len, heads, dim = T.const("batch seq_len heads dim") - O: T.Tensor[[batch, seq_len, heads, dim], T.float16] + Out: T.Tensor[[batch, seq_len, heads, dim], T.float16] dO: T.Tensor[[batch, seq_len, heads, dim], T.float16]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention/example_mha_bwd_bshd.py` around lines 85 - 109, Rename the ambiguous parameter O in function flashattn_bwd_preprocess to a clearer name (e.g., Out or Output) and update all internal references accordingly (the function signature, uses in T.copy slices and any variable names that shadow it like the local fragment o if you prefer to keep that). Specifically update the parameter O -> Out (and adjust any code that references O[bz, ...] to Out[bz, ...]) while leaving dO and local fragments (o, do, acc, delta, Delta) unchanged; run tests/lint to ensure the E741 warning is resolved.examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (2)
151-192: SameOparameter naming consideration.This varlen preprocess has the same E741 warning for ambiguous
Ovariable name.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py` around lines 151 - 192, Rename the ambiguous parameter O in flashattn_bwd_preprocess to a clearer name (e.g., output or O_tensor) and update its type annotation and every usage inside the function (all occurrences of O[...] -> output[...]). Ensure the function signature, the typed tensor declaration (previously "O: T.Tensor[...]"), and all reads that index O (used when filling o[i, j]) are updated to the new identifier so there are no remaining references to the old name.
623-624: Same dummy tensor allocation pattern in non-atomic path.The non-atomic path allocates dummy zero tensors for
dKanddVjust to callflashattn_bwd_postprocess, then discards the results. This is consistent with the non-varlen version but wastes memory.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py` around lines 623 - 624, The non-atomic path currently creates fresh dummy zero tensors for dK/dV just to call flashattn_bwd_postprocess(dq, ..., ...), wasting memory; fix this by reusing a single preallocated pair of dummy tensors (e.g., dummy_dk, dummy_dv) allocated once with torch.zeros_like(k, dtype=torch.float32) and torch.zeros_like(v, dtype=torch.float32) and use those in the non-atomic branch when calling flashattn_bwd_postprocess(dq, dummy_dk, dummy_dv), or alternatively update flashattn_bwd_postprocess to accept None for dK/dV and handle that case so you only pass dq (avoid per-iteration allocations); ensure references to dq, dk, dv, k, v and flashattn_bwd_postprocess are used so reviewers can locate the change.examples/attention_sink/example_gqa_sink_bwd_varlen.py (1)
221-262: SameOparameter naming applies here.The preprocess kernel uses the ambiguous
Ovariable name flagged by E741.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/attention_sink/example_gqa_sink_bwd_varlen.py` around lines 221 - 262, Rename the ambiguous tensor parameter O in flashattn_bwd_preprocess to a more descriptive name (e.g., output_tensor or outputs) and update all uses inside the function (alloc/loads from O and the indexing like O[q_start_idx + ...]) to the new name; ensure the function signature, type annotation line "O: T.Tensor[[UQ, heads, dim], T.float16]" and every reference (currently O[...] and the variable passed in when this function is called) are consistently updated to the chosen descriptive identifier to avoid the E741 ambiguous name warning.examples/dequantize_gemm/example_dequant_gemm_fine_grained.py (1)
104-105: Usetorch.randintinstead oftorch.randfor integer-typed tensors.Line 104 uses
torch.rand(..., dtype=getattr(torch, in_dtype)), which works but is semantically confusing whenin_dtypeisT.int8. The test exercises this code path within_dtype=T.int8, and whiletorch.randwill generate values and cast them to int8, usingtorch.randintis clearer and more idiomatic for integer tensor initialization.💡 Suggested fix
+ torch_dtype = getattr(torch, in_dtype) + if str(in_dtype) == "int8": + A = torch.randint(-128, 128, (M, K), device="cuda", dtype=torch_dtype) + else: + A = torch.rand(M, K, device="cuda", dtype=torch_dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/dequantize_gemm/example_dequant_gemm_fine_grained.py` around lines 104 - 105, Replace the use of torch.rand for integer-typed inputs with torch.randint: update the tensor creation for A (currently A = torch.rand(..., dtype=getattr(torch, in_dtype))) to use torch.randint with an appropriate integer range and dtype determined by in_dtype (e.g., for int8 use a signed range like -128 to 127, and for unsigned types use 0..max), so A is created idiomatically as an integer tensor when getattr(torch, in_dtype) is an integer dtype.
| sinks = T.alloc_fragment([heads], dtype) | ||
|
|
||
| T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) | ||
| T.fill(acc_o, 0) | ||
| T.fill(logsum, 0) | ||
| T.fill(scores_max, -T.infinity(accum_dtype)) | ||
| for i in T.Parallel(block_M): | ||
| sinks[i] = Sinks[by] |
There was a problem hiding this comment.
Allocate sinks per row, not per head count.
Line 69 sizes sinks as [heads], but Lines 75-76 index it with i over block_M. With the current defaults, block_M is often larger than heads, so this writes past the fragment bounds.
🩹 Minimal fix
- sinks = T.alloc_fragment([heads], dtype)
+ sinks = T.alloc_fragment([block_M], dtype)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| sinks = T.alloc_fragment([heads], dtype) | |
| T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) | |
| T.fill(acc_o, 0) | |
| T.fill(logsum, 0) | |
| T.fill(scores_max, -T.infinity(accum_dtype)) | |
| for i in T.Parallel(block_M): | |
| sinks[i] = Sinks[by] | |
| sinks = T.alloc_fragment([block_M], dtype) | |
| T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) | |
| T.fill(acc_o, 0) | |
| T.fill(logsum, 0) | |
| T.fill(scores_max, -T.infinity(accum_dtype)) | |
| for i in T.Parallel(block_M): | |
| sinks[i] = Sinks[by] |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/attention_sink/example_mha_sink_bwd_bhsd.py` around lines 69 - 76,
The sinks fragment is allocated with the wrong size: T.alloc_fragment([heads],
dtype) but it's indexed by i in the T.Parallel(block_M) loop, causing
out-of-bounds when block_M > heads; update the allocation so sinks is sized per
row (use block_M as the dimension) — i.e., replace the T.alloc_fragment
allocation for sinks to allocate length block_M (keeping the same dtype) so the
subsequent loop assigning sinks[i] = Sinks[by] is safe.
| T.copy(scores_max, scores_max_prev) | ||
| T.fill(scores_max, -T.infinity(accum_dtype)) | ||
|
|
||
| T.copy(BlockSparseMask[bz, by, bx, :], block_mask) | ||
|
|
||
| loop_range = ( | ||
| T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) | ||
| ) | ||
|
|
||
| for k in T.Pipelined(loop_range, num_stages=num_stages): | ||
| if block_mask[k] != 0: | ||
| T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) | ||
| if is_causal: | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) | ||
| else: | ||
| T.clear(acc_s) | ||
| T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) | ||
|
|
||
| T.copy(scores_max, scores_max_prev) | ||
| T.fill(scores_max, -T.infinity(accum_dtype)) | ||
| T.reduce_max(acc_s, scores_max, dim=1, clear=False) | ||
| # To do causal softmax, we need to set the scores_max to 0 if it is -inf | ||
| # This process is called Check_inf in FlashAttention3 code, and it only need to be done | ||
| # in the first ceil_div(kBlockM, kBlockN) steps. | ||
| # for i in T.Parallel(block_M): | ||
| # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) | ||
| for i in T.Parallel(block_M): | ||
| scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - | ||
| # max * log_2(e)) This allows the compiler to use the ffma | ||
| # instruction instead of fadd and fmul separately. | ||
| acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) | ||
| T.reduce_sum(acc_s, scores_sum, dim=1) | ||
| for i in T.Parallel(block_M): | ||
| logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] | ||
| T.copy(acc_s, acc_s_cast) | ||
|
|
||
| for i, j in T.Parallel(block_M, dim): | ||
| acc_o[i, j] *= scores_scale[i] | ||
|
|
||
| T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) | ||
| T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) | ||
| T.reduce_max(acc_s, scores_max, dim=1, clear=False) | ||
| # To do causal softmax, we need to set the scores_max to 0 if it is -inf | ||
| # This process is called Check_inf in FlashAttention3 code, and it only need to be done | ||
| # in the first ceil_div(kBlockM, kBlockN) steps. | ||
| # for i in T.Parallel(block_M): | ||
| # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) | ||
| for i in T.Parallel(block_M): | ||
| scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - | ||
| # max * log_2(e)) This allows the compiler to use the ffma | ||
| # instruction instead of fadd and fmul separately. | ||
| acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) | ||
| T.reduce_sum(acc_s, scores_sum, dim=1) |
There was a problem hiding this comment.
Restore the running max before rescaling.
reduce_max(acc_s, scores_max) now gives only the current block max, but the online softmax update still treats scores_max as the max over all processed blocks. If a later sparse block has a smaller row max than an earlier one, prior contributions are rescaled against the wrong baseline and the output drifts.
🩹 Proposed fix
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
+ scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/blocksparse_attention/example_tilelang_block_sparse_attn.py` around
lines 89 - 104, The running max must be restored before rescaling: after
T.reduce_max(acc_s, scores_max, dim=1, clear=False) compute the elementwise
running max = max(scores_max_prev, scores_max) (update either scores_max or
scores_max_prev to hold this combined max) and then use that combined max when
computing scores_scale and when subtracting for acc_s; i.e., replace the current
use of scores_max_prev and scores_max with the single restored running max
(using symbols scores_max_prev, scores_max, scores_scale, acc_s) so prior-block
contributions are rescaled against the true running maximum before
T.reduce_sum(scores_sum).
| max_split = T.alloc_var(T.int32) | ||
|
|
||
| T.clear(lse_logsum_local) | ||
| T.clear(o_accum_local) | ||
| lse_max_local = -T.infinity(accum_dtype) | ||
| for k in T.serial(num_split): | ||
| lse_local_split = glse[bz, by, k] | ||
| if lse_local_split != 0: | ||
| max_split = k | ||
| lse_max_local = T.max(lse_max_local, glse[bz, by, k]) | ||
|
|
||
| for k in T.Pipelined(num_split, num_stages=1): | ||
| if k <= max_split: | ||
| lse_local_split = glse[bz, by, k] | ||
| lse_logsum_local += T.exp2(lse_local_split - lse_max_local) | ||
| lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local | ||
| for k in T.serial(num_split): | ||
| if k <= max_split: | ||
| for i in T.Parallel(dim_v): | ||
| po_local[i] = Output_partial[bz, by, k, i] | ||
| lse_local_split = glse[bz, by, k] | ||
| if lse_local_split != 0: | ||
| max_split = k | ||
| lse_max_local = T.max(lse_max_local, glse[bz, by, k]) | ||
|
|
||
| for k in T.Pipelined(num_split, num_stages=1): | ||
| if k <= max_split: | ||
| lse_local_split = glse[bz, by, k] | ||
| lse_logsum_local += T.exp2(lse_local_split - lse_max_local) | ||
| lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local | ||
| for k in T.serial(num_split): | ||
| if k <= max_split: | ||
| for i in T.Parallel(dim_v): | ||
| po_local[i] = Output_partial[bz, by, k, i] | ||
| lse_local_split = glse[bz, by, k] | ||
| scale_local = T.exp2(lse_local_split - lse_logsum_local) | ||
| for i in T.Parallel(dim_v): | ||
| o_accum_local[i] += po_local[i] * scale_local | ||
| for i in T.Parallel(dim_v): | ||
| Output[bz, by, i] = o_accum_local[i] | ||
|
|
||
| return main | ||
| scale_local = T.exp2(lse_local_split - lse_logsum_local) | ||
| for i in T.Parallel(dim_v): | ||
| o_accum_local[i] += po_local[i] * scale_local |
There was a problem hiding this comment.
Guard the no-valid-split path before reading max_split.
If a head has no valid indices, the first loop never assigns max_split, but the later k <= max_split checks still read it. run_regression_perf() currently builds exactly that all--1 input, so this path is live.
🩹 Proposed fix
scale_local = T.alloc_var(accum_dtype)
max_split = T.alloc_var(T.int32)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local = -T.infinity(accum_dtype)
+ max_split = -1
for k in T.serial(num_split):
lse_local_split = glse[bz, by, k]
if lse_local_split != 0:
max_split = k
lse_max_local = T.max(lse_max_local, glse[bz, by, k])
- for k in T.Pipelined(num_split, num_stages=1):
- if k <= max_split:
- lse_local_split = glse[bz, by, k]
- lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
- lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
- for k in T.serial(num_split):
- if k <= max_split:
- for i in T.Parallel(dim_v):
- po_local[i] = Output_partial[bz, by, k, i]
- lse_local_split = glse[bz, by, k]
- scale_local = T.exp2(lse_local_split - lse_logsum_local)
- for i in T.Parallel(dim_v):
- o_accum_local[i] += po_local[i] * scale_local
+ if max_split >= 0:
+ for k in T.Pipelined(num_split, num_stages=1):
+ if k <= max_split:
+ lse_local_split = glse[bz, by, k]
+ lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
+ lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
+ for k in T.serial(num_split):
+ if k <= max_split:
+ for i in T.Parallel(dim_v):
+ po_local[i] = Output_partial[bz, by, k, i]
+ lse_local_split = glse[bz, by, k]
+ scale_local = T.exp2(lse_local_split - lse_logsum_local)
+ for i in T.Parallel(dim_v):
+ o_accum_local[i] += po_local[i] * scale_local🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py`
around lines 137 - 160, The code reads max_split later even when no split was
assigned; initialize and guard its use: set max_split to -1 before the first
loop (currently max_split = T.alloc_var...) and then before the pipelined and
serial loops that check "k <= max_split" add a guard "if max_split >= 0" (or
wrap both loops in a single conditional) so you skip lse_logsum_local /
o_accum_local updates when there are no valid splits; reference symbols:
max_split, glse[bz, by, k], lse_max_local, lse_logsum_local, po_local,
Output_partial and o_accum_local.
| if has_valid_block: | ||
| for i, j in T.Parallel(block_H, dim_v): | ||
| if i < valid_block_H: | ||
| Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] | ||
|
|
||
| # combine | ||
| with T.Kernel(heads, batch, threads=128) as (by, bz): | ||
| po_local = T.alloc_fragment([dim_v], accum_dtype) | ||
| o_accum_local = T.alloc_fragment([dim_v], accum_dtype) | ||
| lse_local_split = T.alloc_var(accum_dtype) | ||
| lse_logsum_local = T.alloc_var(accum_dtype) | ||
| lse_max_local = T.alloc_var(accum_dtype) | ||
| scale_local = T.alloc_var(accum_dtype) | ||
| max_split = T.alloc_var(T.int32) | ||
|
|
||
| T.clear(lse_logsum_local) | ||
| T.clear(o_accum_local) | ||
| lse_max_local = -T.infinity(accum_dtype) | ||
| for k in T.serial(num_split): | ||
| acc_o[i, j] /= logsum[i] | ||
| for i in T.Parallel(block_H): | ||
| logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale | ||
|
|
||
| # TODO(lei): Support T.Parallel(valid_block_H) | ||
| for i in T.Parallel(block_H): | ||
| if i < valid_block_H: | ||
| glse[bid, hid * valid_block_H + i, sid] = logsum[i] | ||
| for i, j in T.Parallel(block_H, dim_v): | ||
| if i < valid_block_H: | ||
| Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] |
There was a problem hiding this comment.
max_split is not a safe validity proxy for an arbitrary block mask.
This kernel writes zeroed glse/partials for empty splits, then the combine pass folds every split k <= max_split as if it were valid. With a boolean block_mask, empty splits can appear before later non-empty ones, so those holes corrupt the normalization; the all-false case never initializes max_split at all. Use an explicit per-split validity marker, or write -inf for empty splits and gate each combine loop on that instead.
Also applies to: 124-157
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py`
around lines 110 - 122, The current use of max_split/valid_block_H as a proxy
for validity is unsafe because empty splits can appear before later non-empty
ones (and can be all-false), causing incorrect glse/partial folding; change the
kernel to record an explicit per-split validity marker (e.g., valid_split[bid,
hid, i]) when writing glse/Output_partial or write -inf into glse for empty
splits, then update the combine/reader loops (the sections writing/reading glse
and Output_partial—look for uses of valid_block_H, max_split, glse,
Output_partial, acc_o, logsum) to gate on that per-split marker (or check for
-inf) instead of relying on max_split so only truly valid splits are combined.
| BlockMask: T.Tensor[block_mask_shape, T.int32] | ||
| num_threads = 32 | ||
| print("NV", NV, "NS", NS, "B", B, "H", H) | ||
|
|
||
| @T.prim_func | ||
| def flash_bwd_dkv( | ||
| Q: T.Tensor(q_shape, dtype), | ||
| K: T.Tensor(k_shape, dtype), | ||
| V: T.Tensor(v_shape, dtype), | ||
| LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), | ||
| Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), | ||
| DO_slc: T.Tensor(do_slc_shape, dtype), | ||
| DK: T.Tensor(dk_shape, dtype), | ||
| DV: T.Tensor(dv_shape, dtype), | ||
| BlockMask: T.Tensor(block_mask_shape, T.int32), | ||
| ): | ||
| with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): | ||
| K_shared = T.alloc_shared([BS, BK], dtype) | ||
| V_shared = T.alloc_shared([BS, BV], dtype) | ||
| Q_shared = T.alloc_shared([G, BK], dtype) | ||
| qkT = T.alloc_fragment([BS, G], accum_dtype) | ||
| qkT_cast = T.alloc_fragment([BS, G], dtype) | ||
| dsT = T.alloc_fragment([BS, G], accum_dtype) | ||
| dsT_cast = T.alloc_fragment([BS, G], dtype) | ||
| lse_shared = T.alloc_shared([G], accum_dtype) | ||
| delta = T.alloc_shared([G], accum_dtype) | ||
|
|
||
| do = T.alloc_shared([G, BV], dtype) | ||
| dv = T.alloc_fragment([BS, BV], accum_dtype) | ||
| dk = T.alloc_fragment([BS, BK], accum_dtype) | ||
| dq = T.alloc_fragment([BS, G], accum_dtype) | ||
|
|
||
| dv_shared = T.alloc_shared([BS, BV], dtype) | ||
| dk_shared = T.alloc_shared([BS, BK], dtype) | ||
|
|
||
| i_b, i_h = i_bh // H, i_bh % H | ||
|
|
||
| T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared) | ||
| T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared) | ||
|
|
||
| # [BS, BK] | ||
| T.clear(dk) | ||
| # [BS, BV] | ||
| T.clear(dv) | ||
|
|
||
| loop_st = i_s * BS | ||
| loop_ed = seq_len | ||
| for i in T.Pipelined( | ||
| start=loop_st, | ||
| stop=loop_ed, | ||
| num_stages=0, | ||
| ): | ||
| b_m_slc = BlockMask[i_b, i, i_h, i_s] | ||
| if b_m_slc != 0: | ||
| # [G, BK] | ||
| T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared) | ||
| T.clear(qkT) | ||
| # [BS, BK] @ [G, BK] -> [BS, G] | ||
| T.gemm( | ||
| K_shared, | ||
| Q_shared, | ||
| qkT, | ||
| transpose_B=True, | ||
| policy=T.GemmWarpPolicy.FullRow, | ||
| ) | ||
| # [G] | ||
| T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared) | ||
|
|
||
| for _i, _j in T.Parallel(BS, G): | ||
| qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j]) | ||
|
|
||
| for _i, _j in T.Parallel(BS, G): | ||
| qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0) | ||
|
|
||
| # [G, BV] | ||
| T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do) | ||
| T.clear(dsT) | ||
| # [BS, BV] @ [G, BV] -> [BS, G] | ||
| T.gemm( | ||
| V_shared, | ||
| do, | ||
| dsT, | ||
| transpose_B=True, | ||
| policy=T.GemmWarpPolicy.FullRow, | ||
| ) | ||
| T.copy(qkT, qkT_cast) | ||
| # [BS, G] @ [G, BV] -> [BS, BV] | ||
| T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) | ||
| # [G] | ||
| T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta) | ||
| for i, j in T.Parallel(BS, G): | ||
| dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale | ||
|
|
||
| # [BS, G] @ [G, BK] -> [BS, BK] | ||
| T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow) | ||
|
|
||
| T.copy(dv, dv_shared) | ||
| T.copy(dk, dk_shared) | ||
| T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV]) | ||
| T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK]) | ||
|
|
||
| return flash_bwd_dkv | ||
| with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): | ||
| K_shared = T.alloc_shared([BS, BK], dtype) | ||
| V_shared = T.alloc_shared([BS, BV], dtype) | ||
| Q_shared = T.alloc_shared([G, BK], dtype) | ||
| qkT = T.alloc_fragment([BS, G], accum_dtype) | ||
| qkT_cast = T.alloc_fragment([BS, G], dtype) | ||
| dsT = T.alloc_fragment([BS, G], accum_dtype) | ||
| dsT_cast = T.alloc_fragment([BS, G], dtype) | ||
| lse_shared = T.alloc_shared([G], accum_dtype) | ||
| delta = T.alloc_shared([G], accum_dtype) | ||
|
|
||
| do = T.alloc_shared([G, BV], dtype) | ||
| dv = T.alloc_fragment([BS, BV], accum_dtype) | ||
| dk = T.alloc_fragment([BS, BK], accum_dtype) | ||
| dq = T.alloc_fragment([BS, G], accum_dtype) | ||
|
|
||
| dv_shared = T.alloc_shared([BS, BV], dtype) | ||
| dk_shared = T.alloc_shared([BS, BK], dtype) | ||
|
|
||
| i_b, i_h = i_bh // H, i_bh % H |
There was a problem hiding this comment.
Debug print statement with undefined variable H.
Line 205 contains a debug print statement that references H, which is undefined in this scope. The kernel defines heads_kv via T.const, not H. This will cause a NameError. Additionally, line 226 uses H for computing i_h which has the same issue.
🐛 Proposed fix - remove debug print and fix H references
block_mask_shape = [batch, seq_len, heads_kv, NS]
BlockMask: T.Tensor[block_mask_shape, T.int32]
num_threads = 32
- print("NV", NV, "NS", NS, "B", B, "H", H)
- with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
+ with T.Kernel(NV, NS, B * heads_kv, threads=num_threads) as (i_v, i_s, i_bh):
K_shared = T.alloc_shared([BS, BK], dtype)
...
- i_b, i_h = i_bh // H, i_bh % H
+ i_b, i_h = i_bh // heads_kv, i_bh % heads_kv🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/deepseek_nsa/example_tilelang_nsa_bwd.py` around lines 203 - 226,
The debug print uses an undefined H and the kernel's split of i_bh uses H as
well; replace or remove the print and change references of H to the actual
constant name heads_kv (the T.const used to define number of heads for KV), e.g.
remove the print("NV", NV, "NS", NS, "B", B, "H", H) and update the split i_b,
i_h = i_bh // H, i_bh % H to use heads_kv (i_b, i_h = i_bh // heads_kv, i_bh %
heads_kv) so all references match the defined symbol like heads_kv in the
kernel.
| if l_new_topk <= 0: | ||
| break | ||
|
|
||
| r_idx = round % 2 | ||
| l_start_pos = topk - l_new_topk | ||
|
|
||
| T.sync_threads() | ||
| for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): | ||
| input_idx = s * BLOCK_SIZE + tx | ||
| if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: | ||
| inval_int16 = convert_to_uint16(input[bx, input_idx]) | ||
| T.atomic_add(s_histogram[inval_int16], 1) | ||
| T.fill(s_histogram, 0) | ||
| if tx == 0: | ||
| s_num_input[r_idx ^ 1] = 0 | ||
| T.sync_threads() | ||
|
|
||
| l_num_input = s_num_input[r_idx] | ||
| for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): |
There was a problem hiding this comment.
Accept the whole threshold bucket when it already fits.
Line 115 only exits on l_new_topk <= 0, but l_num_input <= l_new_topk is also terminal: every buffered candidate already belongs in the output. Running another radix pass in that state leaves no valid threshold bin, so s_threshold_bin_id[0] stays stale and the selection becomes undefined. This shows up for rows with exactly topk live elements and for any round where the current bucket size exactly matches the remaining budget.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/deepseek_v32/topk_selector.py` around lines 115 - 128, Before
performing another radix pass, handle the terminal case when the current bucket
already fits: if l_num_input <= l_new_topk (in the loop using l_new_topk,
l_num_input, r_idx, s_num_input and s_threshold_bin_id), treat this as a
terminal acceptance of the whole threshold bucket — update the shared selection
state (set s_threshold_bin_id[0] to the current bucket indicator / accepted
sentinel and ensure any per-round counters like s_num_input[r_idx ^ 1] are set
appropriately) and break out of the loop so no further radix passes run and
s_threshold_bin_id[0] cannot remain stale.
| for i_i in T.Pipelined(NS, num_stages=num_stages): | ||
| # Check which indices are valid | ||
| for bi_i in T.Parallel(BS): | ||
| mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (Indices[bos + s_i, bz, i_i * BS + bi_i] != -1) | ||
|
|
||
| T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) | ||
| # Compute attention scores | ||
| for h_i, bi_i in T.Parallel(padded_H, BS): | ||
| acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype)) | ||
|
|
||
| for h_i, bi_i in T.Parallel(padded_H, BS): | ||
| acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale | ||
| # Load KV, V for this block of indices | ||
| for bi_i, d_i in T.Parallel(BS, D): | ||
| KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, d_i] | ||
|
|
||
| T.copy(acc_dp, dP_shared_cast) | ||
| T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) | ||
| T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) | ||
| T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) | ||
|
|
||
| T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) | ||
| T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) | ||
| for bi_i, d_i in T.Parallel(BS, D_tail): | ||
| KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, D + d_i] | ||
| T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) | ||
|
|
There was a problem hiding this comment.
Masking does not protect these KV loads.
The kernel marks invalid entries in mask, but Line 186 and Line 191 still dereference Indices[...] unconditionally. With padded slots stored as S/-1, this can walk outside the valid KV range before the softmax mask suppresses the lane.
| for i_i in T.Pipelined(NI, num_stages=num_stages): | ||
| for bi_i in T.Parallel(BI): | ||
| mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) | ||
|
|
||
| for bi_i, d_i in T.Parallel(BI, D): | ||
| KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] | ||
| for bi_i, d_i in T.Parallel(BI, D_tail): | ||
| K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] |
There was a problem hiding this comment.
Don't index KV with masked-out entries.
The -1/causal checks only affect mask; Line 120 and Line 122 still dereference every sparse slot. Any padded entry (-1 or the current S sentinel left by the generator) can read outside the valid sequence range before acc_s is masked.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/dsa_sparse_finetune/sparse_mla_fwd.py` around lines 115 - 122, The
code currently always dereferences KV using Indices even for masked-out slots;
update the KV loads inside the pipelined loop (the assignments to KV_shared and
K_tail_shared) to first fetch the index (e.g., idx = Indices[bos + s_i, g_i, i_i
* BI + bi_i]) and compute a validity predicate (the same check used for mask:
idx != -1 and idx <= max_kv_i), then perform a guarded load: if valid assign
KV_shared[bi_i,d_i] = KV[bos + idx, g_i, d_i] (or K_tail_shared[...] = KV[bos +
idx, g_i, D + d_i]) else assign a safe default (e.g., zero) so you never index
KV with out-of-range/padded indices; use the same loop scopes (for bi_i, d_i in
T.Parallel(BI, D) and for bi_i, d_i in T.Parallel(BI, D_tail)) and reuse
Indices, mask, KV_shared, and K_tail_shared names to locate the changes.
| for i_i in T.Pipelined(NI, num_stages=num_stages): | ||
| for bi_i in T.Parallel(BI): | ||
| mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) | ||
|
|
||
| for bi_i, d_i in T.Parallel(BI, D): | ||
| KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] | ||
| for bi_i, d_i in T.Parallel(BI, D_tail): | ||
| K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] |
There was a problem hiding this comment.
Guard the KV loads with a safe index.
mask is computed here, but Line 116 and Line 118 still index KV for every lane. If Indices contains -1 or any pruned entry beyond max_kv_i, this kernel can read invalid memory before the masked softmax zeroes that lane out.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py` around lines 111 -
118, The KV loads into KV_shared and K_tail_shared in the nested loops use
Indices directly even when mask is false; change those loads in the loops that
write KV_shared and K_tail_shared (the loops referencing KV[bos + Indices[bos +
s_i, g_i, i_i * BI + bi_i], g_i, d_i] and KV[..., D + d_i]) to guard with the
computed mask by either computing a safe_index (e.g., replace out-of-range or -1
indices with a valid dummy index) or conditionally loading/storing only when
mask[bi_i] is true and writing zeros otherwise so no invalid KV memory is read
when Indices is -1 or > max_kv_i.
| result = flashattn(q, k, v, is_causal=is_causal) | ||
| best_latency = result.latency | ||
| best_config = result.config | ||
| ref_latency = result.ref_latency |
There was a problem hiding this comment.
Pass groups through the tune path.
Line 194 currently falls back to groups=1. For GQA inputs, that makes the tuned kernel index K and V with the query-head id instead of by // groups, so tune mode can walk past the KV-head dimension and benchmark a different kernel than the non-tuned path.
🩹 Minimal fix
- result = flashattn(q, k, v, is_causal=is_causal)
+ result = flashattn(q, k, v, is_causal=is_causal, groups=groups)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| result = flashattn(q, k, v, is_causal=is_causal) | |
| best_latency = result.latency | |
| best_config = result.config | |
| ref_latency = result.ref_latency | |
| result = flashattn(q, k, v, is_causal=is_causal, groups=groups) | |
| best_latency = result.latency | |
| best_config = result.config | |
| ref_latency = result.ref_latency |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/flash_attention/example_gqa_fwd_bshd.py` around lines 194 - 197, The
call to flashattn(q, k, v, is_causal=is_causal) is omitting the groups argument
causing the tuning path to assume groups=1; update the tuned-call site to pass
the runtime groups variable (e.g., flashattn(q, k, v, is_causal=is_causal,
groups=groups)) so that the tuning selection uses the same groups value as the
non-tuned path (ensure any other tune-related invocations that construct or
compare configs also propagate the groups field).
as title.
Summary by CodeRabbit
@T.prim_funcwrappers and enabling in-kernel shape inference viaT.const().