Skip to content

[Example] Flash Attention SM100#1910

Merged
LeiWang1999 merged 8 commits intotile-ai:mainfrom
Hale423:example/flash-atten-sm100
Mar 22, 2026
Merged

[Example] Flash Attention SM100#1910
LeiWang1999 merged 8 commits intotile-ai:mainfrom
Hale423:example/flash-atten-sm100

Conversation

@Hale423
Copy link
Copy Markdown
Contributor

@Hale423 Hale423 commented Mar 8, 2026

Blackwell (SM100) Flash Attention examples

Summary

Add Flash Attention examples for Blackwell (SM100) using TCGEN05MMA + TMEM, covering MHA/GQA forward and backward and replacing the Hopper WGMMA approach. All kernels use the BSHD layout and support causal masking.

Files in this PR

File Description
mha_fwd_bshd.py MHA forward: three variants ss / ts / wasp with optional fallback
gqa_fwd_bshd.py GQA forward: ss (default pipeline), ts (256 threads)
mha_bwd_bshd.py MHA backward: fwd for LSE + bwd, ss / ts
gqa_bwd_bshd.py GQA backward: head_kv indexing, dK/dV via atomic_add, ss / ts

mha_fwd_bshd.py variants

  • ss (--variant ss): 128 threads, both GEMMs use tcgen05mma_ss (shared → TMEM), single-path pipeline.
  • ts (--variant ts): 256 threads, single-path; GEMM 2 uses mma_ts (P_tmem × V_shared → D_tmem), reducing shared-memory traffic.
  • wasp (--variant wasp): 256 threads, warp-specialized (softmax / DMA / BMM warps), GEMM 2 mma_ts, double-buffered K/V. Automatically falls back to ts if layout inference (or similar) fails.

How to run

# MHA forward (default: ss)
python examples/flash_attention_sm100/mha_fwd_bshd.py [--variant ss|ts|wasp] [--is_causal] [--batch 2 --heads 4 --seq_len 256 --dim 128]

# GQA forward
python examples/flash_attention_sm100/gqa_fwd_bshd.py [--variant ss|ts] [--groups N] ...

# MHA backward
python examples/flash_attention_sm100/mha_bwd_bshd.py [--variant ss|ts] ...

# GQA backward
python examples/flash_attention_sm100/gqa_bwd_bshd.py [--variant ss|ts] [--groups N] ...

Technical notes

  • Data flow: GEMM1 Q@K^T → S_tmem → S_reg → online softmax → P_tmem; GEMM2 P_tmem@V → O_tmem; O is rescaled and normalized by logsum in registers, then written out.
  • GQA: K/V are indexed with head_kv = heads // groups; backward dK/dV use atomic_add to aggregate over multiple Q heads.
  • wasp: Three warp roles (tid 0–127 softmax, 128–159 DMA, 160–191 BMM), double-buffered K_shared_0/1 and V_shared_0/1, barrier sync; epilogue (logsum normalization + write Output) runs in the last k iteration.

Current status / limitations

  • wasp depends on layout inference support for fragments in warp branches (e.g. scores_max, logsum). Without the related fixes (Fill InferLayout, step0b default layout, ReduceOp placeholder overwrite), wasp falls back to ts.
  • Correctness: ref vs kernel checks can be relaxed for now (e.g. comment out assert_close) and tightened later.
  • block_M / block_N / num_stages and related tuning are not yet done.

Follow-ups

  • Remove or make optional the wasp→ts fallback once layout inference is stable for wasp.
  • Add CI that builds and runs --variant ss (and optionally ts).

Summary by CodeRabbit

  • New Features

    • Added multiple FlashAttention example implementations targeting Blackwell SM100 GPUs, including forward and backward flows for multi-head and grouped-query attention with causal masking, CPU reference checks, CLI benchmarks, and alternative execution variants for performance tuning.
  • Bug Fixes

    • Corrected tensor memory copy offset computation in the lowering path to align per-thread-group processing.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 8, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 8, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 233a6c38-72cb-48c5-9d39-4cae38d97a9e

📥 Commits

Reviewing files that changed from the base of the PR and between 41d9f9f and c7fb9b8.

📒 Files selected for processing (1)
  • src/op/copy.cc
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/op/copy.cc

📝 Walkthrough

Walkthrough

Adds Blackwell (SM100) FlashAttention forward and backward implementations (GQA and MHA) in BSHD layout, preprocessing/postprocessing helpers, CPU references and CLI harnesses, plus a fix to tcgen05 tensor-memory copy offset computation in src/op/copy.cc.

Changes

Cohort / File(s) Summary
GQA Forward
examples/flash_attention_sm100/gqa_fwd_bshd.py
New forward kernels (variants: ss, ts, wasp/warp) with block tiling, grouping, causal support, CPU reference, correctness checks, and performance benchmarking CLI.
GQA Backward
examples/flash_attention_sm100/gqa_bwd_bshd.py
New backward flow including forward kernel, preprocess (Delta), main backward (dQ/dK/dV) with pipeline/warp variants and atomic updates for grouped reductions, postprocess to reorder dQ, layout helper, and end-to-end CLI.
MHA Forward
examples/flash_attention_sm100/mha_fwd_bshd.py
New forward variants using TCGEN05MMA/TMEM (ss, ts, wasp/warp), multi-stage warp pipeline, reference implementation, benchmarking harness, and PASS_CFG constants.
MHA Backward
examples/flash_attention_sm100/mha_bwd_bshd.py
New backward implementation with forward kernel, preprocess/postprocess, main backward pipeline/warp variants, custom dQ layout, and CLI-driven end-to-end test.
Copy Operation Fix
src/op/copy.cc
Adjusts tcgen05 LowerTmemCopy lowering to use an effective_chunks value when computing col_offset (accounts for pack/unpack), and reorders computation so effective_chunks is derived before relative_wg_idx/col_offset.

Sequence Diagram

sequenceDiagram
    actor User
    participant Forward as Forward Kernels
    participant Preproc as Preprocess (Delta)
    participant Backward as Backward Kernels
    participant Postproc as Postprocess (dQ Layout)
    participant Memory as Global/Shared Memory

    User->>Forward: provide Q, K, V
    Forward->>Memory: store O, LSE
    Forward-->>User: return O, LSE

    User->>Preproc: provide O, dO
    Preproc->>Memory: store Delta
    Preproc-->>User: return Delta

    User->>Backward: Q, K, V, dO, Delta
    Backward->>Memory: accumulate dK, dV (atomic if groups>1)
    Backward->>Memory: compute dQ blocks
    Backward-->>User: return dQ, dK, dV

    User->>Postproc: dQ (accumulated layout)
    Postproc->>Memory: convert to standard layout
    Postproc-->>User: return dQ (final layout)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • Rachmanino
  • chengyupku

Poem

🐇 I nibble code in moonlit rows,

Kernels hum where softmax grows.
Forward, backward—threads align,
Blackwell gates in tiled design.
Hop, test, and benchmark — joy in each line.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 45.24% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Example] Flash Attention SM100' directly and concisely describes the main addition: Flash Attention example implementations for NVIDIA Blackwell (SM100) architecture, which is the core purpose of the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can use OpenGrep to find security vulnerabilities and bugs across 17+ programming languages.

OpenGrep is compatible with Semgrep configurations. Add an opengrep.yml or semgrep.yml configuration file to your project to enable OpenGrep analysis.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/flash_attention_sm100/gqa_bwd_bshd.py`:
- Around line 24-31: The code computes head_kv = heads // groups without
validating groups which can be zero or non-dividing; add a guard in
flashattn_fwd, flashattn_bwd and main before computing head_kv to ensure groups
is a positive integer and that heads % groups == 0 (e.g., raise/throw an error
or assert if groups <= 0 or heads % groups != 0) so kv_shape and subsequent
bx//groups indexing cannot underflow or index past KV heads; reference the
head_kv computation sites in flashattn_fwd, flashattn_bwd and the main function
and perform the same validation there.
- Around line 259-274: ref_program() implements only the forward reference, but
main() still runs backward kernels and unconditionally prints "correctness run
OK." without verifying dQ_out, dK, or dV; either implement a matching backward
reference or stop claiming success for backward runs. Fix by one of two options:
(A) implement a CPU reference backward (e.g., add ref_program_backward or extend
ref_program to return gradients for dQ_out, dK, dV using numerical/analytical
accumulation matching the grouped atomic behavior) and compare those to the
kernel outputs (dQ_out, dK, dV) before printing success; or (B) if you cannot
verify backward, guard the backward execution in main() and skip the backward
comparisons/printing (or print a warning) when ref_program does not provide
gradients. Locate and change the logic around ref_program, main, and the
variables dQ_out, dK, dV to ensure backward results are actually validated
before emitting "correctness run OK."

In `@examples/flash_attention_sm100/gqa_fwd_bshd.py`:
- Around line 24-39: Validate the groups parameter before computing head_kv to
prevent divide-by-zero and non-divisor truncation: in flashattn_ss (and
similarly in flashattn_ts and main) add a guard that ensures groups is a
positive integer and that heads % groups == 0 before doing head_kv = heads //
groups; if the check fails, raise or return a clear error (e.g., ValueError)
explaining that groups must be >0 and evenly divide heads so the by // groups
indexing for K/V heads is safe.

In `@examples/flash_attention_sm100/mha_bwd_bshd.py`:
- Around line 254-268: The test never verifies gradients: modify the test to
compute reference gradients from ref_program(Q,K,V,is_causal) by setting
Q,K,V.requires_grad_=True, running ref_program to produce O_ref, calling a
simple scalar loss (e.g., O_ref.sum()) and backward() to capture ref dQ_ref,
dK_ref, dV_ref, then compare those to the kernel outputs dQ_out, dK, dV (using
torch.allclose or appropriate atol/rtol) and assert failures; update both the
backward-check in main() and the similar check around the other block (the spot
referenced at lines ~293-300) to perform the same gradient computation and
comparison against the kernel results.

In `@examples/flash_attention_sm100/mha_fwd_bshd.py`:
- Around line 674-676: Re-enable the validation by restoring the equality check:
call torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) (using the
existing ref_program/Q/K/V result) and only print "Correctness check passed."
after the assert completes successfully; ensure the assert compares out and ref
on the same device (ref_program(...).to(out.device)) and remove the stray
unconditional print so failures aren’t masked.
- Around line 494-523: Insert waits on the GEMM2 completion barrier so softmax
and the epilogue don't read O_tmem before TCGEN05MMA finishes: in the softmax
warp, call T.mbarrier_wait_parity(mbar_bmm2_full[stage_id], parity) immediately
before the softmax code that reads O_tmem (i.e., prior to the current wait on
mbar_softmax_empty and the load of O_tmem), and in the epilogue replace or add
the wait on mbar_correction_full[0] with
T.mbarrier_wait_parity(mbar_bmm2_full[0], parity) (or wait on both if needed) so
the epilogue reads O_tmem only after mbar_bmm2_full signals completion.

In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py`:
- Around line 78-80: The script computes ref_c but the equality check is
commented out causing false "All checks passed" results; restore the correctness
gate by re-enabling torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
(or replace with an equivalent assertion) before the final print, and ensure c
and ref_c are the tensors compared (variables named c and ref_c in the diff) so
the script fails when outputs differ instead of always printing success.
- Around line 69-71: The chosen tile and staging sizes (block_M, block_N,
block_K, num_stages) allocate ~448KiB of shared memory for
A_shared+B_shared+C_shared which exceeds the per-CTA shared-memory limit; adjust
these constants to reduce shared-memory usage so the kernel can compile/launch.
Locate the declarations of block_M, block_N, block_K and num_stages and either
reduce block_N (e.g., 256 -> 128) or reduce num_stages (e.g., 4 -> 2), or choose
a smaller combination (for example block_M=128, block_N=128, block_K=128,
num_stages=2) until A_shared+B_shared+C_shared fits within the target per-block
shared-memory budget. Ensure any dependent launch/configuration logic that
assumes the previous tile shapes is updated accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5373cfdb-330a-4478-b77a-da3d1063462c

📥 Commits

Reviewing files that changed from the base of the PR and between 9bae56f and 3f1a185.

📒 Files selected for processing (7)
  • examples/flash_attention_sm100/gqa_bwd_bshd.py
  • examples/flash_attention_sm100/gqa_fwd_bshd.py
  • examples/flash_attention_sm100/mha_bwd_bshd.py
  • examples/flash_attention_sm100/mha_fwd_bshd.py
  • examples/flash_attention_sm100/src.cu
  • examples/gemm_sm100/gemm_tcgen5mma.py
  • examples/gemm_sm100/gemm_tcgen5mma_ws.py

Comment thread examples/flash_attention_sm100/gqa_bwd_bshd.py
Comment thread examples/flash_attention_sm100/gqa_bwd_bshd.py Outdated
Comment thread examples/flash_attention_sm100/gqa_fwd_bshd.py Outdated
Comment on lines +254 to +268
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
Q_f = Q.cpu().float()
K_f = K.cpu().float()
V_f = V.cpu().float()
scores = torch.einsum("bqhd,bkhd->bhqk", Q_f, K_f)
scores = scores / (dim**0.5)
if is_causal:
seq_len = Q_f.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
P = F.softmax(scores, dim=-1)
O = torch.einsum("bhqk,bkhd->bqhd", P, V_f)
return O.to(Q.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The backward example never checks any gradients.

ref_program() only computes a forward output, and main() just runs the kernels then prints a success message. dQ_out, dK, and dV are never compared against a reference, so silent gradient bugs in this new kernel will look like passes.

Also applies to: 293-300

🧰 Tools
🪛 Ruff (0.15.4)

[error] 267-267: Ambiguous variable name: O

(E741)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention_sm100/mha_bwd_bshd.py` around lines 254 - 268, The
test never verifies gradients: modify the test to compute reference gradients
from ref_program(Q,K,V,is_causal) by setting Q,K,V.requires_grad_=True, running
ref_program to produce O_ref, calling a simple scalar loss (e.g., O_ref.sum())
and backward() to capture ref dQ_ref, dK_ref, dV_ref, then compare those to the
kernel outputs dQ_out, dK, dV (using torch.allclose or appropriate atol/rtol)
and assert failures; update both the backward-check in main() and the similar
check around the other block (the spot referenced at lines ~293-300) to perform
the same gradient computation and comparison against the kernel results.

Comment thread examples/flash_attention_sm100/mha_fwd_bshd.py Outdated
Comment thread examples/flash_attention_sm100/mha_fwd_bshd.py
Comment thread examples/gemm_sm100/gemm_tcgen5mma_ws.py Outdated
Comment thread examples/gemm_sm100/gemm_tcgen5mma_ws.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
examples/flash_attention_sm100/mha_bwd_bshd.py (1)

241-289: ⚠️ Potential issue | 🟠 Major

Add a real gradient oracle for the backward example.

ref_program() is still forward-only, and main() never compares dQ_out, dK, or dV to anything. The updated log line is more honest, but this path still only proves the kernels run, not that the backward math is correct. At minimum, run a small autograd reference and assert the gradients before treating this example as validated.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention_sm100/mha_bwd_bshd.py` around lines 241 - 289, The
test only runs kernels but doesn't verify backward math; add a gradient oracle
using autograd and compare it to kernel outputs: use
ref_program(Q,K,V,is_causal) (or recompute a differentiable CPU/float32 forward)
with requires_grad=True on Q,K,V, run a small randomized case (reduce
batch/seq_len/heads/dim for speed), call torch.autograd.grad to get ref
dQ_ref,dK_ref,dV_ref, then run kernel_fwd/kernel_prep/kernel_bwd/kernel_post as
before to produce dQ_out,dK_out,dV_out and assert they are close (torch.allclose
with appropriate rtol/atol). Update main() to perform this check (and still keep
the runtime smoke test) and reference the symbols kernel_fwd, kernel_prep,
kernel_bwd, kernel_post, ref_program, and the dQ/dK/dV variables so the changes
are local and discoverable.
examples/flash_attention_sm100/gqa_bwd_bshd.py (1)

270-322: ⚠️ Potential issue | 🟠 Major

Please verify the grouped backward path against a reference.

This still only checks that the kernels execute. ref_program() is forward-only, dQ_out is discarded, and the grouped atomic_add path for dK/dV is never compared against autograd or another oracle, so silent gradient regressions will slip through.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention_sm100/gqa_bwd_bshd.py` around lines 270 - 322, The
test only runs the grouped backward kernel but never verifies its dQ/dK/dV
outputs against a reference; use ref_program to create a differentiable forward
reference and compute autograd gradients for Q, K, V, then compare those to the
outputs from kernel_bwd. Concretely: make Q,K,V require_grad (float), compute
out_ref = ref_program(Q, K, V, is_causal, groups), compute scalar loss =
(out_ref.to(dO.dtype) * dO).sum() or loss = (out_ref * dO.float()).sum(), call
torch.autograd.grad to get ref_dQ, ref_dK, ref_dV, and then assert close(ref_dQ,
dQ_out_from kernel_post or dQ), ref_dK vs dK, ref_dV vs dV within tolerances;
keep using kernel_fwd/kernel_prep/kernel_bwd/kernel_post flow but add this
autograd reference path and comparisons to detect regressions in the grouped
atomic_add path.
🧹 Nitpick comments (1)
examples/flash_attention_sm100/mha_fwd_bshd.py (1)

563-564: Make flashattn_warp mean the same thing across the SM100 examples.

Here the alias points to flashattn_wasp, while examples/flash_attention_sm100/gqa_fwd_bshd.py:309 points the same name at flashattn_ts. That makes shared imports/tooling select materially different pipelines under one symbol. Prefer exporting the concrete variant name or aligning the alias target across both modules.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention_sm100/mha_fwd_bshd.py` around lines 563 - 564, The
alias flashattn_warp is inconsistent across SM100 examples (here it's set to
flashattn_wasp while another module sets flashattn_warp = flashattn_ts); update
this module so flashattn_warp points to the same concrete implementation used in
the other SM100 examples (or remove the alias and export the concrete name
directly) by changing the alias assignment from flashattn_wasp to the agreed
concrete symbol (e.g., flashattn_ts) or by replacing usages to import/export the
concrete symbol instead of flashattn_warp (adjust references to flashattn_warp,
flashattn_wasp, and flashattn_ts accordingly).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/flash_attention_sm100/mha_fwd_bshd.py`:
- Around line 389-395: The code only primes barrier index 0 before entering the
WASP loop, causing later iterations to wait on unprimed stages and potentially
deadlock when num_stages>1; change the priming logic in the is_bmm_warp block to
loop over all stages (0..num_stages-1) and call T.mbarrier_arrive for each of
mbar_dma1_empty, mbar_dma2_empty, mbar_bmm1_empty, and mbar_softmax_empty (use
the same indexing style as mbar_*[stage_id]) so every pipeline stage is primed
before the WASP loop.

---

Duplicate comments:
In `@examples/flash_attention_sm100/gqa_bwd_bshd.py`:
- Around line 270-322: The test only runs the grouped backward kernel but never
verifies its dQ/dK/dV outputs against a reference; use ref_program to create a
differentiable forward reference and compute autograd gradients for Q, K, V,
then compare those to the outputs from kernel_bwd. Concretely: make Q,K,V
require_grad (float), compute out_ref = ref_program(Q, K, V, is_causal, groups),
compute scalar loss = (out_ref.to(dO.dtype) * dO).sum() or loss = (out_ref *
dO.float()).sum(), call torch.autograd.grad to get ref_dQ, ref_dK, ref_dV, and
then assert close(ref_dQ, dQ_out_from kernel_post or dQ), ref_dK vs dK, ref_dV
vs dV within tolerances; keep using
kernel_fwd/kernel_prep/kernel_bwd/kernel_post flow but add this autograd
reference path and comparisons to detect regressions in the grouped atomic_add
path.

In `@examples/flash_attention_sm100/mha_bwd_bshd.py`:
- Around line 241-289: The test only runs kernels but doesn't verify backward
math; add a gradient oracle using autograd and compare it to kernel outputs: use
ref_program(Q,K,V,is_causal) (or recompute a differentiable CPU/float32 forward)
with requires_grad=True on Q,K,V, run a small randomized case (reduce
batch/seq_len/heads/dim for speed), call torch.autograd.grad to get ref
dQ_ref,dK_ref,dV_ref, then run kernel_fwd/kernel_prep/kernel_bwd/kernel_post as
before to produce dQ_out,dK_out,dV_out and assert they are close (torch.allclose
with appropriate rtol/atol). Update main() to perform this check (and still keep
the runtime smoke test) and reference the symbols kernel_fwd, kernel_prep,
kernel_bwd, kernel_post, ref_program, and the dQ/dK/dV variables so the changes
are local and discoverable.

---

Nitpick comments:
In `@examples/flash_attention_sm100/mha_fwd_bshd.py`:
- Around line 563-564: The alias flashattn_warp is inconsistent across SM100
examples (here it's set to flashattn_wasp while another module sets
flashattn_warp = flashattn_ts); update this module so flashattn_warp points to
the same concrete implementation used in the other SM100 examples (or remove the
alias and export the concrete name directly) by changing the alias assignment
from flashattn_wasp to the agreed concrete symbol (e.g., flashattn_ts) or by
replacing usages to import/export the concrete symbol instead of flashattn_warp
(adjust references to flashattn_warp, flashattn_wasp, and flashattn_ts
accordingly).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8f52ee5a-fc74-43ac-99fd-70ea9e11aa6c

📥 Commits

Reviewing files that changed from the base of the PR and between 3f1a185 and f05479c.

📒 Files selected for processing (4)
  • examples/flash_attention_sm100/gqa_bwd_bshd.py
  • examples/flash_attention_sm100/gqa_fwd_bshd.py
  • examples/flash_attention_sm100/mha_bwd_bshd.py
  • examples/flash_attention_sm100/mha_fwd_bshd.py

Comment on lines +389 to +395
# Prime empty barriers so first iteration can proceed (phase 1 for parity_inv=1 at k=0)
if is_bmm_warp:
T.mbarrier_arrive(mbar_dma1_empty[0])
T.mbarrier_arrive(mbar_dma2_empty[0])
T.mbarrier_arrive(mbar_bmm1_empty[0])
T.mbarrier_arrive(mbar_softmax_empty[0])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Prime every pipeline stage before entering the WASP loop.

Only stage 0 is marked empty here, but Line 413, Line 425, Line 437, and Line 492 all wait on stage_id. With the default num_stages=2, the k=1 iteration waits on stage 1 with parity 1 even though nothing has ever arrived there, so the pipeline can deadlock as soon as loop_range > 1.

🔧 Minimal fix
             # Prime empty barriers so first iteration can proceed (phase 1 for parity_inv=1 at k=0)
             if is_bmm_warp:
-                T.mbarrier_arrive(mbar_dma1_empty[0])
-                T.mbarrier_arrive(mbar_dma2_empty[0])
-                T.mbarrier_arrive(mbar_bmm1_empty[0])
-                T.mbarrier_arrive(mbar_softmax_empty[0])
+                for stage in range(num_stages):
+                    T.mbarrier_arrive(mbar_dma1_empty[stage])
+                    T.mbarrier_arrive(mbar_dma2_empty[stage])
+                    T.mbarrier_arrive(mbar_bmm1_empty[stage])
+                    T.mbarrier_arrive(mbar_softmax_empty[stage])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention_sm100/mha_fwd_bshd.py` around lines 389 - 395, The
code only primes barrier index 0 before entering the WASP loop, causing later
iterations to wait on unprimed stages and potentially deadlock when
num_stages>1; change the priming logic in the is_bmm_warp block to loop over all
stages (0..num_stages-1) and call T.mbarrier_arrive for each of mbar_dma1_empty,
mbar_dma2_empty, mbar_bmm1_empty, and mbar_softmax_empty (use the same indexing
style as mbar_*[stage_id]) so every pipeline stage is primed before the WASP
loop.

…> 128 threads or > 1 WG, merged flash_atten ts & ss kernels into one, specify by extra argment
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
examples/flash_attention_sm100/gqa_bwd_bshd.py (1)

127-131: Consider renaming ambiguous variable lseq in layout lambda.

The variable l (line 130) can be confused with 1. Ruff flags this as E741.

Suggested fix
 def make_dq_layout(dQ):
     return T.Layout(
         dQ.shape,
-        lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2],
+        lambda b, seq, h, d: [b, seq // 8, h, d // 8, (d % 2), 4 * (seq % 8) + (d % 8) // 2],
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention_sm100/gqa_bwd_bshd.py` around lines 127 - 131,
Rename the ambiguous lambda parameter l to seq inside make_dq_layout to avoid
confusion with the digit 1 and satisfy Ruff E741: update the lambda signature
from "lambda b, l, h, d" to "lambda b, seq, h, d" and replace all uses of l in
the lambda body (e.g., "l // 8" and "4 * (l % 8)") with seq so the layout
construction in make_dq_layout continues to compute correctly.
examples/flash_attention_sm100/gqa_fwd_bshd.py (1)

197-198: Consider renaming ambiguous variable Oout.

The variable name O can be confused with 0. This is flagged by Ruff (E741).

Suggested fix
-    O = torch.einsum("bhqk,bkhd->bqhd", P, V_f)
-    return O.to(Q.dtype)
+    out = torch.einsum("bhqk,bkhd->bqhd", P, V_f)
+    return out.to(Q.dtype)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention_sm100/gqa_fwd_bshd.py` around lines 197 - 198,
Rename the ambiguous variable O to a clearer name like out in the einsum
assignment and in the return; specifically change the line "O =
torch.einsum(\"bhqk,bkhd->bqhd\", P, V_f)" to assign to out and update "return
O.to(Q.dtype)" to "return out.to(Q.dtype)" so the variable isn't confused with
the digit zero (affects the local scope where P, V_f, and Q.dtype are used).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@examples/flash_attention_sm100/gqa_bwd_bshd.py`:
- Around line 127-131: Rename the ambiguous lambda parameter l to seq inside
make_dq_layout to avoid confusion with the digit 1 and satisfy Ruff E741: update
the lambda signature from "lambda b, l, h, d" to "lambda b, seq, h, d" and
replace all uses of l in the lambda body (e.g., "l // 8" and "4 * (l % 8)") with
seq so the layout construction in make_dq_layout continues to compute correctly.

In `@examples/flash_attention_sm100/gqa_fwd_bshd.py`:
- Around line 197-198: Rename the ambiguous variable O to a clearer name like
out in the einsum assignment and in the return; specifically change the line "O
= torch.einsum(\"bhqk,bkhd->bqhd\", P, V_f)" to assign to out and update "return
O.to(Q.dtype)" to "return out.to(Q.dtype)" so the variable isn't confused with
the digit zero (affects the local scope where P, V_f, and Q.dtype are used).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 55f64213-597b-499f-b32b-7714b1258a1c

📥 Commits

Reviewing files that changed from the base of the PR and between f05479c and 1fdec16.

📒 Files selected for processing (4)
  • examples/flash_attention_sm100/gqa_bwd_bshd.py
  • examples/flash_attention_sm100/gqa_fwd_bshd.py
  • examples/flash_attention_sm100/mha_fwd_bshd.py
  • src/op/copy.cc
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/flash_attention_sm100/mha_fwd_bshd.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (1)
examples/flash_attention_sm100/mha_fwd_bshd.py (1)

257-265: ⚠️ Potential issue | 🔴 Critical

Prime the WASP *_empty barriers before k == 0.

Line 264, Line 275, Line 286, and Line 339 all wait on *_empty[stage_id], but there is no visible pre-loop mbarrier_arrive for any stage. If these barriers are not implicitly initialized to the waited parity, the first DMA/BMM/softmax handoff can deadlock immediately.

🔧 Minimal fix
             loop_range = (
                 T.min(
                     T.ceildiv(seq_len, block_N),
                     T.ceildiv((bx + 1) * block_M, block_N),
                 )
                 if is_causal
                 else T.ceildiv(seq_len, block_N)
             )
+
+            if tid < 128:
+                for stage in T.serial(num_stages):
+                    T.mbarrier_arrive(mbar_bmm1_empty[stage])
+            elif tid >= 160 and tid < 192:
+                for stage in T.serial(num_stages):
+                    T.mbarrier_arrive(mbar_dma1_empty[stage])
+                    T.mbarrier_arrive(mbar_dma2_empty[stage])
+                    T.mbarrier_arrive(mbar_softmax_empty[stage])
 
             for k in T.serial(loop_range):

Run this read-only check to confirm there is no priming block before the first waits:

#!/bin/bash
sed -n '248,345p' examples/flash_attention_sm100/mha_fwd_bshd.py
printf '\n-- *_empty waits/arrives --\n'
rg -n -C1 'mbarrier_(wait_parity|arrive)\(mbar_(dma1_empty|dma2_empty|bmm1_empty|softmax_empty)\[' examples/flash_attention_sm100/mha_fwd_bshd.py

Expected result: at least one mbarrier_arrive(..._empty[stage]) block should appear before for k in T.serial(loop_range):. Right now all visible *_empty arrives are inside the loop, after the first waits.

Also applies to: 275-286, 339-340

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention_sm100/mha_fwd_bshd.py` around lines 257 - 265, The
waits on the WASP barriers (calls to mbarrier_wait_parity using mbar_dma1_empty,
mbar_dma2_empty, mbar_bmm1_empty, mbar_softmax_empty inside the for k in
T.serial(loop_range) loop) must be primed before entering the loop to avoid
deadlock; add a pre-loop block that calls mbarrier_arrive on each *_empty
barrier for every stage_id (or at least for the stages used) with the parity
expected by the first waits (the inverse parity used in the existing
mbarrier_wait_parity calls) so the initial waited parity is satisfied when k==0
(use the same stage_id calculation/num_stages logic as in the loop and mirror
parity/parity_inv computations).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/flash_attention_sm100/gqa_fwd_bshd.py`:
- Around line 178-179: The exported flashattn_ts currently aliases flashattn (so
it inherits the default variant="ss"); replace the direct assignment with a thin
wrapper named flashattn_ts that calls flashattn but forces variant="ts" (e.g.,
accept *args/**kwargs and pass through while setting variant="ts" or overriding
any variant in kwargs) and use functools.wraps to preserve metadata; keep
flashattn_ss as the plain alias to flashattn for the "ss" variant.
- Around line 262-270: The loop deadlocks because the mbarrier objects
(mbar_dma1_empty, mbar_dma0_empty, mbar_clear_accum) are initialized at parity 0
but the first iteration immediately waits for parity 1; to fix this, add a
pre-loop priming step that calls T.mbarrier_arrive_parity(...) with parity 1 for
each barrier that the loop will wait on (mirror the same tid-range conditionals
used inside the for k in T.serial(loop_range) loop), i.e. before the loop
execute the appropriate T.mbarrier_arrive_parity(mbar_dma1_empty, 1),
T.mbarrier_arrive_parity(mbar_dma0_empty, 1), and
T.mbarrier_arrive_parity(mbar_clear_accum, 1) under the same tid checks so the
initial T.mbarrier_wait_parity calls (in the loop) will find parity 1 and not
stall.

In `@examples/flash_attention_sm100/mha_fwd_bshd.py`:
- Around line 441-463: The current branch always calls flashattn_wasp for any
non-"ss"/"ts" variant; change the control flow so you handle three cases: if
variant in ("ss","ts") call flashattn(..., variant=variant); elif variant ==
"wasp" attempt to build kernel = flashattn_wasp(...) inside a try/except and on
exception fall back to kernel = flashattn(..., variant="ts") (preserve the same
block_M/block_N/threads/num_stages args when attempting WASP), and else raise a
clear ValueError for unknown variant values so typos don't default to WASP.
- Around line 178-179: The current aliases flashattn_ss = flashattn and
flashattn_ts = flashattn simply re-export the same function so calling
flashattn_ts(...) still builds the SS kernel; replace these aliases with thin
wrapper functions named flashattn_ss(...) and flashattn_ts(...) that call the
original flashattn(...) while forcing variant="ss" and variant="ts" respectively
(pass through all other args/kwargs) so each wrapper reliably selects the
intended kernel variant.

---

Duplicate comments:
In `@examples/flash_attention_sm100/mha_fwd_bshd.py`:
- Around line 257-265: The waits on the WASP barriers (calls to
mbarrier_wait_parity using mbar_dma1_empty, mbar_dma2_empty, mbar_bmm1_empty,
mbar_softmax_empty inside the for k in T.serial(loop_range) loop) must be primed
before entering the loop to avoid deadlock; add a pre-loop block that calls
mbarrier_arrive on each *_empty barrier for every stage_id (or at least for the
stages used) with the parity expected by the first waits (the inverse parity
used in the existing mbarrier_wait_parity calls) so the initial waited parity is
satisfied when k==0 (use the same stage_id calculation/num_stages logic as in
the loop and mirror parity/parity_inv computations).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 80f39a49-cae2-4ba3-ba28-8529b3ccd031

📥 Commits

Reviewing files that changed from the base of the PR and between 1fdec16 and 41d9f9f.

📒 Files selected for processing (2)
  • examples/flash_attention_sm100/gqa_fwd_bshd.py
  • examples/flash_attention_sm100/mha_fwd_bshd.py

Comment on lines +178 to +179
flashattn_ss = flashattn
flashattn_ts = flashattn
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

flashattn_ts still exports the SS configuration.

flashattn_ts = flashattn preserves the default variant="ss", so direct callers of flashattn_ts(...) silently get the wrong kernel unless they override the variant themselves. A small wrapper that pins variant="ts" avoids that trap.

🔧 Minimal fix
-flashattn_ss = flashattn
-flashattn_ts = flashattn
+def flashattn_ss(batch, heads, seq_len, dim, is_causal, groups=1, block_M=128, block_N=128):
+    return flashattn(
+        batch,
+        heads,
+        seq_len,
+        dim,
+        is_causal,
+        groups=groups,
+        block_M=block_M,
+        block_N=block_N,
+        variant="ss",
+    )
+
+
+def flashattn_ts(batch, heads, seq_len, dim, is_causal, groups=1, block_M=128, block_N=128):
+    return flashattn(
+        batch,
+        heads,
+        seq_len,
+        dim,
+        is_causal,
+        groups=groups,
+        block_M=block_M,
+        block_N=block_N,
+        variant="ts",
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention_sm100/gqa_fwd_bshd.py` around lines 178 - 179, The
exported flashattn_ts currently aliases flashattn (so it inherits the default
variant="ss"); replace the direct assignment with a thin wrapper named
flashattn_ts that calls flashattn but forces variant="ts" (e.g., accept
*args/**kwargs and pass through while setting variant="ts" or overriding any
variant in kwargs) and use functools.wraps to preserve metadata; keep
flashattn_ss as the plain alias to flashattn for the "ss" variant.

Comment on lines +262 to +270
for k in T.serial(loop_range):
parity = (k // num_stages) & 1
parity_inv = parity ^ 1
stage_id = k % num_stages
is_clear_accum = k == 0

if tid >= 128 and tid < 160:
T.mbarrier_wait_parity(mbar_dma1_empty[stage_id], parity_inv)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
sed -n '253,350p' examples/flash_attention_sm100/gqa_fwd_bshd.py
printf '\n-- *_empty waits/arrives --\n'
rg -n -C1 'mbarrier_(wait_parity|arrive)\(mbar_(dma1_empty|dma2_empty|bmm1_empty|softmax_empty)\[' examples/flash_attention_sm100/gqa_fwd_bshd.py

Repository: tile-ai/tilelang

Length of output: 5214


🏁 Script executed:

# Search for alloc_barrier function and barrier initialization
rg -n 'alloc_barrier' examples/flash_attention_sm100/gqa_fwd_bshd.py

Repository: tile-ai/tilelang

Length of output: 841


🏁 Script executed:

# Also search for where the *_empty barriers are created/allocated
rg -n 'mbar_(dma1_empty|dma2_empty|bmm1_empty|softmax_empty)' examples/flash_attention_sm100/gqa_fwd_bshd.py | head -50

Repository: tile-ai/tilelang

Length of output: 965


🏁 Script executed:

# Look at the full context around barrier allocations
sed -n '1,260p' examples/flash_attention_sm100/gqa_fwd_bshd.py | tail -100

Repository: tile-ai/tilelang

Length of output: 3781


🏁 Script executed:

# Search for alloc_barrier implementation or documentation
fd -t f '\.py$' | xargs rg -l 'alloc_barrier' | head -20

Repository: tile-ai/tilelang

Length of output: 1170


🏁 Script executed:

# Look for any documentation or implementation of alloc_barrier
rg -A 5 'def alloc_barrier' --type py

Repository: tile-ai/tilelang

Length of output: 443


🏁 Script executed:

# Check if there's any other initialization happening around barrier creation
rg -B 2 -A 2 'alloc_barrier' examples/flash_attention_sm100/ --type py

Repository: tile-ai/tilelang

Length of output: 3940


🏁 Script executed:

# Check the full alloc_barrier implementation
sed -n '/def alloc_barrier/,/^[^ ]/p' tilelang/language/allocate.py

Repository: tile-ai/tilelang

Length of output: 1181


🏁 Script executed:

# Search for any barrier initialization or parity semantics in the codebase
rg -i 'barrier.*parity|parity.*init' --type py | head -20

Repository: tile-ai/tilelang

Length of output: 2328


🏁 Script executed:

# Look at how barriers are used in other examples to see if priming is done elsewhere
rg -B 3 'for.*T.serial.*loop_range' examples/flash_attention_sm100/*.py | head -40

Repository: tile-ai/tilelang

Length of output: 652


🏁 Script executed:

# Look for any barrier initialization semantics or comments
rg -B 3 -A 3 'barrier_init' tilelang/language/allocate.py

Repository: tile-ai/tilelang

Length of output: 712


🏁 Script executed:

# Check if there's a lowering pass that handles barrier initialization
rg -l 'barrier_init|mbarrier' tilelang/compiler/passes/*.py 2>/dev/null | head -10

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for any comments or documentation about barrier parity initialization
rg -i 'initial.*parity|parity.*0|parity.*1|barrier.*state' examples/ --type py | head -20

Repository: tile-ai/tilelang

Length of output: 2745


🏁 Script executed:

# Look at the warp_specialize examples for barrier usage patterns
cat examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py | grep -A 50 'alloc_barrier' | head -80

Repository: tile-ai/tilelang

Length of output: 1787


🏁 Script executed:

# Look for any compiler pass that handles barrier initialization
find tilelang -name "*.py" -o -name "*.cc" -o -name "*.h" | xargs rg -l 'mbarrier.*init|initialize.*barrier' 2>/dev/null | head -10

Repository: tile-ai/tilelang

Length of output: 148


🏁 Script executed:

# Check if there's documentation in the allocate.py file about initial barrier state
cat tilelang/language/allocate.py | head -100

Repository: tile-ai/tilelang

Length of output: 4125


🏁 Script executed:

# Look for test cases involving barriers and parity to understand expected behavior
rg -B 5 -A 10 'mbarrier_wait_parity.*0\)' testing/ --type py | head -60

Repository: tile-ai/tilelang

Length of output: 5792


🏁 Script executed:

# Search for ptx_init_barrier_thread_count or ptx_fence_barrier_init in the flash attention files
rg -n 'ptx_init_barrier_thread_count|ptx_fence_barrier_init' examples/flash_attention_sm100/

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check if there's barrier initialization happening in the kernel
sed -n '230,270p' examples/flash_attention_sm100/gqa_fwd_bshd.py

Repository: tile-ai/tilelang

Length of output: 1713


🏁 Script executed:

# Look for any compiler transformation that automatically injects barrier init
rg -l 'ptx_init_barrier|ptx_fence_barrier_init' tilelang/ --type py

Repository: tile-ai/tilelang

Length of output: 143


🏁 Script executed:

# Search for barrier initialization in the compiler
rg -n 'ptx_init_barrier|ptx_fence_barrier_init' tilelang/language/tir/

Repository: tile-ai/tilelang

Length of output: 774


🏁 Script executed:

# Check the IR/op definitions for barrier init
rg -B 2 -A 5 'def ptx_init_barrier_thread_count|def ptx_fence_barrier_init' tilelang/

Repository: tile-ai/tilelang

Length of output: 1675


🏁 Script executed:

# Look for compiler passes that might auto-inject barrier initialization
fd -t f '\.(py|cc|h)$' tilelang/compiler | head -20

Repository: tile-ai/tilelang

Length of output: 164


🏁 Script executed:

# Check if there's a specific pass file for barriers
find tilelang -name "*.py" | xargs rg -l 'barrier.*init.*pass|init.*barrier.*inject' 2>/dev/null

Repository: tile-ai/tilelang

Length of output: 44


🏁 Script executed:

# Search for how ptx_init_barrier_thread_count is used in the codebase
rg -B 3 -A 3 'ptx_init_barrier_thread_count' examples/ testing/

Repository: tile-ai/tilelang

Length of output: 7342


🏁 Script executed:

# Check if gqa_fwd_bshd.py has any use of ptx_init_barrier or ptx_fence_barrier_init
rg 'ptx_init_barrier|ptx_fence_barrier_init' examples/flash_attention_sm100/gqa_fwd_bshd.py

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look for the lowering pass that handles barriers
find tilelang -name "*.py" -path "*/passes/*" | xargs cat 2>/dev/null | grep -l "barrier\|lower" | head -10

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for references to barrier initialization in C++ or compiler code
find tilelang -name "*.cc" -o -name "*.h" 2>/dev/null | xargs rg -l 'barrier.*init' 2>/dev/null | head -5

Repository: tile-ai/tilelang

Length of output: 219


🏁 Script executed:

# Check the lower_shared_barrier.cc file
cat src/transform/lower_shared_barrier.cc | head -150

Repository: tile-ai/tilelang

Length of output: 5382


🏁 Script executed:

# Look for barrier initialization logic
rg -B 5 -A 10 'ptx_init_barrier|barrier_init' src/transform/lower_shared_barrier.cc

Repository: tile-ai/tilelang

Length of output: 2985


🏁 Script executed:

# Check if there's automatic initialization of barriers in the lower_hopper_intrin.cc
rg -B 5 -A 10 'barrier.*init|alloc_barrier' src/transform/lower_hopper_intrin.cc | head -80

Repository: tile-ai/tilelang

Length of output: 1752


🏁 Script executed:

# Look at how the warp specialize example handles the barrier priming issue
rg -B 10 -A 10 'for ko in T.Pipelined' examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py | head -60

Repository: tile-ai/tilelang

Length of output: 977


🏁 Script executed:

# Check the NVIDIA PTX documentation reference - what does mbarrier.init set the parity to?
rg -B 5 -A 5 'mbarrier.*init|initial.*parity' src/

Repository: tile-ai/tilelang

Length of output: 6294


🏁 Script executed:

# Verify by checking if other gqa or mha examples have this priming
cat examples/flash_attention_sm100/mha_fwd_bshd.py | sed -n '240,280p'

Repository: tile-ai/tilelang

Length of output: 1684


Add pre-loop barrier priming to prevent pipeline stall on first iteration.

The *_empty barriers are initialized to parity 0 by the compiler, but the main loop immediately waits on parity 1 on the first iteration (lines 269, 280, 291, 344). Without prior arrives to establish parity 1, the pipeline deadlocks before it starts.

🔧 Minimal fix
             loop_range = (
                 T.min(
                     T.ceildiv(seq_len, block_N),
                     T.ceildiv((bx + 1) * block_M, block_N),
                 )
                 if is_causal
                 else T.ceildiv(seq_len, block_N)
             )
+
+            if tid < 128:
+                for stage in T.serial(num_stages):
+                    T.mbarrier_arrive(mbar_bmm1_empty[stage])
+            elif tid >= 160 and tid < 192:
+                for stage in T.serial(num_stages):
+                    T.mbarrier_arrive(mbar_dma1_empty[stage])
+                    T.mbarrier_arrive(mbar_dma2_empty[stage])
+                    T.mbarrier_arrive(mbar_softmax_empty[stage])
 
             for k in T.serial(loop_range):

Also applies to: 280-291, 344-345

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention_sm100/gqa_fwd_bshd.py` around lines 262 - 270, The
loop deadlocks because the mbarrier objects (mbar_dma1_empty, mbar_dma0_empty,
mbar_clear_accum) are initialized at parity 0 but the first iteration
immediately waits for parity 1; to fix this, add a pre-loop priming step that
calls T.mbarrier_arrive_parity(...) with parity 1 for each barrier that the loop
will wait on (mirror the same tid-range conditionals used inside the for k in
T.serial(loop_range) loop), i.e. before the loop execute the appropriate
T.mbarrier_arrive_parity(mbar_dma1_empty, 1),
T.mbarrier_arrive_parity(mbar_dma0_empty, 1), and
T.mbarrier_arrive_parity(mbar_clear_accum, 1) under the same tid checks so the
initial T.mbarrier_wait_parity calls (in the loop) will find parity 1 and not
stall.

Comment on lines +178 to +179
flashattn_ss = flashattn
flashattn_ts = flashattn
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

flashattn_ts still builds the SS kernel.

Because flashattn_ts is only a plain alias, calling flashattn_ts(...) without an explicit variant="ts" still returns the 128-thread SS path. Export thin wrappers that pin the variant instead of re-exporting flashattn directly.

🔧 Minimal fix
-flashattn_ss = flashattn
-flashattn_ts = flashattn
+def flashattn_ss(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128):
+    return flashattn(
+        batch,
+        heads,
+        seq_len,
+        dim,
+        is_causal,
+        block_M=block_M,
+        block_N=block_N,
+        variant="ss",
+    )
+
+
+def flashattn_ts(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128):
+    return flashattn(
+        batch,
+        heads,
+        seq_len,
+        dim,
+        is_causal,
+        block_M=block_M,
+        block_N=block_N,
+        variant="ts",
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention_sm100/mha_fwd_bshd.py` around lines 178 - 179, The
current aliases flashattn_ss = flashattn and flashattn_ts = flashattn simply
re-export the same function so calling flashattn_ts(...) still builds the SS
kernel; replace these aliases with thin wrapper functions named
flashattn_ss(...) and flashattn_ts(...) that call the original flashattn(...)
while forcing variant="ss" and variant="ts" respectively (pass through all other
args/kwargs) so each wrapper reliably selects the intended kernel variant.

Comment on lines +441 to +463
if variant in ("ss", "ts"):
kernel = flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
variant=variant,
)
else:
kernel = flashattn_wasp(
batch,
heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
threads=256,
num_stages=2,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The advertised WASP→TS fallback is not wired up.

This branch always builds flashattn_wasp(...) directly, so variant="wasp" still bubbles up layout/JIT failures even though the module docstring and CLI help say it should fall back to TS. Make this an explicit elif variant == "wasp" branch, fall back to flashattn(..., variant="ts") on WASP build failure, and raise on unknown variants instead of routing every typo to WASP.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/flash_attention_sm100/mha_fwd_bshd.py` around lines 441 - 463, The
current branch always calls flashattn_wasp for any non-"ss"/"ts" variant; change
the control flow so you handle three cases: if variant in ("ss","ts") call
flashattn(..., variant=variant); elif variant == "wasp" attempt to build kernel
= flashattn_wasp(...) inside a try/except and on exception fall back to kernel =
flashattn(..., variant="ts") (preserve the same
block_M/block_N/threads/num_stages args when attempting WASP), and else raise a
clear ValueError for unknown variant values so typos don't default to WASP.

@Hale423
Copy link
Copy Markdown
Contributor Author

Hale423 commented Mar 11, 2026

Performance Report

Hardware: NVIDIA Blackwell (Drive Thor)
Problem size: batch=2, heads=4, seq_len=8192, dim=128, non-causal, group=4(GQA)

MHA Forward

Kernel Arch Variant Latency (ms) TFlops Speedup vs Hopper
example_mha_fwd_bshd.py <=SM90 baseline 11.31 24.30 1.00×
mha_fwd_bshd.py SM100 ss (tcgen05mma_ss) 8.05 34.13 1.40×
mha_fwd_bshd.py SM100 ts (tcgen05mma_ts) 5.72 48.03 1.98×
mha_fwd_bshd.py SM100 wasp (warp-specialized, TMA disabled) 9.64 28.51 1.17×

GQA Forward

Kernel Arch Variant Latency (ms) TFlops Speedup vs Hopper
example_gqa_fwd_bshd.py <=SM90 baseline 10.63 25.86 1.00×
gqa_fwd_bshd.py SM100 ss (tcgen05mma_ss) 7.59 34.15 1.40×
gqa_fwd_bshd.py SM100 ts (tcgen05mma_ts) 5.65 48.04 1.88×
gqa_fwd_bshd.py SM100 wasp (warp-specialized, TMA disabled) 10.08 27.27 1.05×

@Hale423
Copy link
Copy Markdown
Contributor Author

Hale423 commented Mar 11, 2026

Hi @LeiWang1999 , @Rachmanino

I implemented these flash attn SM100 examples, the functionality was verified valid and perf was somehow expected(see the above table). I noticed that some related PRs are on-going, acknowledge still distance away to achieve SOL, specifically:

  • TMA lowering is still in progress, I have to disable TMA loading, making wasp far slower than normal pipeline
  • 2SM/CTA mode is still in progress, currently unable using double-softmax pipeline to boost the perf
  • Auto wasp I learned is also in the progress

I suppose these are what we can do currently, please take a review at your convenience, thanks ~

@LeiWang1999
Copy link
Copy Markdown
Member

@Hale423 Thanks for your contribution! I'll take a look then

@Rachmanino
Copy link
Copy Markdown
Collaborator

Yeah, currently we can only use explicit warp specialize to utilize TMA for TCGEN5MMA. Besides, 2cta tcgen5mma will be supported soon!

@LeiWang1999 LeiWang1999 merged commit 6d7baa7 into tile-ai:main Mar 22, 2026
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants