Skip to content

[Refactor] Refactor examples into eager style#1948

Closed
LeiWang1999 wants to merge 2 commits intotile-ai:mainfrom
LeiWang1999:refactor_0319
Closed

[Refactor] Refactor examples into eager style#1948
LeiWang1999 wants to merge 2 commits intotile-ai:mainfrom
LeiWang1999:refactor_0319

Conversation

@LeiWang1999
Copy link
Copy Markdown
Member

@LeiWang1999 LeiWang1999 commented Mar 19, 2026

as title.

Summary by CodeRabbit

  • Refactor
    • Simplified TileLang kernel APIs across numerous example kernels (attention, GEMM, quantization, and sparse operations) by transitioning from parameterized shape-based signatures to direct tensor-based signatures, eliminating nested @T.prim_func wrappers and enabling in-kernel shape inference via T.const().

… and performance

- Refactored the `matmul` function to accept tensors directly instead of dimensions, enhancing usability.
- Simplified the `flashattn_bwd_preprocess` and `flashattn_bwd_postprocess` functions by updating parameter names and types for better readability.
- Adjusted kernel context initialization and memory allocation for shared tensors in both forward and backward attention functions.
- Enhanced the handling of dynamic shapes and added support for variable-length sequences in attention sink examples.
- Cleaned up example scripts by removing unnecessary comments and improving code structure for better maintainability.
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 19, 2026

Important

Review skipped

Review was skipped as selected files did not have any reviewable changes.

💤 Files selected but had no reviewable changes (2)
  • tilelang/engine/phase.py
  • tilelang/language/eager/builder.py
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: eeca44e9-523a-48a4-9c6a-9a5e765d55ae

📥 Commits

Reviewing files that changed from the base of the PR and between 5cdcd9e and 37e760a.

📒 Files selected for processing (2)
  • tilelang/engine/phase.py
  • tilelang/language/eager/builder.py

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR systematically refactors TileLang JIT kernel examples across 60+ files, converting them from a "factory pattern" where functions return nested @T.prim_func callables to a "direct JIT pattern" where functions take tensor inputs directly and return computed tensor outputs. Shape parameters are replaced with tensor-based inputs, and T.const() is used for compile-time shape inference instead of runtime scalar parameters.

Changes

Cohort / File(s) Summary
Flash Attention Forward
examples/flash_attention/example_gqa_fwd_bshd.py, example_gqa_fwd_varlen.py, example_mha_fwd_bhsd.py, example_mha_fwd_bshd.py
Changed flashattn from shape-parameterized factory returning @T.prim_func to tensor-input JIT kernels; removed nested prim_func wrappers, use T.const() for shape inference, allocate and return Output/lse directly.
Flash Attention Backward
examples/flash_attention/example_gqa_bwd.py, example_gqa_bwd_tma_reduce.py, example_gqa_bwd_tma_reduce_varlen.py, example_mha_bwd_bhsd.py, example_mha_bwd_bshd.py
Refactored forward/preprocess/postprocess/backward kernels from shape-parameterized prim_func factories to direct tensor-input JIT kernels; removed nested @T.prim_func wrappers, updated signatures to accept O/dO/Q/K/V/Delta tensors, return computed gradients directly.
AMD Flash Attention
examples/amd/example_amd_flash_attn_bwd.py
Changed flashattn_bwd_preprocess to take (O, dO) and return Delta directly; changed flashattn_bwd_postprocess to take dQ_in and return dQ_out directly, removing shape parameters and nested prim_funcs.
Attention Sink Forward
examples/attention_sink/example_gqa_sink_fwd_varlen.py, example_mha_sink_fwd_bhsd.py
Changed flashattn_sink/flashattn from shape-parameterized factories to tensor-input kernels; removed nested @T.prim_func main(...) wrappers, allocate Output internally, return tensors directly.
Attention Sink Backward
examples/attention_sink/example_gqa_sink_bwd_bhsd.py, example_gqa_sink_bwd_varlen.py, example_mha_sink_bwd_bhsd.py
Refactored forward/preprocess/postprocess/backward/dsink kernels to tensor-input JIT; removed shape parameters and nested prim_funcs, use T.const() for shapes, return computed outputs directly.
Block Sparse Attention
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py, example_tilelang_sparse_gqa_decode_paged.py, example_tilelang_sparse_gqa_decode_varlen_indice.py, example_tilelang_sparse_gqa_decode_varlen_mask.py
Changed blocksparse_flashattn and sparse decode kernels from parameterized factories to tensor-input JIT; removed @T.prim_func wrappers, updated to take (Q, K, V, BlockMask/indices) tensors, allocate and return Output directly.
Block Sparse GEMM
examples/blocksparse_gemm/example_blocksparse_gemm.py
Refactored blocksparse_matmul from shape-parameterized factory returning prim_func to tensor-input JIT taking (A, B, BlockMask) with defaults; allocates and returns C directly, removed out_idx decorator.
Cast/Quantization
examples/cast/example_group_per_split_token_cast_to_fp8.py, example_per_token_cast_to_fp8.py
Changed cast kernels from shape-parameterized factories to tensor-input JIT; take (X, batch_sizes) and (X) respectively, allocate X_fp8/X_amax internally, return tensors directly.
Convolution
examples/convolution/example_convolution.py, example_convolution_autotune.py
Refactored convolution from shape-parameterized factory returning prim_func to tensor-input JIT taking (data, kernel_weight) with block/dtype parameters; allocates and returns out directly.
DeepSeek DeepGEMM
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
Changed tl_gemm from shape/dtype-parameterized factory to tensor-input JIT taking (A, B, scales_a, scales_b) with typed parameters; allocates and returns C directly, derives shapes via T.const().
DeepSeek MHC
examples/deepseek_mhc/example_mhc_bwd.py, example_mhc_post.py, example_mhc_pre.py
Updated kernels to use tensor inputs and T.const() shape inference; sinkhorn_bwd_implicit_cg takes (out, dout) and returns res; mhc_post/pre updated tensor type annotations with T.const() shape bindings.
DeepSeek MLA Decode
examples/deepseek_mla/example_mla_decode.py, example_mla_decode_paged.py, example_mla_decode_persistent.py, example_mla_decode_ws.py, amd/benchmark_mla_decode_amd_tilelang.py
Changed flashattn/mla_decode_tilelang from shape-parameterized factories to tensor-input JIT kernels; removed inner prim_funcs, allocate Output/glse internally, return tensors directly.
DeepSeek MLA KV FP8
examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py
Refactored flashattn to take (Q, Q_pe, KV, K_pe) tensors and return Output directly; removed nested prim_func, uses T.const() for shape inference.
DeepSeek NSA Forward
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py, example_tilelang_nsa_fwd.py, example_tilelang_nsa_fwd_varlen.py, example_tilelang_nsa_decode.py
Changed tilelang_sparse_attention and native_sparse_attention from shape-parameterized factories to tensor-input JIT; removed nested prim_funcs, allocate and return Output directly.
DeepSeek NSA Backward
examples/deepseek_nsa/example_tilelang_nsa_bwd.py
Refactored forward/backward kernels to tensor-input JIT; tilelang_kernel_fwd takes (Q, K, V, BlockIndices, O_slc, LSE_slc), backward kernels take explicit gradient tensors, removed inner prim_funcs and out_idx decorators.
DeepSeek V32 FP8
examples/deepseek_v32/fp8_lighting_indexer.py, inference/kernel.py
Updated mqa_attn_return_logits to take explicit tensor inputs; changed act_quant_kernel, fp8_gemm_kernel, fp8_index_kernel from shape-parameterized factories to tensor-input JIT kernels with direct output allocation/return.
DeepSeek V32 Sparse MLA
examples/deepseek_v32/sparse_mla_bwd.py, sparse_mla_fwd.py, sparse_mla_fwd_pipelined.py, sparse_mla_fwd_seesaw.py
Refactored sparse MLA kernels from shape-parameterized factories to tensor-input JIT; preprocess/postprocess/bwd take explicit tensor arguments, allocate outputs internally, return tensors directly.
DeepSeek V32 TopK
examples/deepseek_v32/topk_selector.py
Changed tl_topk_impl from (topk, in_dtype, out_dtype) parameterized factory to tensor-input JIT (input, starts, ends, topk) that allocates and returns index directly.
Dequantize GEMM
examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py, example_dequant_gemm_bf16_mxfp4_hopper.py, example_dequant_gemm_fine_grained.py, example_dequant_gemm_fp4_hopper.py, example_dequant_gemm_w4a8.py, example_dequant_groupedgemm_bf16_mxfp4_hopper.py
Refactored matmul kernels from shape/dtype-parameterized factories returning prim_funcs to tensor-input JIT taking (A, B) or (A, B, Scale, Bias) with typed block/dtype defaults; allocate and return C directly, removed out_idx decorators.
Dequantize GEMV
examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py
Changed dequantize_gemv from shape-parameterized factory returning prim_func to tensor-input JIT taking (A, B) with dtype defaults; allocates and returns C directly.
Dynamic Shape GEMM
examples/dynamic_shape/example_dynamic.py
Refactored matmul_dynamic_mnk from shape/block-parameterized factory to tensor-input JIT taking (A, B, C) with typed defaults; removed nested prim_func, uses dynamic shape binding.
Elementwise
examples/elementwise/example_elementwise_add.py
Changed elementwise_add from shape-parameterized factory returning prim_func to tensor-input JIT taking (A, B) with block/dtype defaults; allocates and returns C directly, removed out_idx.
AutoDD
examples/autodd/tilelang_buggy.py
Refactored buggy_matmul from shape-parameterized factory to tensor-input JIT taking (A, B, C) with typed defaults; removed nested @T.prim_func main(...) wrapper, updated call sites accordingly.
DSA Sparse MLA
examples/dsa_sparse_finetune/sparse_mla_bwd.py, sparse_mla_fwd.py, sparse_mla_topk_reducesum.py
Changed sparse MLA kernels from shape-parameterized factories to tensor-input JIT; preprocess/postprocess/bwd take explicit tensor arguments, allocate/return outputs directly, removed nested prim_funcs.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • chengyupku

🐰 Hop into the refactor, my friend,
From factories to tensors—this is the trend,
Prim funcs be gone, inline they run,
JIT kernels dancing in the sun!

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 17

Note

Due to the large number of review comments, Critical severity comments were prioritized as inline comments.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (7)
examples/flash_attention/example_gqa_fwd_varlen.py (1)

145-150: ⚠️ Potential issue | 🟡 Minor

Remove duplicate random seed setting.

tilelang.testing.set_random_seed(0) is called twice (lines 145 and 150). Remove one of these calls.

Proposed fix
     tilelang.testing.set_random_seed(0)
 
     if is_causal:
         total_flops *= 0.5
 
-    tilelang.testing.set_random_seed(0)
-
     dtype = torch.float16
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention/example_gqa_fwd_varlen.py` around lines 145 - 150,
Remove the duplicate call to tilelang.testing.set_random_seed(0) — keep a single
invocation (either before or after the is_causal block) to preserve
deterministic behavior; update the area around the is_causal check and
total_flops adjustment so only one tilelang.testing.set_random_seed(0) call
remains (refer to tilelang.testing.set_random_seed and the is_causal /
total_flops code block to locate the redundant line).
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py (1)

33-36: ⚠️ Potential issue | 🟡 Minor

Error message inconsistent with assertion.

The assertion checks for T.bfloat16 but the error message states "float16". This could confuse users when debugging dtype mismatches.

📝 Proposed fix
     assert out_dtype in [
         T.bfloat16,
         T.float32,
-    ], "Currently only float16 and float32 are supported"
+    ], "Currently only bfloat16 and float32 are supported"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py` around lines 33 -
36, The assertion message is inconsistent with the checked dtypes: update the
error string to match the actual allowed types (T.bfloat16 and T.float32) used
in the assert for variable out_dtype; e.g., change the message to "Currently
only bfloat16 and float32 are supported" so it references T.bfloat16 and
T.float32 consistently with the assert on out_dtype.
examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py (1)

161-174: ⚠️ Potential issue | 🔴 Critical

Undefined variable D - use K instead.

Line 161 extracts the dimension as K from k.shape, but lines 171-174 use D which is undefined in this scope. This will cause a NameError at runtime.

🐛 Proposed fix
 def parallel_nsa_fwd(
     ...
 ):
     B, C_SEQ_LEN, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
 
     batch = len(offsets) - 1
     HQ = q.shape[2]
     G = HQ // H
     BS = block_size
     WS = window_size
 
     o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device)
     native_sparse_attention_varlen(
-        q.view(C_SEQ_LEN, HQ, D),
-        k.view(C_SEQ_LEN, H, D),
-        v.view(C_SEQ_LEN, H, D),
+        q.view(C_SEQ_LEN, HQ, K),
+        k.view(C_SEQ_LEN, H, K),
+        v.view(C_SEQ_LEN, H, K),
         o_slc.view(C_SEQ_LEN, HQ, V),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py` around lines 161 -
174, The code uses an undefined symbol D when calling
native_sparse_attention_varlen; replace D with the correct dimension variable K
(extracted earlier as part of B, C_SEQ_LEN, H, K, V, S = *k.shape ...) so the
q/k/v/o_slc reshapes use K instead of D; update the calls/reshape expressions
around native_sparse_attention_varlen (references: q.view(..., D) k.view(..., D)
v.view(..., D) o_slc.view(..., V)) to use K for the embedding dimension.
examples/deepseek_mhc/example_mhc_post.py (1)

16-45: ⚠️ Potential issue | 🟠 Major

Allocation and loop bounds should derive from tensor-extracted dimensions, not unconstrained scalar parameters.

Tensor shapes are inferred via T.const() (lines 18–22), giving symbolic dimensions (hc_a, hc_c, h, h_d, h_x). However, memory allocations at lines 32–41 and loop bounds at line 45 use scalar function parameters (hc, hidden) instead. Without explicit guards, if the caller passes scalar values differing from actual tensor shapes, allocations will mismatch tensor extents and line 46 indexing into tensor b[i_n, 0, i0_h * h_blk] (where b has shape [n, hc_c, h]) can exceed bounds.

Align all allocation and loop-bound references to use the extracted symbolic dimensions (prefer h over scalar hidden at line 45; anchor hc-based allocations to a specific tensor dimension).

Suggested direction
-        for i0_h in T.Pipelined(T.ceildiv(hidden, h_blk), num_stages=2):
+        for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2):
 def mhc_post(
     x: torch.Tensor,
     residual: torch.Tensor,
     post_layer_mix: torch.Tensor,
     comb_res_mix: torch.Tensor,
 ) -> torch.Tensor:
     out = torch.empty_like(residual)
+    hc, hidden = residual.shape[1], residual.shape[2]
-    mhc_post_tilelang(comb_res_mix, residual, post_layer_mix.squeeze(-1), x, out, residual.shape[-2], residual.shape[-1])
+    mhc_post_tilelang(comb_res_mix, residual, post_layer_mix.squeeze(-1), x, out, hc, hidden)
     return out
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_mhc/example_mhc_post.py` around lines 16 - 45, The kernel
currently uses scalar parameters hc and hidden for allocations and loop bounds,
which can mismatch the tensor-inferred symbolic dims (hc_a, hc_b, hc_c, h, h_d,
h_x) and cause out-of-bounds accesses; update mhc_post_tilelang to derive sizes
from the T.const() symbols: replace uses of hc with the appropriate per-tensor
symbols (e.g., a_local should be (hc_a, hc_b), c_local size hc_c,
x_shared/x_local should use hc_a and the h_x extent, b_shared/b_local should use
hc_c and h, d_shared/d_local should use h_d), compute h_blk = math.gcd(h, h_blk)
and change the loop bound T.ceildiv(hidden, h_blk) to T.ceildiv(h, h_blk) so all
allocations and tiling iterate over the actual tensor dimensions (use
hc_a/hc_b/hc_c/h/h_d/h_x where relevant).
examples/cast/example_group_per_split_token_cast_to_fp8.py (1)

204-211: ⚠️ Potential issue | 🟠 Major

Bug: dtype comparisons use strings but dtype is a TileLang type.

The global dtype = T.bfloat16 (line 8) is a TileLang type object, not a string. The comparisons dtype == "float", dtype == "float16", dtype == "bfloat16" will always be False, causing all valid dtypes to raise ValueError.

Compare with main() (lines 160-167) which correctly uses dtype == T.float, dtype == T.float16, dtype == T.bfloat16.

🐛 Proposed fix
 def run_regression_perf(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None):
     if batch_sizes is None:
         batch_sizes = [2048, 6144]
-    if dtype == "float":
+    if dtype == T.float:
         x = torch.randn(M, N, device="cuda", dtype=torch.float32)
-    elif dtype == "float16":
+    elif dtype == T.float16:
         x = torch.randn(M, N, device="cuda", dtype=torch.float16)
-    elif dtype == "bfloat16":
+    elif dtype == T.bfloat16:
         x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
     else:
         raise ValueError(f"Unsupported dtype: {dtype}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/cast/example_group_per_split_token_cast_to_fp8.py` around lines 204
- 211, The dtype comparisons are using string literals but the module-level
dtype is a TileLang type (e.g., T.bfloat16); change the conditional in the block
that creates x (the branches that set x = torch.randn(...)) to compare against
TileLang constants (dtype == T.float, dtype == T.float16, dtype == T.bfloat16)
and map those to the correct torch dtypes (torch.float32, torch.float16,
torch.bfloat16) so the right branch runs instead of raising the ValueError;
update the conditionals and keep the same torch.randn calls for each matching
TileLang dtype.
examples/dsa_sparse_finetune/sparse_mla_bwd.py (1)

246-288: ⚠️ Potential issue | 🟠 Major

return_kernel is now a no-op.

The wrapper still exposes return_kernel, but this branch disappeared in the refactor and the function always launches preprocess/bwd/postprocess eagerly. That is a silent API regression for callers that previously used return_kernel=True to stage a reusable kernel; either restore the branch or remove the parameter.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/dsa_sparse_finetune/sparse_mla_bwd.py` around lines 246 - 288, The
wrapper sparse_mla_bwd exposes return_kernel but no longer respects it; restore
the original behavior by adding a branch that, when return_kernel=True, returns
the staged kernel objects (the results/closures produced by
preprocess/bwd/postprocess) instead of eagerly running them, or remove the
return_kernel parameter from sparse_mla_bwd and all callers to avoid the silent
API regression; locate the symbols preprocess, bwd, postprocess and the
return_kernel parameter in sparse_mla_bwd to implement the fix (either
reintroduce the conditional that returns kernels when return_kernel is True, or
delete the parameter and update callers).
examples/deepseek_v32/sparse_mla_bwd.py (1)

231-255: ⚠️ Potential issue | 🟠 Major

return_kernel is no longer honored here.

The wrapper still accepts return_kernel, but it now always executes the eager preprocess/bwd/postprocess path and never returns a reusable kernel handle. Keeping the flag while ignoring it is a silent API break; either restore the old behavior or drop the parameter.

🟠 Major comments (13)
examples/deepseek_v32/fp8_lighting_indexer.py-186-190 (1)

186-190: ⚠️ Potential issue | 🟠 Major

Missing bounds check causes potential out-of-bounds write.

When seq_len_kv is not a multiple of block_K (4096), the final iteration writes to indices beyond the tensor bounds.

For example, if seq_len_kv=5000:

  • T.ceildiv(5000, 4096) = 2 loop iterations
  • Iteration 1 computes idx values from 4096 to 8191
  • Writes to Logits[bx, idx] for idx >= 5000 are out-of-bounds

Add a bounds check before the write.

Proposed fix
         for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)):
             for k_i in T.serial(block_K // threads):
                 idx = n_i * block_K + k_i * threads + tx
-                if idx < cu_k_s or idx >= cu_k_e:
-                    Logits[bx, idx] = -T.infinity(T.float32)
+                if idx < seq_len_kv:
+                    if idx < cu_k_s or idx >= cu_k_e:
+                        Logits[bx, idx] = -T.infinity(T.float32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_v32/fp8_lighting_indexer.py` around lines 186 - 190, The
loop that writes Logits[bx, idx] can write past seq_len_kv when seq_len_kv is
not a multiple of block_K; update the inner loop (the block using T.Pipelined /
T.ceildiv with variables block_K, threads, tx computing idx) to include an
additional bounds check such as "if idx < seq_len_kv and idx >= 0 and idx >=
cu_k_s and idx < cu_k_e" (or simply add "and idx < seq_len_kv" to the existing
cu_k_s/cu_k_e condition) before performing Logits[bx, idx] =
-T.infinity(T.float32) so no out-of-bounds writes occur.
examples/attention_sink/example_gqa_sink_fwd_varlen.py-101-123 (1)

101-123: ⚠️ Potential issue | 🟠 Major

Clamp empty KV windows before entering T.Pipelined.

Line 119 can go negative on batches shorter than max_seqlen_q when window_size is set, because the extra bx tiles may compute start > end. That makes the pipelined K loop invalid for ragged inputs.

🩹 Minimal fix
-        loop_range = end - start
+        loop_range = T.max(0, end - start)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/attention_sink/example_gqa_sink_fwd_varlen.py` around lines 101 -
123, The pipelined K-loop can receive a negative range when start > end (ragged
inputs with window_size), so clamp the window before calling T.Pipelined by
computing loop_range = max(0, end - start) and only enter the T.Pipelined loop
when loop_range > 0; keep the existing start/end usage for computing actual_k
and the K_unpad slice (symbols: start, end, loop_range, T.Pipelined, actual_k,
K_unpad).
examples/deepseek_v32/sparse_mla_fwd_seesaw.py-91-98 (1)

91-98: ⚠️ Potential issue | 🟠 Major

Initialize lse for the queries the grid skips.

When CP0 is true, the x-grid intentionally skips the first kv_stride - 1 queries, but the post-fixup only repairs out. The returned lse slice for that prefix is never written, so callers get garbage values whenever kv_stride > 1.

Also applies to: 505-508

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_v32/sparse_mla_fwd_seesaw.py` around lines 91 - 98, The
kernel skips the first (kv_stride - 1) queries when CP0 is true but never
initializes the corresponding lse entries, leaving garbage returned to callers;
inside the T.Kernel that creates (bx,by,bz) (the CP0 path), set the prefix
length p = max(0, kv_stride - 1) (or compute p = max(0, seq_len - kv_stride + 1)
depending on the existing grid math) and explicitly initialize lse[:p] to the
correct neutral/underflow value used elsewhere (e.g., -inf or the same filled
value used when masking attention) before any writes to out/other buffers, and
apply the same initialization fix to the other analogous kernel instance flagged
in the comment (the later T.Kernel usage around the second occurrence).
examples/deepseek_mla/example_mla_decode.py-245-255 (1)

245-255: ⚠️ Potential issue | 🟠 Major

Both correctness checks still use the stale six-argument reference.

Lines 255 and 290 validate the eager (Q, Q_pe, KV, K_pe) kernel against ref_program(q, q_pe, kv, k_pe, glse, Output_partial). Either slim ref_program down to four inputs or pass a 4-arg wrapper at both call sites, otherwise main() and run_regression_perf() stay out of sync with the refactor.

Also applies to: 280-290

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_mla/example_mla_decode.py` around lines 245 - 255, The
correctness checks call profiler.assert_allclose(ref_program, ...) but
ref_program still expects six args (q, q_pe, kv, k_pe, glse, Output_partial);
make the calls and reference consistent by either slimming ref_program to accept
only the four inputs used by the eager kernel (Q, Q_pe, KV, K_pe) or add a thin
4-arg wrapper (e.g., ref_program4) that forwards those four arguments to the
existing ref_program with suitable defaults for glse and Output_partial, and
then use that 4-arg function at both profiler.assert_allclose call sites (the
checks invoked from main() and run_regression_perf()) so the kernel and
reference signatures match.
examples/deepseek_mla/example_mla_decode_ws.py-569-579 (1)

569-579: ⚠️ Potential issue | 🟠 Major

Keep the reference callable aligned with the compiled entrypoint.

Line 579 still hands assert_allclose the pre-refactor six-argument ref_program, even though the compiled kernel now exposes only (Q, Q_pe, KV, K_pe). The validation step stays stale until those unused parameters are removed or shimmed away here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_mla/example_mla_decode_ws.py` around lines 569 - 579, The
reference callable passed to profiler.assert_allclose is still the old
six-argument ref_program but the compiled entrypoint now exposes only (Q, Q_pe,
KV, K_pe); update the validation to use a reference callable that matches the
compiled signature by either changing ref_program to accept only (Q, Q_pe, KV,
K_pe) or provide a small shim/wrapper (e.g., lambda or def ref_shim(Q, Q_pe, KV,
K_pe): return ref_program(Q, Q_pe, KV, K_pe, <shim_K>, <shim_V>)) and pass that
shim to profiler.assert_allclose so the callable signatures align with the
compiled entrypoint.
examples/deepseek_mla/example_mla_decode_persistent.py-208-219 (1)

208-219: ⚠️ Potential issue | 🟠 Major

Update ref_program to the new eager signature.

Line 219 now validates a kernel whose public tensor inputs are just (Q, Q_pe, KV, K_pe), but ref_program still declares the removed glse and Output_partial parameters. Either drop those params from ref_program or wrap it in a 4-arg shim before calling assert_allclose.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_mla/example_mla_decode_persistent.py` around lines 208 -
219, ref_program's signature still expects the removed glse and Output_partial
parameters but profiler.assert_allclose (via compiled.get_profiler()) now
supplies only (Q, Q_pe, KV, K_pe); update ref_program to match by either
removing the obsolete glse and Output_partial parameters from ref_program's
declaration or create a 4-argument shim that accepts (Q, Q_pe, KV, K_pe) and
calls the old ref_program supplying appropriate default/dummy values for glse
and Output_partial before passing to profiler.assert_allclose; target the
ref_program symbol (or place the shim right before the call to
profiler.assert_allclose) so the profiler and ref_program signatures match.
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py-276-278 (1)

276-278: ⚠️ Potential issue | 🟠 Major

Fix the reference call to match _get_inputs().

profiler._get_inputs() now reflects the eager compiled entrypoint, but Line 278 still splats those tensors into a six-argument ref_program. Call a 4-input reference here, or update ref_program, before this correctness path can work reliably again.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py` around lines
276 - 278, The correctness check is calling ref_program with six args while
profiler._get_inputs() now returns the eager compiled entrypoint inputs; update
the call to match the 4-input reference signature by passing only the first four
tensors from profiler._get_inputs() (use profiler._get_inputs() -> input_tensors
and call ref_program with those four), or alternatively update ref_program to
accept the full set if that is intended; adjust either the ref_program signature
or the caller at the site where compiled(*input_tensors) and ref_program(...)
are invoked so both use the same 4-input contract (referencing
profiler._get_inputs(), compiled, and ref_program).
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py-269-275 (1)

269-275: ⚠️ Potential issue | 🟠 Major

--autotune no longer changes what gets benchmarked.

Lines 274-275 unconditionally compile the fixed (BLOCK_N, BLOCK_H, num_split, threads) config after the autotune branch, so the later profiler path ignores the autotuned result. With --autotune, the script now does extra work and still reports the hard-coded kernel.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py` around lines
269 - 275, The profiling currently always compiles the hard-coded (BLOCK_N,
BLOCK_H, num_split, threads) config, ignoring the autotuned run; change the
compile call to match the invocation: if enable_autotune is true, call
flashmla_decode.compile(Q, Q_pe, KV, K_pe) (no fixed params) so the compiled
kernel matches the autotuned result and avoid the extra fixed-params compile,
otherwise compile with the explicit block_N, block_H, num_split, threads
arguments as before; ensure the profiler is taken from that
conditionally-selected compiled object (use compiled.get_profiler()).
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py-388-397 (1)

388-397: ⚠️ Potential issue | 🟠 Major

This perf harness is timing setup and an empty workload.

The timed closure recreates SparseFlashAttn, regenerates Q/K/V, and leaves block_indices_tensor entirely -1. Even with the combine guard fixed, that still measures setup cost and a no-op path instead of sparse decode throughput.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py`
around lines 388 - 397, The timed harness in run_kernel_only recreates
SparseFlashAttn and regenerates Q/K/V each run and passes an all -1
block_indices_tensor, so you end up timing setup and a no-op path; to fix, move
creation of SparseFlashAttn, Q, K, V, cache_seqlens and block_indices_tensor
outside the timed closure so they are allocated once, populate
block_indices_tensor with realistic non -1 block indices (or a mix
representative of sparse_ratio) for each batch/head_kv instead of leaving it all
-1, and then call sparse_kernel(...) repeatedly inside the timed loop to measure
real sparse decode throughput (refer to run_kernel_only, SparseFlashAttn,
block_indices_tensor, cache_seqlens, Q/K/V).
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py-410-419 (1)

410-419: ⚠️ Potential issue | 🟠 Major

Move setup out of the timed closure.

run_kernel_only() recreates the module, allocates/fills every CUDA tensor, and benchmarks an all-false mask on each sample. That makes the reported number mostly reflect setup work and the empty-path behavior rather than sparse decode performance.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py`
around lines 410 - 419, The timed closure run_kernel_only is recreating the
SparseFlashAttn model and allocating/filling CUDA tensors (Q, K, V,
cache_seqlens, block_mask_tensor, num_blocks) on each invocation, skewing the
benchmark; move the setup out of run_kernel_only so the model, dtype, Q, K, V,
cache_seqlens, num_blocks and block_mask_tensor are created once before timing
and run_kernel_only only calls model(Q, K, V, block_mask_tensor, cache_seqlens)
(and if you need per-iteration variability, update only the minimal tensors
inside the timed loop).
examples/convolution/example_convolution_autotune.py-77-95 (1)

77-95: ⚠️ Potential issue | 🟠 Major

dtype/accum_dtype are public knobs, but the kernel hard-codes fp16/fp32.

Lines 94-95 bind the inputs with the caller-provided types, then Lines 100-101 overwrite those variables before the shared/output tensors are created. A non-default call will therefore mix caller-typed inputs with fp16/fp32 internals. Either honor the parameters end-to-end or remove them from the signature.

Also applies to: 100-101

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/convolution/example_convolution_autotune.py` around lines 77 - 95,
The convolution function currently binds the input tensors to the
caller-provided dtype/accum_dtype but then overwrites those types with
hard-coded fp16/fp32 (reassignments of dtype and accum_dtype), causing
mixed-type internals; update the function so it either removes dtype/accum_dtype
from the signature or — preferably — stop overriding the passed-in dtype and
accum_dtype and use the provided dtype/accum_dtype when declaring shared/output
tensors and any internal casts (look for the places where dtype and accum_dtype
are reassigned to T.float16/T.float32 and remove those reassignments so that
data/kernel_weight and tensors created later respect the original dtype and
accum_dtype parameters).
examples/blocksparse_gemm/example_blocksparse_gemm.py-125-132 (1)

125-132: ⚠️ Potential issue | 🟠 Major

Generate BlockMask with the same tiling the kernel is using.

Lines 125-127, 144-145, and 180-181 size the mask with DEFAULT_BLOCK_* and floor division, while the kernel indexes BlockMask[by, bx, k] using the active block_M/N/K and a ceildiv grid. Any autotune candidate that changes tile sizes—or any input with remainder tiles—will misalign the mask and can read the wrong sparsity decisions. Please either rebuild the mask per config, or keep tile sizes fixed and size the tensor with ceiling division/assert divisibility.

Also applies to: 143-158, 179-193

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/blocksparse_gemm/example_blocksparse_gemm.py` around lines 125 -
132, The mask is built using DEFAULT_BLOCK_M/N/K and floor division (mask_shape
= (M // block_M, N // block_N, K // block_K)) but the kernel and autotuner use
the active tile sizes and ceildiv when indexing BlockMask[by, bx, k], causing
misalignment for different tile candidates or remainder tiles; fix by
generating/rebuilding block_mask using the kernel's actual tile sizes (use
ceildiv(M, tile_M), ceildiv(N, tile_N), ceildiv(K, tile_K)) taken from the
compiled kernel or best_config (e.g., after blocksparse_matmul.compile/getting
best_config) or alternatively enforce fixed tiles and size the mask with ceiling
division or assert exact divisibility so BlockMask indexing aligns with the
kernel's block_M/block_N/block_K.
examples/convolution/example_convolution_autotune.py-92-99 (1)

92-99: ⚠️ Potential issue | 🟠 Major

Guard rectangular filters or compute OW from KW.

Lines 97-99 collapse the filter shape to K = KH and use that single extent for both output dimensions. After the tensor-first API change, kernel_weight can be rectangular, so KH != KW will produce the wrong width/output indexing.

💡 One safe option
-    K = KH
-    OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
-    OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
+    assert KH == KW, "Only square kernels are supported"
+    OH = (H + 2 * P - D * (KH - 1) - 1) // S + 1
+    OW = (W + 2 * P - D * (KW - 1) - 1) // S + 1
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/convolution/example_convolution_autotune.py` around lines 92 - 99,
The code incorrectly collapses the kernel shape by setting K = KH and using it
for both spatial dims, which breaks rectangular kernels; change the logic to
keep distinct kernel heights and widths (use KH and KW separately), compute OH
using KH (or K_h) and OW using KW (or K_w), and update any downstream uses of K
to use the appropriate per-dimension symbol (e.g., replace references to K when
computing OW/indexing with KW or K_w); ensure kernel_weight is treated as [KH,
KW, C, F] when computing output sizes and indexing.
🟡 Minor comments (9)
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py-517-518 (1)

517-518: ⚠️ Potential issue | 🟡 Minor

Fix acc_s_cast allocation to use fragment instead of shared memory.

Line 518 allocates acc_s_cast with T.alloc_shared, but all other NSA examples (example_tilelang_nsa_fwd.py, example_tilelang_nsa_fwd_varlen.py, example_tilelang_nsa_decode.py, example_tilelang_nsa_bwd.py) use T.alloc_fragment. For consistency and performance (fragments stay in registers), change to T.alloc_fragment([G, BS], dtype).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py` around lines 517 - 518,
The allocation for acc_s_cast is using T.alloc_shared but should use a fragment
like the other NSA examples; update the acc_s_cast allocation to
T.alloc_fragment([G, BS], dtype) (matching acc_s's shape [G, BS] and using
dtype) so acc_s_cast is allocated as a fragment (register-backed) instead of
shared memory for consistency and performance.
examples/attention_sink/example_mha_sink_bwd_bhsd.py-135-139 (1)

135-139: ⚠️ Potential issue | 🟡 Minor

Rename the new O identifiers before lint fails.

Lines 135-139 and Line 490 use O, which Ruff flags as ambiguous (E741). Renaming those new identifiers to out or output should clear the lint regression.

Also applies to: 490-490

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/attention_sink/example_mha_sink_bwd_bhsd.py` around lines 135 - 139,
The identifier O in flashattn_bwd_preprocess should be renamed to a
non-ambiguous name (e.g., out or output) to avoid Ruff E741; update the function
signature and its typed annotation T.Tensor[[batch, heads, seq_len, dim],
T.float16] to use the new name and change all internal references to O
accordingly, and apply the same renaming for the other occurrence at the site
referenced around line 490 so both locations consistently use the new identifier
(also verify any related uses like dO remain correct).
examples/flash_attention/example_mha_bwd_bhsd.py-88-91 (1)

88-91: ⚠️ Potential issue | 🟡 Minor

Rename O before Ruff fails this helper.

Lines 88-91 introduce O, which Ruff flags as ambiguous (E741). Renaming it to out keeps the new eager-style preprocess helper lint-clean.

🩹 Minimal fix
-def flashattn_bwd_preprocess(O, dO):
+def flashattn_bwd_preprocess(out, dO):
     batch, heads, seq_len, dim = T.const("batch heads seq_len dim")
-    O: T.Tensor[[batch, heads, seq_len, dim], T.float16]
+    out: T.Tensor[[batch, heads, seq_len, dim], T.float16]
     dO: T.Tensor[[batch, heads, seq_len, dim], T.float16]
     Delta = T.empty([batch, heads, seq_len], T.float32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention/example_mha_bwd_bhsd.py` around lines 88 - 91, The
helper function flashattn_bwd_preprocess declares a parameter named O which
triggers Ruff E741; rename the parameter O to out (and update the corresponding
type annotation and any internal uses) so the signature becomes
flashattn_bwd_preprocess(out, dO) and the annotation uses out: T.Tensor[[batch,
heads, seq_len, dim], T.float16]; keep dO unchanged and ensure all references
inside the function and any callers are updated to use out instead of O.
examples/amd/example_amd_flash_attn_bwd.py-238-245 (1)

238-245: ⚠️ Potential issue | 🟡 Minor

Rename O before Ruff fails this helper.

Line 238 introduces O, which Ruff flags as ambiguous (E741). Renaming the parameter and its annotation to out keeps the new eager helper lint-clean.

🩹 Minimal fix
-@tilelang.jit
-def flashattn_bwd_preprocess(O, dO, batch: int = 1, heads: int = 8, seq_len: int = 1024, dim: int = 64):
+@tilelang.jit
+def flashattn_bwd_preprocess(out, dO, batch: int = 1, heads: int = 8, seq_len: int = 1024, dim: int = 64):
     batch, heads, seq_len, dim = T.const("batch heads seq_len dim")
     dtype = T.float16
     accum_dtype = T.float32
     blk = 32

-    O: T.Tensor[[batch, seq_len, heads, dim], dtype]
+    out: T.Tensor[[batch, seq_len, heads, dim], dtype]
     dO: T.Tensor[[batch, seq_len, heads, dim], dtype]
     Delta = T.empty([batch, heads, seq_len], accum_dtype)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/amd/example_amd_flash_attn_bwd.py` around lines 238 - 245, The
helper function flashattn_bwd_preprocess uses the single-letter parameter name O
which triggers Ruff E741; rename the parameter and its annotated binding from O
to a clearer name (e.g., out) and update all references and the type annotation
(change "O" to "out" in the parameter list and the "O: T.Tensor[[...], dtype]"
line) so the function signature and internal annotations remain consistent and
lint-clean while preserving dO and other symbols.
examples/flash_attention/example_gqa_bwd_tma_reduce.py-350-353 (1)

350-353: ⚠️ Potential issue | 🟡 Minor

Inconsistent copy pattern for dk vs dv.

Lines 350-351 copy dv to dv_shared then dv_shared to dV, but lines 352-353 copy dk to dk_shared then copy dk (not dk_shared) directly to dK. This appears inconsistent and may be a copy-paste error.

🔧 Suggested fix for consistency
         T.copy(dv, dv_shared)
         T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
         T.copy(dk, dk_shared)
-        T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
+        T.copy(dk_shared, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention/example_gqa_bwd_tma_reduce.py` around lines 350 -
353, The copy sequence is inconsistent: after copying dk into dk_shared you then
copy dk (not dk_shared) into dK; change the second copy to use dk_shared so it
mirrors the dv -> dv_shared -> dV pattern. Update the T.copy that currently uses
dk as source to instead use dk_shared for writing into dK (the statements
involving dk, dk_shared, and dK with indices bx, groups, bz, by, block_M).
examples/flash_attention/example_gqa_bwd.py-340-343 (1)

340-343: ⚠️ Potential issue | 🟡 Minor

Same inconsistent copy pattern for dk as in example_gqa_bwd_tma_reduce.py.

Lines 340-341 copy dv to dv_shared then dv_shared to dV, but lines 342-343 copy dk to dk_shared then copy dk (not dk_shared) directly to dK. This is the same inconsistency noted in the TMA reduce variant.

🔧 Suggested fix for consistency
         T.copy(dv, dv_shared)
         T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
         T.copy(dk, dk_shared)
-        T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
+        T.copy(dk_shared, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention/example_gqa_bwd.py` around lines 340 - 343, The copy
for dk is inconsistent: after copying dk into dk_shared you must copy dk_shared
into dK (just like dv -> dv_shared -> dV). Update the second dk copy so that
T.copy(dk_shared, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx //
groups, :]) is used instead of copying dk directly; this keeps the pattern for
dk consistent with dv and matches the behavior in functions using dv/dv_shared
and dV.
examples/attention_sink/example_gqa_sink_bwd_bhsd.py-69-76 (1)

69-76: ⚠️ Potential issue | 🟡 Minor

Fragment allocation size mismatch for sinks.

The sinks fragment is allocated with size [heads] but the parallel loop at lines 75-76 iterates over block_M elements. Since you're copying a single value Sinks[by] to all block_M positions, the fragment should be allocated as [block_M] to match the loop bounds and consistent with other position-level fragments.

🔧 Suggested fix
-        sinks = T.alloc_fragment([heads], dtype)
+        sinks = T.alloc_fragment([block_M], dtype)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/attention_sink/example_gqa_sink_bwd_bhsd.py` around lines 69 - 76,
The sinks fragment is allocated with the wrong size—T.alloc_fragment([heads],
dtype) but the loop for i in T.Parallel(block_M) writes sinks[i]; change the
allocation to match the loop by allocating sinks with size [block_M] (i.e., use
T.alloc_fragment([block_M], dtype)) so that sinks, the T.alloc_fragment call,
and the parallel loop filling sinks from Sinks[by] are consistent.
examples/deepseek_mla/example_mla_decode_paged.py-26-26 (1)

26-26: ⚠️ Potential issue | 🟡 Minor

Make the softmax_scale annotation match its None default.

Line 26 advertises softmax_scale as a plain float, but the function immediately treats None as a valid value. Update the annotation to float | None or Optional[float], or drop it if tilelang.jit requires the scalar annotation to remain concrete.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_mla/example_mla_decode_paged.py` at line 26, The
softmax_scale parameter is annotated as float but defaults to None; update the
function signature for the parameter softmax_scale in example_mla_decode_paged
(the function taking softmax_scale: float = None) to use a nullable type (e.g.,
softmax_scale: float | None or softmax_scale: Optional[float]) and add the
necessary typing import if you choose Optional; alternatively, if tilelang.jit
requires a concrete scalar annotation, remove the default None or adjust the
call sites to pass a concrete float so the annotation and default are
consistent.
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py-219-223 (1)

219-223: ⚠️ Potential issue | 🟡 Minor

Import autotune explicitly instead of using a star import.

Line 219 uses the @autotune decorator, but the file relies on from tilelang.autotuner import * at line 3. This causes a Ruff F405 undefined name warning. Since only autotune is used from this import (the get_configs() function is defined locally), explicit import is preferred.

🧹 Minimal cleanup
-from tilelang.autotuner import *
+from tilelang.autotuner import autotune
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py` around lines 219
- 223, Replace the star import from tilelang.autotuner with an explicit import
for autotune (or add an explicit "from tilelang.autotuner import autotune") so
the `@autotune` decorator used on the kernel function is defined; update the
import near the top of the file where tilelang.autotuner is currently imported
with a wildcard and leave get_configs() as-is since it is defined locally.
🧹 Nitpick comments (14)
examples/deepseek_mhc/example_mhc_bwd.py (1)

69-76: Unused constants n_s and n_s2 create potential confusion.

The constants n_s and n_s2 extracted via T.const() are only used in the tensor annotations (lines 70-71) but never referenced in the kernel logic. All allocations and loop bounds use n_stream instead:

  • Line 73: tensor_shape = [seqlen, n_stream, n_stream]
  • Line 100: T.alloc_fragment([tilesize, n_stream, n_stream], ...)
  • Line 80: T.Parallel(tilesize, n_stream, n_stream)

If the input tensors' actual shape dimensions differ from n_stream, there could be a silent shape mismatch between what's annotated versus what's allocated. Consider either:

  1. Using n_s/n_s2 consistently throughout the kernel, or
  2. Removing them if n_stream is the authoritative compile-time shape
♻️ Option 1: Use extracted constants consistently
-    seqlen, n_s, n_s2 = T.const("seqlen n_s n_s2")
+    seqlen, n_s, _ = T.const("seqlen n_s n_s2")
     out: T.Tensor[[seqlen, n_s, n_s2], T.float32]
     dout: T.Tensor[[seqlen, n_s, n_s2], T.float32]

-    tensor_shape = [seqlen, n_stream, n_stream]
+    tensor_shape = [seqlen, n_s, n_s]
     dtype = T.float32

     res = T.empty(tensor_shape, dtype)

Then update all n_stream references in macros and allocations to n_s.

♻️ Option 2: Simplify by removing unused constants
-    seqlen, n_s, n_s2 = T.const("seqlen n_s n_s2")
-    out: T.Tensor[[seqlen, n_s, n_s2], T.float32]
-    dout: T.Tensor[[seqlen, n_s, n_s2], T.float32]
+    seqlen = T.const("seqlen")
+    out: T.Tensor[[seqlen, n_stream, n_stream], T.float32]
+    dout: T.Tensor[[seqlen, n_stream, n_stream], T.float32]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_mhc/example_mhc_bwd.py` around lines 69 - 76, The code
defines compile-time constants n_s and n_s2 via T.const but never uses them in
kernel logic (allocations and loops use n_stream), causing silent shape mismatch
risk; either (A) replace n_stream with n_s/n_s2 across the kernel (update
tensor_shape, T.alloc_fragment calls, and T.Parallel bounds) so annotations and
allocations match the T.const annotations, or (B) remove the unused T.const("n_s
n_s2") and the n_s/n_s2 symbols (and their annotations in out/dout) and keep
n_stream as the single authoritative compile-time size; update the tensor_shape,
res allocation, and any T.alloc_fragment / T.Parallel usages accordingly and
ensure the out/dout annotations match the chosen shape variables.
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py (1)

140-140: Unused parameter in_dtype.

The in_dtype parameter is passed to assert_tl_gemm_correctness but is never used within the function. The input dtype is determined by A_fp8.dtype implicitly. Consider removing this parameter to avoid confusion.

🧹 Proposed fix
-def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype):
+def assert_tl_gemm_correctness(M, N, K, block_N, out_dtype, accum_dtype):

And update the call sites in main() and the __main__ block accordingly.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py` at line 140, The
function assert_tl_gemm_correctness currently accepts an unused parameter
in_dtype; remove in_dtype from its signature and from any internal references,
then update all call sites (notably in main() and the __main__ block) to stop
passing that argument so the parameter list matches the new signature; search
for assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype,
accum_dtype) and replace calls to assert_tl_gemm_correctness(M, N, K, block_N,
out_dtype, accum_dtype) (or reorder to match the updated signature) to ensure
types still align with how A_fp8.dtype is used.
examples/deepseek_nsa/example_tilelang_nsa_bwd.py (1)

2-6: Duplicate torch import.

torch is imported twice (lines 2 and 6).

♻️ Remove duplicate import
 # ruff: noqa
 import torch
 from typing import Optional, Union
 from packaging.version import parse
 
-import torch
 import triton
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_nsa/example_tilelang_nsa_bwd.py` around lines 2 - 6, The
file contains a duplicate import of the torch module; remove the redundant
import statement so only a single "import torch" remains at the top of the
module (look for the two occurrences of the "import torch" statement among the
top-level imports alongside "from typing import Optional, Union" and "from
packaging.version import parse")—delete one of the identical imports to avoid
the duplication.
examples/deepseek_v32/inference/kernel.py (1)

147-147: Inconsistent 1D shape tuple syntax.

(block_M) evaluates to just block_M (an integer), not a single-element tuple. For consistency with other 1D allocations in this file (e.g., line 65: (blk_m,)), consider using a trailing comma.

Suggested fix
-        Scale_C_shared = T.alloc_shared((block_M), FP32)
+        Scale_C_shared = T.alloc_shared((block_M,), FP32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_v32/inference/kernel.py` at line 147, The allocation call
for Scale_C_shared uses (block_M) which is an int, not a 1-tuple; update the
T.alloc_shared invocation for Scale_C_shared to use a single-element tuple
syntax (block_M,) to match other 1D allocations (e.g., at blk_m,) and ensure the
shape is interpreted as a 1D array by the allocator.
examples/deepseek_mhc/example_mhc_pre.py (1)

35-35: Unused symbolic dimension declaration.

This line creates two symbolic constants named literally "_" that are immediately discarded. If the intent was to skip dimensions from a tensor shape, there's no corresponding tensor type annotation that uses "_ _". This appears to be dead code that can be removed.

🧹 Suggested fix
     n_splits_d, num_tokens, hc_mult3_d = T.const("n_splits_d num_tokens hc_mult3_d")
-    _, _ = T.const("_ _")
     hc_scale_len = T.const("hc_scale_len")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_mhc/example_mhc_pre.py` at line 35, The line calling
T.const("_ _") creates unused symbolic constants and should be removed; delete
the expression "_, _ = T.const(\"_ _\")" from example_mhc_pre.py (or replace it
with the correct tensor type annotation if you actually intended to declare
anonymous dimensions), ensuring no dead code remains and that any needed
symbolic dimensions are declared where a tensor shape/type uses them (e.g., in
the relevant T.tensor/T.type annotation) instead of as a discarded const call.
examples/deepseek_v32/topk_selector.py (1)

69-72: Scan only the active [start, end) window.

Line 69 and Line 97 still iterate over ceildiv(seq_len, BLOCK_SIZE), so short rows pay O(max_seq_len) work and just branch away most lanes. Basing both stage-1 loops on l_end_idx - l_start_idx would make the new varlen path proportional to each row’s actual length.

Also applies to: 97-100

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_v32/topk_selector.py` around lines 69 - 72, The loop over
T.serial currently uses ceildiv(seq_len, BLOCK_SIZE) causing all rows to iterate
to max length; change both stage-1 loops that use T.serial(T.ceildiv(seq_len,
BLOCK_SIZE)) so they instead iterate over T.ceildiv(l_end_idx - l_start_idx,
BLOCK_SIZE) (i.e., base the iteration count on the active window length), and
compute input_idx = s * BLOCK_SIZE + tx + l_start_idx (or otherwise offset by
l_start_idx) before the existing bounds checks; update references in the same
block that load input (convert_to_uint16(input[bx, input_idx])) so they use the
new input_idx/offset logic to make work proportional to each row’s actual
length.
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)

135-159: Same ambiguous O parameter name as other files.

Consider renaming O to Out or Output to address the E741 static analysis warning, consistent with any changes made in other files.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/attention_sink/example_gqa_sink_bwd_bhsd.py` around lines 135 - 159,
Rename the ambiguous parameter O in the function flashattn_bwd_preprocess to a
clearer name (e.g., Output or Out) and update every reference inside the
function (parameter list, T.copy sources, T.alloc_fragment uses that reference,
and any other occurrences like the copy from O to o) to use the new identifier;
keep dO and Delta unchanged and ensure the new name matches the convention used
in other files to resolve the E741 static analysis warning.
examples/flash_attention/example_gqa_bwd_tma_reduce.py (2)

359-367: Unused shape unpacking variables.

The variables BATCH, N_CTX, H, D_HEAD_QK are unpacked but never used. Consider using underscore prefix per static analysis suggestion.

💡 Optional fix
     `@staticmethod`
     def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
-        BATCH, N_CTX, H, D_HEAD_QK = q.shape
+        _BATCH, _N_CTX, _H, _D_HEAD_QK = q.shape
         block_M = 128
         block_N = 64

Or simply remove the unpacking if not needed:

     `@staticmethod`
     def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
-        BATCH, N_CTX, H, D_HEAD_QK = q.shape
         block_M = 128
         block_N = 64
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention/example_gqa_bwd_tma_reduce.py` around lines 359 -
367, In the forward function, the shape unpacking line "BATCH, N_CTX, H,
D_HEAD_QK = q.shape" creates unused variables; either remove the unpacking
entirely or rename them to unused placeholders (e.g., _batch, _n_ctx, _h,
_d_head_qk) to satisfy static analysis. Update the forward function (around the
call to flashattn_fwd and ctx.save_for_backward) so no unused local names remain
while preserving q, k, v, causal, groups and use_atomic behavior.

437-438: Unnecessary dummy tensor allocations in non-atomic path.

When using the split path, flashattn_bwd_postprocess(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)) allocates zero tensors for dK and dV just to satisfy the function signature, then immediately discards them. This wastes GPU memory.

Consider either:

  1. Overloading flashattn_bwd_postprocess to accept only dQ for this use case
  2. Or documenting this is intentional for API consistency
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention/example_gqa_bwd_tma_reduce.py` around lines 437 -
438, The call to flashattn_bwd_postprocess in the split/non-atomic path is
creating throwaway zero tensors for dK and dV (torch.zeros_like(k,
dtype=torch.float32)) which wastes GPU memory; update flashattn_bwd_postprocess
to accept optional dK/dV (e.g., default None) or add an overload that only takes
dq, and change the call in this file to pass only dq (remove the
torch.zeros_like arguments) so the function avoids allocating zeros when dK/dV
are not needed; ensure flashattn_bwd_postprocess checks for None and
computes/returns dK/dV appropriately or returns placeholders without allocating
memory, and update any other callers to the new signature if necessary.
examples/flash_attention/example_mha_bwd_bshd.py (1)

85-109: Consider renaming O parameter to avoid ambiguity.

Static analysis flags O as an ambiguous variable name (E741) since it can be confused with 0. While O is conventional in attention literature for "Output", consider renaming to Out or Output for clarity.

💡 Optional rename suggestion
-def flashattn_bwd_preprocess(O, dO):
+def flashattn_bwd_preprocess(Out, dO):
     batch, seq_len, heads, dim = T.const("batch seq_len heads dim")
-    O: T.Tensor[[batch, seq_len, heads, dim], T.float16]
+    Out: T.Tensor[[batch, seq_len, heads, dim], T.float16]
     dO: T.Tensor[[batch, seq_len, heads, dim], T.float16]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention/example_mha_bwd_bshd.py` around lines 85 - 109,
Rename the ambiguous parameter O in function flashattn_bwd_preprocess to a
clearer name (e.g., Out or Output) and update all internal references
accordingly (the function signature, uses in T.copy slices and any variable
names that shadow it like the local fragment o if you prefer to keep that).
Specifically update the parameter O -> Out (and adjust any code that references
O[bz, ...] to Out[bz, ...]) while leaving dO and local fragments (o, do, acc,
delta, Delta) unchanged; run tests/lint to ensure the E741 warning is resolved.
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (2)

151-192: Same O parameter naming consideration.

This varlen preprocess has the same E741 warning for ambiguous O variable name.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py` around lines
151 - 192, Rename the ambiguous parameter O in flashattn_bwd_preprocess to a
clearer name (e.g., output or O_tensor) and update its type annotation and every
usage inside the function (all occurrences of O[...] -> output[...]). Ensure the
function signature, the typed tensor declaration (previously "O:
T.Tensor[...]"), and all reads that index O (used when filling o[i, j]) are
updated to the new identifier so there are no remaining references to the old
name.

623-624: Same dummy tensor allocation pattern in non-atomic path.

The non-atomic path allocates dummy zero tensors for dK and dV just to call flashattn_bwd_postprocess, then discards the results. This is consistent with the non-varlen version but wastes memory.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py` around lines
623 - 624, The non-atomic path currently creates fresh dummy zero tensors for
dK/dV just to call flashattn_bwd_postprocess(dq, ..., ...), wasting memory; fix
this by reusing a single preallocated pair of dummy tensors (e.g., dummy_dk,
dummy_dv) allocated once with torch.zeros_like(k, dtype=torch.float32) and
torch.zeros_like(v, dtype=torch.float32) and use those in the non-atomic branch
when calling flashattn_bwd_postprocess(dq, dummy_dk, dummy_dv), or alternatively
update flashattn_bwd_postprocess to accept None for dK/dV and handle that case
so you only pass dq (avoid per-iteration allocations); ensure references to dq,
dk, dv, k, v and flashattn_bwd_postprocess are used so reviewers can locate the
change.
examples/attention_sink/example_gqa_sink_bwd_varlen.py (1)

221-262: Same O parameter naming applies here.

The preprocess kernel uses the ambiguous O variable name flagged by E741.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/attention_sink/example_gqa_sink_bwd_varlen.py` around lines 221 -
262, Rename the ambiguous tensor parameter O in flashattn_bwd_preprocess to a
more descriptive name (e.g., output_tensor or outputs) and update all uses
inside the function (alloc/loads from O and the indexing like O[q_start_idx +
...]) to the new name; ensure the function signature, type annotation line "O:
T.Tensor[[UQ, heads, dim], T.float16]" and every reference (currently O[...] and
the variable passed in when this function is called) are consistently updated to
the chosen descriptive identifier to avoid the E741 ambiguous name warning.
examples/dequantize_gemm/example_dequant_gemm_fine_grained.py (1)

104-105: Use torch.randint instead of torch.rand for integer-typed tensors.

Line 104 uses torch.rand(..., dtype=getattr(torch, in_dtype)), which works but is semantically confusing when in_dtype is T.int8. The test exercises this code path with in_dtype=T.int8, and while torch.rand will generate values and cast them to int8, using torch.randint is clearer and more idiomatic for integer tensor initialization.

💡 Suggested fix
+    torch_dtype = getattr(torch, in_dtype)
+    if str(in_dtype) == "int8":
+        A = torch.randint(-128, 128, (M, K), device="cuda", dtype=torch_dtype)
+    else:
+        A = torch.rand(M, K, device="cuda", dtype=torch_dtype)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/dequantize_gemm/example_dequant_gemm_fine_grained.py` around lines
104 - 105, Replace the use of torch.rand for integer-typed inputs with
torch.randint: update the tensor creation for A (currently A = torch.rand(...,
dtype=getattr(torch, in_dtype))) to use torch.randint with an appropriate
integer range and dtype determined by in_dtype (e.g., for int8 use a signed
range like -128 to 127, and for unsigned types use 0..max), so A is created
idiomatically as an integer tensor when getattr(torch, in_dtype) is an integer
dtype.

Comment on lines +69 to +76
sinks = T.alloc_fragment([heads], dtype)

T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Allocate sinks per row, not per head count.

Line 69 sizes sinks as [heads], but Lines 75-76 index it with i over block_M. With the current defaults, block_M is often larger than heads, so this writes past the fragment bounds.

🩹 Minimal fix
-        sinks = T.alloc_fragment([heads], dtype)
+        sinks = T.alloc_fragment([block_M], dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
sinks = T.alloc_fragment([heads], dtype)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
sinks = T.alloc_fragment([block_M], dtype)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/attention_sink/example_mha_sink_bwd_bhsd.py` around lines 69 - 76,
The sinks fragment is allocated with the wrong size: T.alloc_fragment([heads],
dtype) but it's indexed by i in the T.Parallel(block_M) loop, causing
out-of-bounds when block_M > heads; update the allocation so sinks is sized per
row (use block_M as the dimension) — i.e., replace the T.alloc_fragment
allocation for sinks to allocate length block_M (keeping the same dtype) so the
subsequent loop assigning sinks[i] = Sinks[by] is safe.

Comment on lines +89 to +104
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))

T.copy(BlockSparseMask[bz, by, bx, :], block_mask)

loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
)

for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k] != 0:
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)

for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]

T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Restore the running max before rescaling.

reduce_max(acc_s, scores_max) now gives only the current block max, but the online softmax update still treats scores_max as the max over all processed blocks. If a later sparse block has a smaller row max than an earlier one, prior contributions are rescaled against the wrong baseline and the output drifts.

🩹 Proposed fix
                 T.copy(scores_max, scores_max_prev)
                 T.fill(scores_max, -T.infinity(accum_dtype))
                 T.reduce_max(acc_s, scores_max, dim=1, clear=False)
                 # To do causal softmax, we need to set the scores_max to 0 if it is -inf
                 # This process is called Check_inf in FlashAttention3 code, and it only need to be done
                 # in the first ceil_div(kBlockM, kBlockN) steps.
                 # for i in T.Parallel(block_M):
                 #     scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
                 for i in T.Parallel(block_M):
+                    scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
                     scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/blocksparse_attention/example_tilelang_block_sparse_attn.py` around
lines 89 - 104, The running max must be restored before rescaling: after
T.reduce_max(acc_s, scores_max, dim=1, clear=False) compute the elementwise
running max = max(scores_max_prev, scores_max) (update either scores_max or
scores_max_prev to hold this combined max) and then use that combined max when
computing scores_scale and when subtracting for acc_s; i.e., replace the current
use of scores_max_prev and scores_max with the single restored running max
(using symbols scores_max_prev, scores_max, scores_scale, acc_s) so prior-block
contributions are rescaled against the true running maximum before
T.reduce_sum(scores_sum).

Comment on lines +137 to +160
max_split = T.alloc_var(T.int32)

T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_local_split = glse[bz, by, k]
if lse_local_split != 0:
max_split = k
lse_max_local = T.max(lse_max_local, glse[bz, by, k])

for k in T.Pipelined(num_split, num_stages=1):
if k <= max_split:
lse_local_split = glse[bz, by, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
if k <= max_split:
for i in T.Parallel(dim_v):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split = glse[bz, by, k]
if lse_local_split != 0:
max_split = k
lse_max_local = T.max(lse_max_local, glse[bz, by, k])

for k in T.Pipelined(num_split, num_stages=1):
if k <= max_split:
lse_local_split = glse[bz, by, k]
lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
for k in T.serial(num_split):
if k <= max_split:
for i in T.Parallel(dim_v):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split = glse[bz, by, k]
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim_v):
o_accum_local[i] += po_local[i] * scale_local
for i in T.Parallel(dim_v):
Output[bz, by, i] = o_accum_local[i]

return main
scale_local = T.exp2(lse_local_split - lse_logsum_local)
for i in T.Parallel(dim_v):
o_accum_local[i] += po_local[i] * scale_local
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard the no-valid-split path before reading max_split.

If a head has no valid indices, the first loop never assigns max_split, but the later k <= max_split checks still read it. run_regression_perf() currently builds exactly that all--1 input, so this path is live.

🩹 Proposed fix
         scale_local = T.alloc_var(accum_dtype)
         max_split = T.alloc_var(T.int32)

         T.clear(lse_logsum_local)
         T.clear(o_accum_local)
         lse_max_local = -T.infinity(accum_dtype)
+        max_split = -1
         for k in T.serial(num_split):
             lse_local_split = glse[bz, by, k]
             if lse_local_split != 0:
                 max_split = k
                 lse_max_local = T.max(lse_max_local, glse[bz, by, k])

-        for k in T.Pipelined(num_split, num_stages=1):
-            if k <= max_split:
-                lse_local_split = glse[bz, by, k]
-                lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
-        lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
-        for k in T.serial(num_split):
-            if k <= max_split:
-                for i in T.Parallel(dim_v):
-                    po_local[i] = Output_partial[bz, by, k, i]
-                lse_local_split = glse[bz, by, k]
-                scale_local = T.exp2(lse_local_split - lse_logsum_local)
-                for i in T.Parallel(dim_v):
-                    o_accum_local[i] += po_local[i] * scale_local
+        if max_split >= 0:
+            for k in T.Pipelined(num_split, num_stages=1):
+                if k <= max_split:
+                    lse_local_split = glse[bz, by, k]
+                    lse_logsum_local += T.exp2(lse_local_split - lse_max_local)
+            lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local
+            for k in T.serial(num_split):
+                if k <= max_split:
+                    for i in T.Parallel(dim_v):
+                        po_local[i] = Output_partial[bz, by, k, i]
+                    lse_local_split = glse[bz, by, k]
+                    scale_local = T.exp2(lse_local_split - lse_logsum_local)
+                    for i in T.Parallel(dim_v):
+                        o_accum_local[i] += po_local[i] * scale_local
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py`
around lines 137 - 160, The code reads max_split later even when no split was
assigned; initialize and guard its use: set max_split to -1 before the first
loop (currently max_split = T.alloc_var...) and then before the pipelined and
serial loops that check "k <= max_split" add a guard "if max_split >= 0" (or
wrap both loops in a single conditional) so you skip lse_logsum_local /
o_accum_local updates when there are no valid splits; reference symbols:
max_split, glse[bz, by, k], lse_max_local, lse_logsum_local, po_local,
Output_partial and o_accum_local.

Comment on lines +110 to +122
if has_valid_block:
for i, j in T.Parallel(block_H, dim_v):
if i < valid_block_H:
Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j]

# combine
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype)
o_accum_local = T.alloc_fragment([dim_v], accum_dtype)
lse_local_split = T.alloc_var(accum_dtype)
lse_logsum_local = T.alloc_var(accum_dtype)
lse_max_local = T.alloc_var(accum_dtype)
scale_local = T.alloc_var(accum_dtype)
max_split = T.alloc_var(T.int32)

T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local = -T.infinity(accum_dtype)
for k in T.serial(num_split):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale

# TODO(lei): Support T.Parallel(valid_block_H)
for i in T.Parallel(block_H):
if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i]
for i, j in T.Parallel(block_H, dim_v):
if i < valid_block_H:
Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

max_split is not a safe validity proxy for an arbitrary block mask.

This kernel writes zeroed glse/partials for empty splits, then the combine pass folds every split k <= max_split as if it were valid. With a boolean block_mask, empty splits can appear before later non-empty ones, so those holes corrupt the normalization; the all-false case never initializes max_split at all. Use an explicit per-split validity marker, or write -inf for empty splits and gate each combine loop on that instead.

Also applies to: 124-157

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py`
around lines 110 - 122, The current use of max_split/valid_block_H as a proxy
for validity is unsafe because empty splits can appear before later non-empty
ones (and can be all-false), causing incorrect glse/partial folding; change the
kernel to record an explicit per-split validity marker (e.g., valid_split[bid,
hid, i]) when writing glse/Output_partial or write -inf into glse for empty
splits, then update the combine/reader loops (the sections writing/reading glse
and Output_partial—look for uses of valid_block_H, max_split, glse,
Output_partial, acc_o, logsum) to gate on that per-split marker (or check for
-inf) instead of relying on max_split so only truly valid splits are combined.

Comment on lines +203 to +226
BlockMask: T.Tensor[block_mask_shape, T.int32]
num_threads = 32
print("NV", NV, "NS", NS, "B", B, "H", H)

@T.prim_func
def flash_bwd_dkv(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(k_shape, dtype),
V: T.Tensor(v_shape, dtype),
LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
Delta_slc: T.Tensor(delta_slc_shape, accum_dtype),
DO_slc: T.Tensor(do_slc_shape, dtype),
DK: T.Tensor(dk_shape, dtype),
DV: T.Tensor(dv_shape, dtype),
BlockMask: T.Tensor(block_mask_shape, T.int32),
):
with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
K_shared = T.alloc_shared([BS, BK], dtype)
V_shared = T.alloc_shared([BS, BV], dtype)
Q_shared = T.alloc_shared([G, BK], dtype)
qkT = T.alloc_fragment([BS, G], accum_dtype)
qkT_cast = T.alloc_fragment([BS, G], dtype)
dsT = T.alloc_fragment([BS, G], accum_dtype)
dsT_cast = T.alloc_fragment([BS, G], dtype)
lse_shared = T.alloc_shared([G], accum_dtype)
delta = T.alloc_shared([G], accum_dtype)

do = T.alloc_shared([G, BV], dtype)
dv = T.alloc_fragment([BS, BV], accum_dtype)
dk = T.alloc_fragment([BS, BK], accum_dtype)
dq = T.alloc_fragment([BS, G], accum_dtype)

dv_shared = T.alloc_shared([BS, BV], dtype)
dk_shared = T.alloc_shared([BS, BK], dtype)

i_b, i_h = i_bh // H, i_bh % H

T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared)
T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared)

# [BS, BK]
T.clear(dk)
# [BS, BV]
T.clear(dv)

loop_st = i_s * BS
loop_ed = seq_len
for i in T.Pipelined(
start=loop_st,
stop=loop_ed,
num_stages=0,
):
b_m_slc = BlockMask[i_b, i, i_h, i_s]
if b_m_slc != 0:
# [G, BK]
T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared)
T.clear(qkT)
# [BS, BK] @ [G, BK] -> [BS, G]
T.gemm(
K_shared,
Q_shared,
qkT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
# [G]
T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared)

for _i, _j in T.Parallel(BS, G):
qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j])

for _i, _j in T.Parallel(BS, G):
qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0)

# [G, BV]
T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do)
T.clear(dsT)
# [BS, BV] @ [G, BV] -> [BS, G]
T.gemm(
V_shared,
do,
dsT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.copy(qkT, qkT_cast)
# [BS, G] @ [G, BV] -> [BS, BV]
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
# [G]
T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta)
for i, j in T.Parallel(BS, G):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale

# [BS, G] @ [G, BK] -> [BS, BK]
T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow)

T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV])
T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK])

return flash_bwd_dkv
with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
K_shared = T.alloc_shared([BS, BK], dtype)
V_shared = T.alloc_shared([BS, BV], dtype)
Q_shared = T.alloc_shared([G, BK], dtype)
qkT = T.alloc_fragment([BS, G], accum_dtype)
qkT_cast = T.alloc_fragment([BS, G], dtype)
dsT = T.alloc_fragment([BS, G], accum_dtype)
dsT_cast = T.alloc_fragment([BS, G], dtype)
lse_shared = T.alloc_shared([G], accum_dtype)
delta = T.alloc_shared([G], accum_dtype)

do = T.alloc_shared([G, BV], dtype)
dv = T.alloc_fragment([BS, BV], accum_dtype)
dk = T.alloc_fragment([BS, BK], accum_dtype)
dq = T.alloc_fragment([BS, G], accum_dtype)

dv_shared = T.alloc_shared([BS, BV], dtype)
dk_shared = T.alloc_shared([BS, BK], dtype)

i_b, i_h = i_bh // H, i_bh % H
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Debug print statement with undefined variable H.

Line 205 contains a debug print statement that references H, which is undefined in this scope. The kernel defines heads_kv via T.const, not H. This will cause a NameError. Additionally, line 226 uses H for computing i_h which has the same issue.

🐛 Proposed fix - remove debug print and fix H references
     block_mask_shape = [batch, seq_len, heads_kv, NS]
     BlockMask: T.Tensor[block_mask_shape, T.int32]
     num_threads = 32
-    print("NV", NV, "NS", NS, "B", B, "H", H)
 
-    with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
+    with T.Kernel(NV, NS, B * heads_kv, threads=num_threads) as (i_v, i_s, i_bh):
         K_shared = T.alloc_shared([BS, BK], dtype)
         ...
 
-        i_b, i_h = i_bh // H, i_bh % H
+        i_b, i_h = i_bh // heads_kv, i_bh % heads_kv
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_nsa/example_tilelang_nsa_bwd.py` around lines 203 - 226,
The debug print uses an undefined H and the kernel's split of i_bh uses H as
well; replace or remove the print and change references of H to the actual
constant name heads_kv (the T.const used to define number of heads for KV), e.g.
remove the print("NV", NV, "NS", NS, "B", B, "H", H) and update the split i_b,
i_h = i_bh // H, i_bh % H to use heads_kv (i_b, i_h = i_bh // heads_kv, i_bh %
heads_kv) so all references match the defined symbol like heads_kv in the
kernel.

Comment on lines +115 to +128
if l_new_topk <= 0:
break

r_idx = round % 2
l_start_pos = topk - l_new_topk

T.sync_threads()
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
input_idx = s * BLOCK_SIZE + tx
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
inval_int16 = convert_to_uint16(input[bx, input_idx])
T.atomic_add(s_histogram[inval_int16], 1)
T.fill(s_histogram, 0)
if tx == 0:
s_num_input[r_idx ^ 1] = 0
T.sync_threads()

l_num_input = s_num_input[r_idx]
for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Accept the whole threshold bucket when it already fits.

Line 115 only exits on l_new_topk <= 0, but l_num_input <= l_new_topk is also terminal: every buffered candidate already belongs in the output. Running another radix pass in that state leaves no valid threshold bin, so s_threshold_bin_id[0] stays stale and the selection becomes undefined. This shows up for rows with exactly topk live elements and for any round where the current bucket size exactly matches the remaining budget.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_v32/topk_selector.py` around lines 115 - 128, Before
performing another radix pass, handle the terminal case when the current bucket
already fits: if l_num_input <= l_new_topk (in the loop using l_new_topk,
l_num_input, r_idx, s_num_input and s_threshold_bin_id), treat this as a
terminal acceptance of the whole threshold bucket — update the shared selection
state (set s_threshold_bin_id[0] to the current bucket indicator / accepted
sentinel and ensure any per-round counters like s_num_input[r_idx ^ 1] are set
appropriately) and break out of the loop so no further radix passes run and
s_threshold_bin_id[0] cannot remain stale.

Comment on lines +175 to 193
for i_i in T.Pipelined(NS, num_stages=num_stages):
# Check which indices are valid
for bi_i in T.Parallel(BS):
mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (Indices[bos + s_i, bz, i_i * BS + bi_i] != -1)

T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
# Compute attention scores
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype))

for h_i, bi_i in T.Parallel(padded_H, BS):
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale
# Load KV, V for this block of indices
for bi_i, d_i in T.Parallel(BS, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, d_i]

T.copy(acc_dp, dP_shared_cast)
T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)

T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
for bi_i, d_i in T.Parallel(BS, D_tail):
KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, D + d_i]
T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Masking does not protect these KV loads.

The kernel marks invalid entries in mask, but Line 186 and Line 191 still dereference Indices[...] unconditionally. With padded slots stored as S/-1, this can walk outside the valid KV range before the softmax mask suppresses the lane.

Comment on lines +115 to +122
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)

for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Don't index KV with masked-out entries.

The -1/causal checks only affect mask; Line 120 and Line 122 still dereference every sparse slot. Any padded entry (-1 or the current S sentinel left by the generator) can read outside the valid sequence range before acc_s is masked.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/dsa_sparse_finetune/sparse_mla_fwd.py` around lines 115 - 122, The
code currently always dereferences KV using Indices even for masked-out slots;
update the KV loads inside the pipelined loop (the assignments to KV_shared and
K_tail_shared) to first fetch the index (e.g., idx = Indices[bos + s_i, g_i, i_i
* BI + bi_i]) and compute a validity predicate (the same check used for mask:
idx != -1 and idx <= max_kv_i), then perform a guarded load: if valid assign
KV_shared[bi_i,d_i] = KV[bos + idx, g_i, d_i] (or K_tail_shared[...] = KV[bos +
idx, g_i, D + d_i]) else assign a safe default (e.g., zero) so you never index
KV with out-of-range/padded indices; use the same loop scopes (for bi_i, d_i in
T.Parallel(BI, D) and for bi_i, d_i in T.Parallel(BI, D_tail)) and reuse
Indices, mask, KV_shared, and K_tail_shared names to locate the changes.

Comment on lines +111 to +118
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)

for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard the KV loads with a safe index.

mask is computed here, but Line 116 and Line 118 still index KV for every lane. If Indices contains -1 or any pruned entry beyond max_kv_i, this kernel can read invalid memory before the masked softmax zeroes that lane out.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py` around lines 111 -
118, The KV loads into KV_shared and K_tail_shared in the nested loops use
Indices directly even when mask is false; change those loads in the loops that
write KV_shared and K_tail_shared (the loops referencing KV[bos + Indices[bos +
s_i, g_i, i_i * BI + bi_i], g_i, d_i] and KV[..., D + d_i]) to guard with the
computed mask by either computing a safe_index (e.g., replace out-of-range or -1
indices with a valid dummy index) or conditionally loading/storing only when
mask[bi_i] is true and writing zeros otherwise so no invalid KV memory is read
when Indices is -1 or > max_kv_i.

Comment on lines +194 to +197
result = flashattn(q, k, v, is_causal=is_causal)
best_latency = result.latency
best_config = result.config
ref_latency = result.ref_latency
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Pass groups through the tune path.

Line 194 currently falls back to groups=1. For GQA inputs, that makes the tuned kernel index K and V with the query-head id instead of by // groups, so tune mode can walk past the KV-head dimension and benchmark a different kernel than the non-tuned path.

🩹 Minimal fix
-        result = flashattn(q, k, v, is_causal=is_causal)
+        result = flashattn(q, k, v, is_causal=is_causal, groups=groups)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
result = flashattn(q, k, v, is_causal=is_causal)
best_latency = result.latency
best_config = result.config
ref_latency = result.ref_latency
result = flashattn(q, k, v, is_causal=is_causal, groups=groups)
best_latency = result.latency
best_config = result.config
ref_latency = result.ref_latency
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention/example_gqa_fwd_bshd.py` around lines 194 - 197, The
call to flashattn(q, k, v, is_causal=is_causal) is omitting the groups argument
causing the tuning path to assume groups=1; update the tuned-call site to pass
the runtime groups variable (e.g., flashattn(q, k, v, is_causal=is_causal,
groups=groups)) so that the tuning selection uses the same groups value as the
non-tuned path (ensure any other tune-related invocations that construct or
compare configs also propagate the groups field).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant