[Example] Flash Attention SM100#1910
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds Blackwell (SM100) FlashAttention forward and backward implementations (GQA and MHA) in BSHD layout, preprocessing/postprocessing helpers, CPU references and CLI harnesses, plus a fix to tcgen05 tensor-memory copy offset computation in Changes
Sequence DiagramsequenceDiagram
actor User
participant Forward as Forward Kernels
participant Preproc as Preprocess (Delta)
participant Backward as Backward Kernels
participant Postproc as Postprocess (dQ Layout)
participant Memory as Global/Shared Memory
User->>Forward: provide Q, K, V
Forward->>Memory: store O, LSE
Forward-->>User: return O, LSE
User->>Preproc: provide O, dO
Preproc->>Memory: store Delta
Preproc-->>User: return Delta
User->>Backward: Q, K, V, dO, Delta
Backward->>Memory: accumulate dK, dV (atomic if groups>1)
Backward->>Memory: compute dQ blocks
Backward-->>User: return dQ, dK, dV
User->>Postproc: dQ (accumulated layout)
Postproc->>Memory: convert to standard layout
Postproc-->>User: return dQ (final layout)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can use OpenGrep to find security vulnerabilities and bugs across 17+ programming languages.OpenGrep is compatible with Semgrep configurations. Add an |
There was a problem hiding this comment.
Actionable comments posted: 8
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/flash_attention_sm100/gqa_bwd_bshd.py`:
- Around line 24-31: The code computes head_kv = heads // groups without
validating groups which can be zero or non-dividing; add a guard in
flashattn_fwd, flashattn_bwd and main before computing head_kv to ensure groups
is a positive integer and that heads % groups == 0 (e.g., raise/throw an error
or assert if groups <= 0 or heads % groups != 0) so kv_shape and subsequent
bx//groups indexing cannot underflow or index past KV heads; reference the
head_kv computation sites in flashattn_fwd, flashattn_bwd and the main function
and perform the same validation there.
- Around line 259-274: ref_program() implements only the forward reference, but
main() still runs backward kernels and unconditionally prints "correctness run
OK." without verifying dQ_out, dK, or dV; either implement a matching backward
reference or stop claiming success for backward runs. Fix by one of two options:
(A) implement a CPU reference backward (e.g., add ref_program_backward or extend
ref_program to return gradients for dQ_out, dK, dV using numerical/analytical
accumulation matching the grouped atomic behavior) and compare those to the
kernel outputs (dQ_out, dK, dV) before printing success; or (B) if you cannot
verify backward, guard the backward execution in main() and skip the backward
comparisons/printing (or print a warning) when ref_program does not provide
gradients. Locate and change the logic around ref_program, main, and the
variables dQ_out, dK, dV to ensure backward results are actually validated
before emitting "correctness run OK."
In `@examples/flash_attention_sm100/gqa_fwd_bshd.py`:
- Around line 24-39: Validate the groups parameter before computing head_kv to
prevent divide-by-zero and non-divisor truncation: in flashattn_ss (and
similarly in flashattn_ts and main) add a guard that ensures groups is a
positive integer and that heads % groups == 0 before doing head_kv = heads //
groups; if the check fails, raise or return a clear error (e.g., ValueError)
explaining that groups must be >0 and evenly divide heads so the by // groups
indexing for K/V heads is safe.
In `@examples/flash_attention_sm100/mha_bwd_bshd.py`:
- Around line 254-268: The test never verifies gradients: modify the test to
compute reference gradients from ref_program(Q,K,V,is_causal) by setting
Q,K,V.requires_grad_=True, running ref_program to produce O_ref, calling a
simple scalar loss (e.g., O_ref.sum()) and backward() to capture ref dQ_ref,
dK_ref, dV_ref, then compare those to the kernel outputs dQ_out, dK, dV (using
torch.allclose or appropriate atol/rtol) and assert failures; update both the
backward-check in main() and the similar check around the other block (the spot
referenced at lines ~293-300) to perform the same gradient computation and
comparison against the kernel results.
In `@examples/flash_attention_sm100/mha_fwd_bshd.py`:
- Around line 674-676: Re-enable the validation by restoring the equality check:
call torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) (using the
existing ref_program/Q/K/V result) and only print "Correctness check passed."
after the assert completes successfully; ensure the assert compares out and ref
on the same device (ref_program(...).to(out.device)) and remove the stray
unconditional print so failures aren’t masked.
- Around line 494-523: Insert waits on the GEMM2 completion barrier so softmax
and the epilogue don't read O_tmem before TCGEN05MMA finishes: in the softmax
warp, call T.mbarrier_wait_parity(mbar_bmm2_full[stage_id], parity) immediately
before the softmax code that reads O_tmem (i.e., prior to the current wait on
mbar_softmax_empty and the load of O_tmem), and in the epilogue replace or add
the wait on mbar_correction_full[0] with
T.mbarrier_wait_parity(mbar_bmm2_full[0], parity) (or wait on both if needed) so
the epilogue reads O_tmem only after mbar_bmm2_full signals completion.
In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py`:
- Around line 78-80: The script computes ref_c but the equality check is
commented out causing false "All checks passed" results; restore the correctness
gate by re-enabling torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
(or replace with an equivalent assertion) before the final print, and ensure c
and ref_c are the tensors compared (variables named c and ref_c in the diff) so
the script fails when outputs differ instead of always printing success.
- Around line 69-71: The chosen tile and staging sizes (block_M, block_N,
block_K, num_stages) allocate ~448KiB of shared memory for
A_shared+B_shared+C_shared which exceeds the per-CTA shared-memory limit; adjust
these constants to reduce shared-memory usage so the kernel can compile/launch.
Locate the declarations of block_M, block_N, block_K and num_stages and either
reduce block_N (e.g., 256 -> 128) or reduce num_stages (e.g., 4 -> 2), or choose
a smaller combination (for example block_M=128, block_N=128, block_K=128,
num_stages=2) until A_shared+B_shared+C_shared fits within the target per-block
shared-memory budget. Ensure any dependent launch/configuration logic that
assumes the previous tile shapes is updated accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5373cfdb-330a-4478-b77a-da3d1063462c
📒 Files selected for processing (7)
examples/flash_attention_sm100/gqa_bwd_bshd.pyexamples/flash_attention_sm100/gqa_fwd_bshd.pyexamples/flash_attention_sm100/mha_bwd_bshd.pyexamples/flash_attention_sm100/mha_fwd_bshd.pyexamples/flash_attention_sm100/src.cuexamples/gemm_sm100/gemm_tcgen5mma.pyexamples/gemm_sm100/gemm_tcgen5mma_ws.py
| def ref_program(Q, K, V, is_causal): | ||
| dim = Q.size(-1) | ||
| Q_f = Q.cpu().float() | ||
| K_f = K.cpu().float() | ||
| V_f = V.cpu().float() | ||
| scores = torch.einsum("bqhd,bkhd->bhqk", Q_f, K_f) | ||
| scores = scores / (dim**0.5) | ||
| if is_causal: | ||
| seq_len = Q_f.size(1) | ||
| mask = torch.tril(torch.ones(seq_len, seq_len)) | ||
| mask = mask.unsqueeze(0).unsqueeze(0) | ||
| scores = scores.masked_fill(mask == 0, float("-inf")) | ||
| P = F.softmax(scores, dim=-1) | ||
| O = torch.einsum("bhqk,bkhd->bqhd", P, V_f) | ||
| return O.to(Q.dtype) |
There was a problem hiding this comment.
The backward example never checks any gradients.
ref_program() only computes a forward output, and main() just runs the kernels then prints a success message. dQ_out, dK, and dV are never compared against a reference, so silent gradient bugs in this new kernel will look like passes.
Also applies to: 293-300
🧰 Tools
🪛 Ruff (0.15.4)
[error] 267-267: Ambiguous variable name: O
(E741)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/flash_attention_sm100/mha_bwd_bshd.py` around lines 254 - 268, The
test never verifies gradients: modify the test to compute reference gradients
from ref_program(Q,K,V,is_causal) by setting Q,K,V.requires_grad_=True, running
ref_program to produce O_ref, calling a simple scalar loss (e.g., O_ref.sum())
and backward() to capture ref dQ_ref, dK_ref, dV_ref, then compare those to the
kernel outputs dQ_out, dK, dV (using torch.allclose or appropriate atol/rtol)
and assert failures; update both the backward-check in main() and the similar
check around the other block (the spot referenced at lines ~293-300) to perform
the same gradient computation and comparison against the kernel results.
…sert_close, docstrings
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
examples/flash_attention_sm100/mha_bwd_bshd.py (1)
241-289:⚠️ Potential issue | 🟠 MajorAdd a real gradient oracle for the backward example.
ref_program()is still forward-only, andmain()never comparesdQ_out,dK, ordVto anything. The updated log line is more honest, but this path still only proves the kernels run, not that the backward math is correct. At minimum, run a small autograd reference and assert the gradients before treating this example as validated.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention_sm100/mha_bwd_bshd.py` around lines 241 - 289, The test only runs kernels but doesn't verify backward math; add a gradient oracle using autograd and compare it to kernel outputs: use ref_program(Q,K,V,is_causal) (or recompute a differentiable CPU/float32 forward) with requires_grad=True on Q,K,V, run a small randomized case (reduce batch/seq_len/heads/dim for speed), call torch.autograd.grad to get ref dQ_ref,dK_ref,dV_ref, then run kernel_fwd/kernel_prep/kernel_bwd/kernel_post as before to produce dQ_out,dK_out,dV_out and assert they are close (torch.allclose with appropriate rtol/atol). Update main() to perform this check (and still keep the runtime smoke test) and reference the symbols kernel_fwd, kernel_prep, kernel_bwd, kernel_post, ref_program, and the dQ/dK/dV variables so the changes are local and discoverable.examples/flash_attention_sm100/gqa_bwd_bshd.py (1)
270-322:⚠️ Potential issue | 🟠 MajorPlease verify the grouped backward path against a reference.
This still only checks that the kernels execute.
ref_program()is forward-only,dQ_outis discarded, and the groupedatomic_addpath fordK/dVis never compared against autograd or another oracle, so silent gradient regressions will slip through.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention_sm100/gqa_bwd_bshd.py` around lines 270 - 322, The test only runs the grouped backward kernel but never verifies its dQ/dK/dV outputs against a reference; use ref_program to create a differentiable forward reference and compute autograd gradients for Q, K, V, then compare those to the outputs from kernel_bwd. Concretely: make Q,K,V require_grad (float), compute out_ref = ref_program(Q, K, V, is_causal, groups), compute scalar loss = (out_ref.to(dO.dtype) * dO).sum() or loss = (out_ref * dO.float()).sum(), call torch.autograd.grad to get ref_dQ, ref_dK, ref_dV, and then assert close(ref_dQ, dQ_out_from kernel_post or dQ), ref_dK vs dK, ref_dV vs dV within tolerances; keep using kernel_fwd/kernel_prep/kernel_bwd/kernel_post flow but add this autograd reference path and comparisons to detect regressions in the grouped atomic_add path.
🧹 Nitpick comments (1)
examples/flash_attention_sm100/mha_fwd_bshd.py (1)
563-564: Makeflashattn_warpmean the same thing across the SM100 examples.Here the alias points to
flashattn_wasp, whileexamples/flash_attention_sm100/gqa_fwd_bshd.py:309points the same name atflashattn_ts. That makes shared imports/tooling select materially different pipelines under one symbol. Prefer exporting the concrete variant name or aligning the alias target across both modules.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention_sm100/mha_fwd_bshd.py` around lines 563 - 564, The alias flashattn_warp is inconsistent across SM100 examples (here it's set to flashattn_wasp while another module sets flashattn_warp = flashattn_ts); update this module so flashattn_warp points to the same concrete implementation used in the other SM100 examples (or remove the alias and export the concrete name directly) by changing the alias assignment from flashattn_wasp to the agreed concrete symbol (e.g., flashattn_ts) or by replacing usages to import/export the concrete symbol instead of flashattn_warp (adjust references to flashattn_warp, flashattn_wasp, and flashattn_ts accordingly).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/flash_attention_sm100/mha_fwd_bshd.py`:
- Around line 389-395: The code only primes barrier index 0 before entering the
WASP loop, causing later iterations to wait on unprimed stages and potentially
deadlock when num_stages>1; change the priming logic in the is_bmm_warp block to
loop over all stages (0..num_stages-1) and call T.mbarrier_arrive for each of
mbar_dma1_empty, mbar_dma2_empty, mbar_bmm1_empty, and mbar_softmax_empty (use
the same indexing style as mbar_*[stage_id]) so every pipeline stage is primed
before the WASP loop.
---
Duplicate comments:
In `@examples/flash_attention_sm100/gqa_bwd_bshd.py`:
- Around line 270-322: The test only runs the grouped backward kernel but never
verifies its dQ/dK/dV outputs against a reference; use ref_program to create a
differentiable forward reference and compute autograd gradients for Q, K, V,
then compare those to the outputs from kernel_bwd. Concretely: make Q,K,V
require_grad (float), compute out_ref = ref_program(Q, K, V, is_causal, groups),
compute scalar loss = (out_ref.to(dO.dtype) * dO).sum() or loss = (out_ref *
dO.float()).sum(), call torch.autograd.grad to get ref_dQ, ref_dK, ref_dV, and
then assert close(ref_dQ, dQ_out_from kernel_post or dQ), ref_dK vs dK, ref_dV
vs dV within tolerances; keep using
kernel_fwd/kernel_prep/kernel_bwd/kernel_post flow but add this autograd
reference path and comparisons to detect regressions in the grouped atomic_add
path.
In `@examples/flash_attention_sm100/mha_bwd_bshd.py`:
- Around line 241-289: The test only runs kernels but doesn't verify backward
math; add a gradient oracle using autograd and compare it to kernel outputs: use
ref_program(Q,K,V,is_causal) (or recompute a differentiable CPU/float32 forward)
with requires_grad=True on Q,K,V, run a small randomized case (reduce
batch/seq_len/heads/dim for speed), call torch.autograd.grad to get ref
dQ_ref,dK_ref,dV_ref, then run kernel_fwd/kernel_prep/kernel_bwd/kernel_post as
before to produce dQ_out,dK_out,dV_out and assert they are close (torch.allclose
with appropriate rtol/atol). Update main() to perform this check (and still keep
the runtime smoke test) and reference the symbols kernel_fwd, kernel_prep,
kernel_bwd, kernel_post, ref_program, and the dQ/dK/dV variables so the changes
are local and discoverable.
---
Nitpick comments:
In `@examples/flash_attention_sm100/mha_fwd_bshd.py`:
- Around line 563-564: The alias flashattn_warp is inconsistent across SM100
examples (here it's set to flashattn_wasp while another module sets
flashattn_warp = flashattn_ts); update this module so flashattn_warp points to
the same concrete implementation used in the other SM100 examples (or remove the
alias and export the concrete name directly) by changing the alias assignment
from flashattn_wasp to the agreed concrete symbol (e.g., flashattn_ts) or by
replacing usages to import/export the concrete symbol instead of flashattn_warp
(adjust references to flashattn_warp, flashattn_wasp, and flashattn_ts
accordingly).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8f52ee5a-fc74-43ac-99fd-70ea9e11aa6c
📒 Files selected for processing (4)
examples/flash_attention_sm100/gqa_bwd_bshd.pyexamples/flash_attention_sm100/gqa_fwd_bshd.pyexamples/flash_attention_sm100/mha_bwd_bshd.pyexamples/flash_attention_sm100/mha_fwd_bshd.py
| # Prime empty barriers so first iteration can proceed (phase 1 for parity_inv=1 at k=0) | ||
| if is_bmm_warp: | ||
| T.mbarrier_arrive(mbar_dma1_empty[0]) | ||
| T.mbarrier_arrive(mbar_dma2_empty[0]) | ||
| T.mbarrier_arrive(mbar_bmm1_empty[0]) | ||
| T.mbarrier_arrive(mbar_softmax_empty[0]) | ||
|
|
There was a problem hiding this comment.
Prime every pipeline stage before entering the WASP loop.
Only stage 0 is marked empty here, but Line 413, Line 425, Line 437, and Line 492 all wait on stage_id. With the default num_stages=2, the k=1 iteration waits on stage 1 with parity 1 even though nothing has ever arrived there, so the pipeline can deadlock as soon as loop_range > 1.
🔧 Minimal fix
# Prime empty barriers so first iteration can proceed (phase 1 for parity_inv=1 at k=0)
if is_bmm_warp:
- T.mbarrier_arrive(mbar_dma1_empty[0])
- T.mbarrier_arrive(mbar_dma2_empty[0])
- T.mbarrier_arrive(mbar_bmm1_empty[0])
- T.mbarrier_arrive(mbar_softmax_empty[0])
+ for stage in range(num_stages):
+ T.mbarrier_arrive(mbar_dma1_empty[stage])
+ T.mbarrier_arrive(mbar_dma2_empty[stage])
+ T.mbarrier_arrive(mbar_bmm1_empty[stage])
+ T.mbarrier_arrive(mbar_softmax_empty[stage])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/flash_attention_sm100/mha_fwd_bshd.py` around lines 389 - 395, The
code only primes barrier index 0 before entering the WASP loop, causing later
iterations to wait on unprimed stages and potentially deadlock when
num_stages>1; change the priming logic in the is_bmm_warp block to loop over all
stages (0..num_stages-1) and call T.mbarrier_arrive for each of mbar_dma1_empty,
mbar_dma2_empty, mbar_bmm1_empty, and mbar_softmax_empty (use the same indexing
style as mbar_*[stage_id]) so every pipeline stage is primed before the WASP
loop.
…> 128 threads or > 1 WG, merged flash_atten ts & ss kernels into one, specify by extra argment
There was a problem hiding this comment.
🧹 Nitpick comments (2)
examples/flash_attention_sm100/gqa_bwd_bshd.py (1)
127-131: Consider renaming ambiguous variablel→seqin layout lambda.The variable
l(line 130) can be confused with1. Ruff flags this as E741.Suggested fix
def make_dq_layout(dQ): return T.Layout( dQ.shape, - lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2], + lambda b, seq, h, d: [b, seq // 8, h, d // 8, (d % 2), 4 * (seq % 8) + (d % 8) // 2], )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention_sm100/gqa_bwd_bshd.py` around lines 127 - 131, Rename the ambiguous lambda parameter l to seq inside make_dq_layout to avoid confusion with the digit 1 and satisfy Ruff E741: update the lambda signature from "lambda b, l, h, d" to "lambda b, seq, h, d" and replace all uses of l in the lambda body (e.g., "l // 8" and "4 * (l % 8)") with seq so the layout construction in make_dq_layout continues to compute correctly.examples/flash_attention_sm100/gqa_fwd_bshd.py (1)
197-198: Consider renaming ambiguous variableO→out.The variable name
Ocan be confused with0. This is flagged by Ruff (E741).Suggested fix
- O = torch.einsum("bhqk,bkhd->bqhd", P, V_f) - return O.to(Q.dtype) + out = torch.einsum("bhqk,bkhd->bqhd", P, V_f) + return out.to(Q.dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention_sm100/gqa_fwd_bshd.py` around lines 197 - 198, Rename the ambiguous variable O to a clearer name like out in the einsum assignment and in the return; specifically change the line "O = torch.einsum(\"bhqk,bkhd->bqhd\", P, V_f)" to assign to out and update "return O.to(Q.dtype)" to "return out.to(Q.dtype)" so the variable isn't confused with the digit zero (affects the local scope where P, V_f, and Q.dtype are used).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@examples/flash_attention_sm100/gqa_bwd_bshd.py`:
- Around line 127-131: Rename the ambiguous lambda parameter l to seq inside
make_dq_layout to avoid confusion with the digit 1 and satisfy Ruff E741: update
the lambda signature from "lambda b, l, h, d" to "lambda b, seq, h, d" and
replace all uses of l in the lambda body (e.g., "l // 8" and "4 * (l % 8)") with
seq so the layout construction in make_dq_layout continues to compute correctly.
In `@examples/flash_attention_sm100/gqa_fwd_bshd.py`:
- Around line 197-198: Rename the ambiguous variable O to a clearer name like
out in the einsum assignment and in the return; specifically change the line "O
= torch.einsum(\"bhqk,bkhd->bqhd\", P, V_f)" to assign to out and update "return
O.to(Q.dtype)" to "return out.to(Q.dtype)" so the variable isn't confused with
the digit zero (affects the local scope where P, V_f, and Q.dtype are used).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 55f64213-597b-499f-b32b-7714b1258a1c
📒 Files selected for processing (4)
examples/flash_attention_sm100/gqa_bwd_bshd.pyexamples/flash_attention_sm100/gqa_fwd_bshd.pyexamples/flash_attention_sm100/mha_fwd_bshd.pysrc/op/copy.cc
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/flash_attention_sm100/mha_fwd_bshd.py
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (1)
examples/flash_attention_sm100/mha_fwd_bshd.py (1)
257-265:⚠️ Potential issue | 🔴 CriticalPrime the WASP
*_emptybarriers beforek == 0.Line 264, Line 275, Line 286, and Line 339 all wait on
*_empty[stage_id], but there is no visible pre-loopmbarrier_arrivefor any stage. If these barriers are not implicitly initialized to the waited parity, the first DMA/BMM/softmax handoff can deadlock immediately.🔧 Minimal fix
loop_range = ( T.min( T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N), ) if is_causal else T.ceildiv(seq_len, block_N) ) + + if tid < 128: + for stage in T.serial(num_stages): + T.mbarrier_arrive(mbar_bmm1_empty[stage]) + elif tid >= 160 and tid < 192: + for stage in T.serial(num_stages): + T.mbarrier_arrive(mbar_dma1_empty[stage]) + T.mbarrier_arrive(mbar_dma2_empty[stage]) + T.mbarrier_arrive(mbar_softmax_empty[stage]) for k in T.serial(loop_range):Run this read-only check to confirm there is no priming block before the first waits:
#!/bin/bash sed -n '248,345p' examples/flash_attention_sm100/mha_fwd_bshd.py printf '\n-- *_empty waits/arrives --\n' rg -n -C1 'mbarrier_(wait_parity|arrive)\(mbar_(dma1_empty|dma2_empty|bmm1_empty|softmax_empty)\[' examples/flash_attention_sm100/mha_fwd_bshd.pyExpected result: at least one
mbarrier_arrive(..._empty[stage])block should appear beforefor k in T.serial(loop_range):. Right now all visible*_emptyarrives are inside the loop, after the first waits.Also applies to: 275-286, 339-340
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/flash_attention_sm100/mha_fwd_bshd.py` around lines 257 - 265, The waits on the WASP barriers (calls to mbarrier_wait_parity using mbar_dma1_empty, mbar_dma2_empty, mbar_bmm1_empty, mbar_softmax_empty inside the for k in T.serial(loop_range) loop) must be primed before entering the loop to avoid deadlock; add a pre-loop block that calls mbarrier_arrive on each *_empty barrier for every stage_id (or at least for the stages used) with the parity expected by the first waits (the inverse parity used in the existing mbarrier_wait_parity calls) so the initial waited parity is satisfied when k==0 (use the same stage_id calculation/num_stages logic as in the loop and mirror parity/parity_inv computations).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/flash_attention_sm100/gqa_fwd_bshd.py`:
- Around line 178-179: The exported flashattn_ts currently aliases flashattn (so
it inherits the default variant="ss"); replace the direct assignment with a thin
wrapper named flashattn_ts that calls flashattn but forces variant="ts" (e.g.,
accept *args/**kwargs and pass through while setting variant="ts" or overriding
any variant in kwargs) and use functools.wraps to preserve metadata; keep
flashattn_ss as the plain alias to flashattn for the "ss" variant.
- Around line 262-270: The loop deadlocks because the mbarrier objects
(mbar_dma1_empty, mbar_dma0_empty, mbar_clear_accum) are initialized at parity 0
but the first iteration immediately waits for parity 1; to fix this, add a
pre-loop priming step that calls T.mbarrier_arrive_parity(...) with parity 1 for
each barrier that the loop will wait on (mirror the same tid-range conditionals
used inside the for k in T.serial(loop_range) loop), i.e. before the loop
execute the appropriate T.mbarrier_arrive_parity(mbar_dma1_empty, 1),
T.mbarrier_arrive_parity(mbar_dma0_empty, 1), and
T.mbarrier_arrive_parity(mbar_clear_accum, 1) under the same tid checks so the
initial T.mbarrier_wait_parity calls (in the loop) will find parity 1 and not
stall.
In `@examples/flash_attention_sm100/mha_fwd_bshd.py`:
- Around line 441-463: The current branch always calls flashattn_wasp for any
non-"ss"/"ts" variant; change the control flow so you handle three cases: if
variant in ("ss","ts") call flashattn(..., variant=variant); elif variant ==
"wasp" attempt to build kernel = flashattn_wasp(...) inside a try/except and on
exception fall back to kernel = flashattn(..., variant="ts") (preserve the same
block_M/block_N/threads/num_stages args when attempting WASP), and else raise a
clear ValueError for unknown variant values so typos don't default to WASP.
- Around line 178-179: The current aliases flashattn_ss = flashattn and
flashattn_ts = flashattn simply re-export the same function so calling
flashattn_ts(...) still builds the SS kernel; replace these aliases with thin
wrapper functions named flashattn_ss(...) and flashattn_ts(...) that call the
original flashattn(...) while forcing variant="ss" and variant="ts" respectively
(pass through all other args/kwargs) so each wrapper reliably selects the
intended kernel variant.
---
Duplicate comments:
In `@examples/flash_attention_sm100/mha_fwd_bshd.py`:
- Around line 257-265: The waits on the WASP barriers (calls to
mbarrier_wait_parity using mbar_dma1_empty, mbar_dma2_empty, mbar_bmm1_empty,
mbar_softmax_empty inside the for k in T.serial(loop_range) loop) must be primed
before entering the loop to avoid deadlock; add a pre-loop block that calls
mbarrier_arrive on each *_empty barrier for every stage_id (or at least for the
stages used) with the parity expected by the first waits (the inverse parity
used in the existing mbarrier_wait_parity calls) so the initial waited parity is
satisfied when k==0 (use the same stage_id calculation/num_stages logic as in
the loop and mirror parity/parity_inv computations).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 80f39a49-cae2-4ba3-ba28-8529b3ccd031
📒 Files selected for processing (2)
examples/flash_attention_sm100/gqa_fwd_bshd.pyexamples/flash_attention_sm100/mha_fwd_bshd.py
| flashattn_ss = flashattn | ||
| flashattn_ts = flashattn |
There was a problem hiding this comment.
flashattn_ts still exports the SS configuration.
flashattn_ts = flashattn preserves the default variant="ss", so direct callers of flashattn_ts(...) silently get the wrong kernel unless they override the variant themselves. A small wrapper that pins variant="ts" avoids that trap.
🔧 Minimal fix
-flashattn_ss = flashattn
-flashattn_ts = flashattn
+def flashattn_ss(batch, heads, seq_len, dim, is_causal, groups=1, block_M=128, block_N=128):
+ return flashattn(
+ batch,
+ heads,
+ seq_len,
+ dim,
+ is_causal,
+ groups=groups,
+ block_M=block_M,
+ block_N=block_N,
+ variant="ss",
+ )
+
+
+def flashattn_ts(batch, heads, seq_len, dim, is_causal, groups=1, block_M=128, block_N=128):
+ return flashattn(
+ batch,
+ heads,
+ seq_len,
+ dim,
+ is_causal,
+ groups=groups,
+ block_M=block_M,
+ block_N=block_N,
+ variant="ts",
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/flash_attention_sm100/gqa_fwd_bshd.py` around lines 178 - 179, The
exported flashattn_ts currently aliases flashattn (so it inherits the default
variant="ss"); replace the direct assignment with a thin wrapper named
flashattn_ts that calls flashattn but forces variant="ts" (e.g., accept
*args/**kwargs and pass through while setting variant="ts" or overriding any
variant in kwargs) and use functools.wraps to preserve metadata; keep
flashattn_ss as the plain alias to flashattn for the "ss" variant.
| for k in T.serial(loop_range): | ||
| parity = (k // num_stages) & 1 | ||
| parity_inv = parity ^ 1 | ||
| stage_id = k % num_stages | ||
| is_clear_accum = k == 0 | ||
|
|
||
| if tid >= 128 and tid < 160: | ||
| T.mbarrier_wait_parity(mbar_dma1_empty[stage_id], parity_inv) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
sed -n '253,350p' examples/flash_attention_sm100/gqa_fwd_bshd.py
printf '\n-- *_empty waits/arrives --\n'
rg -n -C1 'mbarrier_(wait_parity|arrive)\(mbar_(dma1_empty|dma2_empty|bmm1_empty|softmax_empty)\[' examples/flash_attention_sm100/gqa_fwd_bshd.pyRepository: tile-ai/tilelang
Length of output: 5214
🏁 Script executed:
# Search for alloc_barrier function and barrier initialization
rg -n 'alloc_barrier' examples/flash_attention_sm100/gqa_fwd_bshd.pyRepository: tile-ai/tilelang
Length of output: 841
🏁 Script executed:
# Also search for where the *_empty barriers are created/allocated
rg -n 'mbar_(dma1_empty|dma2_empty|bmm1_empty|softmax_empty)' examples/flash_attention_sm100/gqa_fwd_bshd.py | head -50Repository: tile-ai/tilelang
Length of output: 965
🏁 Script executed:
# Look at the full context around barrier allocations
sed -n '1,260p' examples/flash_attention_sm100/gqa_fwd_bshd.py | tail -100Repository: tile-ai/tilelang
Length of output: 3781
🏁 Script executed:
# Search for alloc_barrier implementation or documentation
fd -t f '\.py$' | xargs rg -l 'alloc_barrier' | head -20Repository: tile-ai/tilelang
Length of output: 1170
🏁 Script executed:
# Look for any documentation or implementation of alloc_barrier
rg -A 5 'def alloc_barrier' --type pyRepository: tile-ai/tilelang
Length of output: 443
🏁 Script executed:
# Check if there's any other initialization happening around barrier creation
rg -B 2 -A 2 'alloc_barrier' examples/flash_attention_sm100/ --type pyRepository: tile-ai/tilelang
Length of output: 3940
🏁 Script executed:
# Check the full alloc_barrier implementation
sed -n '/def alloc_barrier/,/^[^ ]/p' tilelang/language/allocate.pyRepository: tile-ai/tilelang
Length of output: 1181
🏁 Script executed:
# Search for any barrier initialization or parity semantics in the codebase
rg -i 'barrier.*parity|parity.*init' --type py | head -20Repository: tile-ai/tilelang
Length of output: 2328
🏁 Script executed:
# Look at how barriers are used in other examples to see if priming is done elsewhere
rg -B 3 'for.*T.serial.*loop_range' examples/flash_attention_sm100/*.py | head -40Repository: tile-ai/tilelang
Length of output: 652
🏁 Script executed:
# Look for any barrier initialization semantics or comments
rg -B 3 -A 3 'barrier_init' tilelang/language/allocate.pyRepository: tile-ai/tilelang
Length of output: 712
🏁 Script executed:
# Check if there's a lowering pass that handles barrier initialization
rg -l 'barrier_init|mbarrier' tilelang/compiler/passes/*.py 2>/dev/null | head -10Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for any comments or documentation about barrier parity initialization
rg -i 'initial.*parity|parity.*0|parity.*1|barrier.*state' examples/ --type py | head -20Repository: tile-ai/tilelang
Length of output: 2745
🏁 Script executed:
# Look at the warp_specialize examples for barrier usage patterns
cat examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py | grep -A 50 'alloc_barrier' | head -80Repository: tile-ai/tilelang
Length of output: 1787
🏁 Script executed:
# Look for any compiler pass that handles barrier initialization
find tilelang -name "*.py" -o -name "*.cc" -o -name "*.h" | xargs rg -l 'mbarrier.*init|initialize.*barrier' 2>/dev/null | head -10Repository: tile-ai/tilelang
Length of output: 148
🏁 Script executed:
# Check if there's documentation in the allocate.py file about initial barrier state
cat tilelang/language/allocate.py | head -100Repository: tile-ai/tilelang
Length of output: 4125
🏁 Script executed:
# Look for test cases involving barriers and parity to understand expected behavior
rg -B 5 -A 10 'mbarrier_wait_parity.*0\)' testing/ --type py | head -60Repository: tile-ai/tilelang
Length of output: 5792
🏁 Script executed:
# Search for ptx_init_barrier_thread_count or ptx_fence_barrier_init in the flash attention files
rg -n 'ptx_init_barrier_thread_count|ptx_fence_barrier_init' examples/flash_attention_sm100/Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Check if there's barrier initialization happening in the kernel
sed -n '230,270p' examples/flash_attention_sm100/gqa_fwd_bshd.pyRepository: tile-ai/tilelang
Length of output: 1713
🏁 Script executed:
# Look for any compiler transformation that automatically injects barrier init
rg -l 'ptx_init_barrier|ptx_fence_barrier_init' tilelang/ --type pyRepository: tile-ai/tilelang
Length of output: 143
🏁 Script executed:
# Search for barrier initialization in the compiler
rg -n 'ptx_init_barrier|ptx_fence_barrier_init' tilelang/language/tir/Repository: tile-ai/tilelang
Length of output: 774
🏁 Script executed:
# Check the IR/op definitions for barrier init
rg -B 2 -A 5 'def ptx_init_barrier_thread_count|def ptx_fence_barrier_init' tilelang/Repository: tile-ai/tilelang
Length of output: 1675
🏁 Script executed:
# Look for compiler passes that might auto-inject barrier initialization
fd -t f '\.(py|cc|h)$' tilelang/compiler | head -20Repository: tile-ai/tilelang
Length of output: 164
🏁 Script executed:
# Check if there's a specific pass file for barriers
find tilelang -name "*.py" | xargs rg -l 'barrier.*init.*pass|init.*barrier.*inject' 2>/dev/nullRepository: tile-ai/tilelang
Length of output: 44
🏁 Script executed:
# Search for how ptx_init_barrier_thread_count is used in the codebase
rg -B 3 -A 3 'ptx_init_barrier_thread_count' examples/ testing/Repository: tile-ai/tilelang
Length of output: 7342
🏁 Script executed:
# Check if gqa_fwd_bshd.py has any use of ptx_init_barrier or ptx_fence_barrier_init
rg 'ptx_init_barrier|ptx_fence_barrier_init' examples/flash_attention_sm100/gqa_fwd_bshd.pyRepository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for the lowering pass that handles barriers
find tilelang -name "*.py" -path "*/passes/*" | xargs cat 2>/dev/null | grep -l "barrier\|lower" | head -10Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for references to barrier initialization in C++ or compiler code
find tilelang -name "*.cc" -o -name "*.h" 2>/dev/null | xargs rg -l 'barrier.*init' 2>/dev/null | head -5Repository: tile-ai/tilelang
Length of output: 219
🏁 Script executed:
# Check the lower_shared_barrier.cc file
cat src/transform/lower_shared_barrier.cc | head -150Repository: tile-ai/tilelang
Length of output: 5382
🏁 Script executed:
# Look for barrier initialization logic
rg -B 5 -A 10 'ptx_init_barrier|barrier_init' src/transform/lower_shared_barrier.ccRepository: tile-ai/tilelang
Length of output: 2985
🏁 Script executed:
# Check if there's automatic initialization of barriers in the lower_hopper_intrin.cc
rg -B 5 -A 10 'barrier.*init|alloc_barrier' src/transform/lower_hopper_intrin.cc | head -80Repository: tile-ai/tilelang
Length of output: 1752
🏁 Script executed:
# Look at how the warp specialize example handles the barrier priming issue
rg -B 10 -A 10 'for ko in T.Pipelined' examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py | head -60Repository: tile-ai/tilelang
Length of output: 977
🏁 Script executed:
# Check the NVIDIA PTX documentation reference - what does mbarrier.init set the parity to?
rg -B 5 -A 5 'mbarrier.*init|initial.*parity' src/Repository: tile-ai/tilelang
Length of output: 6294
🏁 Script executed:
# Verify by checking if other gqa or mha examples have this priming
cat examples/flash_attention_sm100/mha_fwd_bshd.py | sed -n '240,280p'Repository: tile-ai/tilelang
Length of output: 1684
Add pre-loop barrier priming to prevent pipeline stall on first iteration.
The *_empty barriers are initialized to parity 0 by the compiler, but the main loop immediately waits on parity 1 on the first iteration (lines 269, 280, 291, 344). Without prior arrives to establish parity 1, the pipeline deadlocks before it starts.
🔧 Minimal fix
loop_range = (
T.min(
T.ceildiv(seq_len, block_N),
T.ceildiv((bx + 1) * block_M, block_N),
)
if is_causal
else T.ceildiv(seq_len, block_N)
)
+
+ if tid < 128:
+ for stage in T.serial(num_stages):
+ T.mbarrier_arrive(mbar_bmm1_empty[stage])
+ elif tid >= 160 and tid < 192:
+ for stage in T.serial(num_stages):
+ T.mbarrier_arrive(mbar_dma1_empty[stage])
+ T.mbarrier_arrive(mbar_dma2_empty[stage])
+ T.mbarrier_arrive(mbar_softmax_empty[stage])
for k in T.serial(loop_range):Also applies to: 280-291, 344-345
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/flash_attention_sm100/gqa_fwd_bshd.py` around lines 262 - 270, The
loop deadlocks because the mbarrier objects (mbar_dma1_empty, mbar_dma0_empty,
mbar_clear_accum) are initialized at parity 0 but the first iteration
immediately waits for parity 1; to fix this, add a pre-loop priming step that
calls T.mbarrier_arrive_parity(...) with parity 1 for each barrier that the loop
will wait on (mirror the same tid-range conditionals used inside the for k in
T.serial(loop_range) loop), i.e. before the loop execute the appropriate
T.mbarrier_arrive_parity(mbar_dma1_empty, 1),
T.mbarrier_arrive_parity(mbar_dma0_empty, 1), and
T.mbarrier_arrive_parity(mbar_clear_accum, 1) under the same tid checks so the
initial T.mbarrier_wait_parity calls (in the loop) will find parity 1 and not
stall.
| flashattn_ss = flashattn | ||
| flashattn_ts = flashattn |
There was a problem hiding this comment.
flashattn_ts still builds the SS kernel.
Because flashattn_ts is only a plain alias, calling flashattn_ts(...) without an explicit variant="ts" still returns the 128-thread SS path. Export thin wrappers that pin the variant instead of re-exporting flashattn directly.
🔧 Minimal fix
-flashattn_ss = flashattn
-flashattn_ts = flashattn
+def flashattn_ss(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128):
+ return flashattn(
+ batch,
+ heads,
+ seq_len,
+ dim,
+ is_causal,
+ block_M=block_M,
+ block_N=block_N,
+ variant="ss",
+ )
+
+
+def flashattn_ts(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128):
+ return flashattn(
+ batch,
+ heads,
+ seq_len,
+ dim,
+ is_causal,
+ block_M=block_M,
+ block_N=block_N,
+ variant="ts",
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/flash_attention_sm100/mha_fwd_bshd.py` around lines 178 - 179, The
current aliases flashattn_ss = flashattn and flashattn_ts = flashattn simply
re-export the same function so calling flashattn_ts(...) still builds the SS
kernel; replace these aliases with thin wrapper functions named
flashattn_ss(...) and flashattn_ts(...) that call the original flashattn(...)
while forcing variant="ss" and variant="ts" respectively (pass through all other
args/kwargs) so each wrapper reliably selects the intended kernel variant.
| if variant in ("ss", "ts"): | ||
| kernel = flashattn( | ||
| batch, | ||
| heads, | ||
| seq_len, | ||
| dim, | ||
| is_causal, | ||
| block_M=128, | ||
| block_N=128, | ||
| variant=variant, | ||
| ) | ||
| else: | ||
| kernel = flashattn_wasp( | ||
| batch, | ||
| heads, | ||
| seq_len, | ||
| dim, | ||
| is_causal, | ||
| block_M=128, | ||
| block_N=128, | ||
| threads=256, | ||
| num_stages=2, | ||
| ) |
There was a problem hiding this comment.
The advertised WASP→TS fallback is not wired up.
This branch always builds flashattn_wasp(...) directly, so variant="wasp" still bubbles up layout/JIT failures even though the module docstring and CLI help say it should fall back to TS. Make this an explicit elif variant == "wasp" branch, fall back to flashattn(..., variant="ts") on WASP build failure, and raise on unknown variants instead of routing every typo to WASP.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/flash_attention_sm100/mha_fwd_bshd.py` around lines 441 - 463, The
current branch always calls flashattn_wasp for any non-"ss"/"ts" variant; change
the control flow so you handle three cases: if variant in ("ss","ts") call
flashattn(..., variant=variant); elif variant == "wasp" attempt to build kernel
= flashattn_wasp(...) inside a try/except and on exception fall back to kernel =
flashattn(..., variant="ts") (preserve the same
block_M/block_N/threads/num_stages args when attempting WASP), and else raise a
clear ValueError for unknown variant values so typos don't default to WASP.
Performance ReportHardware: NVIDIA Blackwell (Drive Thor) MHA Forward
GQA Forward
|
|
Hi @LeiWang1999 , @Rachmanino I implemented these flash attn SM100 examples, the functionality was verified valid and perf was somehow expected(see the above table). I noticed that some related PRs are on-going, acknowledge still distance away to achieve SOL, specifically:
I suppose these are what we can do currently, please take a review at your convenience, thanks ~ |
|
@Hale423 Thanks for your contribution! I'll take a look then |
|
Yeah, currently we can only use explicit warp specialize to utilize TMA for TCGEN5MMA. Besides, 2cta tcgen5mma will be supported soon! |
…le/flash-atten-sm100
Blackwell (SM100) Flash Attention examples
Summary
Add Flash Attention examples for Blackwell (SM100) using TCGEN05MMA + TMEM, covering MHA/GQA forward and backward and replacing the Hopper WGMMA approach. All kernels use the BSHD layout and support causal masking.
Files in this PR
mha_fwd_bshd.py variants
--variant ss): 128 threads, both GEMMs usetcgen05mma_ss(shared → TMEM), single-path pipeline.--variant ts): 256 threads, single-path; GEMM 2 usesmma_ts(P_tmem × V_shared → D_tmem), reducing shared-memory traffic.--variant wasp): 256 threads, warp-specialized (softmax / DMA / BMM warps), GEMM 2 mma_ts, double-buffered K/V. Automatically falls back to ts if layout inference (or similar) fails.How to run
Technical notes
head_kv = heads // groups; backward dK/dV use atomic_add to aggregate over multiple Q heads.Current status / limitations
assert_close) and tightened later.Follow-ups
--variant ss(and optionally ts).Summary by CodeRabbit
New Features
Bug Fixes