Skip to content

[Feature] Add T.copy_cluster to support TMA multicast and SM-to-SM cluster copy#1908

Open
He-Jingkai wants to merge 62 commits intotile-ai:mainfrom
He-Jingkai:t_copy_extend
Open

[Feature] Add T.copy_cluster to support TMA multicast and SM-to-SM cluster copy#1908
He-Jingkai wants to merge 62 commits intotile-ai:mainfrom
He-Jingkai:t_copy_extend

Conversation

@He-Jingkai
Copy link

@He-Jingkai He-Jingkai commented Mar 7, 2026

Add multicast copy (tma_load_multicast) and shared-memory-to-shared-memory cluster copy (tma_store_cluster) to the t.copy operator.

Multicast copy:

  • A single CTA issues tma_load_multicast to broadcast a tile to multiple CTAs in the cluster simultaneously; other masked CTAs receive passively.
  • The cluster_mask parameter on T.copy_cluster() controls which CTAs participate.

SM-to-SM cluster copy:

  • Fast path via cp.async.bulk.shared::cluster with mbarrier synchronisation (remote_barrier parameter on T.copy_cluster()).
  • Slow path via cooperative_groups::map_shared_rank for element-by-element fallback (dst_block parameter on T.copy_cluster()).

Supporting changes:

  • New builtins: tma_load_multicast, tma_store_cluster.
  • Codegen (codegen_cuda.cc) and device templates (barrier.h, copy_sm90.h) for all new intrinsics; unified mbarrier API via tl::mbarrier_* free functions.
  • inject_tma_barrier: handle tma_load_multicast; distinguish thread guards from cluster-rank conditions for correct expect_tx injection.
  • lower_hopper_intrin: migrate barrier allocation to builtin::create_barriers; hoist user ptx_init_barrier_thread_count calls alongside compiler barriers.
  • warp_specialized_rewriter: support tma_load_multicast as producer op; UserBarrierInitExtractor to separate user barrier inits from body; ragged (dynamic-extent) pipeline prefix counter; fix nested IfThenElse state corruption in shuffle-elect optimisation.
  • multi_version_buffer_rewriter: ragged pipeline support with runtime prefix counter for correct ping-pong buffer versioning.

We will upstream some use cases soon. Stay Tuned.


Co-authored-by: Guangda Sun 2012661711@qq.com

Summary by CodeRabbit

  • New Features
    • TMA multicast broadcasting and SM-to-SM async cluster copy for NVIDIA Hopper (SM90+); cluster query/sync and barrier-style helpers; copy API now accepts destination-block, multicast mask, and remote-barrier options.
  • Documentation
    • New programming guide detailing cluster TMA patterns, synchronization contracts, and split‑K examples.
  • Tests
    • New regression tests demonstrating SM-to-SM cluster copy and TMA multicast behavior.

He-Jingkai and others added 7 commits March 7, 2026 04:07
… copy

Add multicast copy (tma_load_multicast) and shared-memory-to-shared-memory
cluster copy (tma_store_cluster / ptx_cluster_store) to the t.copy operator.

Multicast copy:
- A single CTA issues tma_load_multicast to broadcast a tile to multiple
  CTAs in the cluster simultaneously; other masked CTAs receive passively.
- New `cluster_mask` parameter on T.copy() controls which CTAs participate.

SM-to-SM cluster copy:
- Fast path via cp.async.bulk.shared::cluster with mbarrier synchronisation
  (new `remote_barrier` parameter on T.copy()).
- Slow path via cooperative_groups::map_shared_rank for element-by-element
  fallback (new `dst_block` parameter on T.copy()).

Supporting changes:
- New builtins: tma_load_multicast, tma_store_cluster, ptx_cluster_store,
  cluster_sync, get_cluster_id, get_cluster_block_rank, get_cluster_block_nums,
  mbarrier_arrive (cluster-scope).
- Codegen (codegen_cuda.cc) and device templates (barrier.h, copy_sm90.h)
  for all new intrinsics; unified mbarrier API via tl::mbarrier_* free functions.
- inject_tma_barrier: handle tma_load_multicast; distinguish thread guards
  from cluster-rank conditions for correct expect_tx injection.
- lower_hopper_intrin: migrate barrier allocation to builtin::create_barriers;
  hoist user ptx_init_barrier_thread_count calls alongside compiler barriers.
- warp_specialized_rewriter: support tma_load_multicast as producer op;
  UserBarrierInitExtractor to separate user barrier inits from body;
  ragged (dynamic-extent) pipeline prefix counter; fix nested IfThenElse
  state corruption in shuffle-elect optimisation.
- multi_version_buffer_rewriter: ragged pipeline support with runtime
  prefix counter for correct ping-pong buffer versioning.
- thread_storage_sync: reserve fixed barrier IDs (kProducer/kConsumer)
  for 256-thread CTA warp-specialised split to prevent deadlocks.
- pipeline_planning: defensive handling of non-BufferLoad mbarrier args.
- Shared memory alignment bumped to 128 bytes for TMA.

Co-authored-by: Guangda Sun <2012661711@qq.com>
Add docs/programming_guides/cluster_tma.md covering the two new
T.copy extensions introduced in the t_copy_extend feature branch:

- TMA multicast (cluster_mask): explains how a single TMA transaction
  broadcasts one global tile to multiple CTAs in a cluster simultaneously,
  with API usage, a per-rank behaviour table, and a complete code example.

- SM-to-SM cluster copy (dst_block / remote_barrier): documents the fast
  path (cp.async.bulk.shared::cluster + mbarrier, single-thread async DMA)
  and the slow path (map_shared_rank element-by-element SIMT fallback),
  including the synchronisation contract for source and destination CTAs.

Also covers cluster helper builtins (T.get_cluster_block_rank,
T.cluster_sync, T.mbarrier_init/arrive/wait_parity, etc.) and a
split-K sketch combining both features end-to-end.
@github-actions
Copy link

github-actions bot commented Mar 7, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 7, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds SM90+ cluster data-movement features: TMA multicast and SM-to-SM cluster copy across TileLang frontend, Tile op, lowering, CUDA codegen, TL templates, transform passes, warp rewriter, tests, and documentation.

Changes

Cohort / File(s) Summary
Documentation
docs/programming_guides/cluster_tma.md
New guide documenting TMA multicast and SM-to-SM cluster copy semantics, examples, and split‑K sketch.
TileLang Builtins
src/op/builtin.h, src/op/builtin.cc
Add intrinsics: mbarrier_arrive, tma_load_multicast, get_cluster_id, get_cluster_block_nums, ptx_cluster_store, tma_store_cluster.
Copy Op & Lowering
src/op/copy.h, src/op/copy.cc, tilelang/language/copy_op.py
Add dst_block field and annotations (cluster_mask, remote_barrier), increase Copy TL op arity (5→6), add LowerClusterCopy fast path (tma_store_cluster) and multicast-aware bulk-copy with SIMT fallback.
CUDA Codegen & TL Templates
src/target/codegen_cuda.cc, src/tl_templates/cuda/barrier.h, src/tl_templates/cuda/cluster.h, src/tl_templates/cuda/copy_sm90.h
Emit tl::mbarrier_*, tl::tma_load_multicast, tl::tma_store_cluster, cluster query intrinsics; add Barrier overloads, cluster_unsupported_trap, and SM90 multicast/cluster-store templates.
Transform Passes
src/transform/inject_tma_barrier.cc, src/transform/lower_hopper_intrin.cc, src/transform/lower_tile_op.cc, src/transform/pipeline_planning.cc
Extend barrier injection and lowering for multicast; reserve barrier ID range, add no-op IfThenElse pruning, and treat multicast like other TMA ops.
Ragged-Pipeline & Warp Rewriter
src/transform/multi_version_buffer_rewriter.cc, src/transform/warp_specialized_rewriter.cc, src/transform/warp_specialized_rewriter.h
Introduce ragged-prefix buffers for dynamic extents, UserBarrierInitExtractor, WSCodeEmitter::Emit, and has_cluster_copy_ detection; integrate multicast/cluster-copy awareness.
Thread Sync / Barrier IDs
src/transform/common/thread_sync_types.h, src/transform/thread_storage_sync.cc
Reserve kProducer/kConsumer barrier IDs, add special-case mapping for 256-thread producer/consumer split, and recognize tma_load_multicast in sync logic.
Engine / Frontend Adjustments
tilelang/engine/phase.py, tilelang/language/builtin.py
Add shared-memory align_bytes to MergeSharedMemoryAllocations; add _get_mbarrier, mbarrier_init, get_cluster_id, cluster_block_nums, and cluster_sync helpers in frontend.
Tests
testing/python/cuda/test_tma_dsmem.py, testing/python/cuda/test_tma_multicast_demo.py
Add SM‑to‑SM cluster‑copy regression test and TMA multicast demo/regression test exercising new intrinsics and sync semantics.

Sequence Diagram(s)

sequenceDiagram
    participant P as Producer CTA (rank 0)
    participant TL as TileLang Kernel / CodeGen
    participant MB as mbarrier (handle)
    participant C as Consumer CTA (rank 1)
    Note right of TL: CodeGen emits tl::tma_store_cluster / tl::tma_load_multicast

    P->>TL: load global -> shared (s_src) via TMA
    TL->>MB: tl::mbarrier_expect_tx(bar, size)
    TL->>TL: tl::tma_store_cluster(dst_shared, s_src, dst_cta=1, size, bar)
    Note over P,C: cp.async.bulk cluster transfer (SM-to-SM) across CTAs
    C->>MB: tl::mbarrier_wait(bar)
    C->>TL: read from shared (s_dst) and write out
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • chengyupku

Poem

🐇
I nibble bytes in shared-mem rows,
multicast whispers where the data goes,
two CTAs in tidy dance,
cp.async sings — give them a chance,
hop, decode, and fast-path shows!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.56% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main feature: adding T.copy_cluster to support TMA multicast and SM-to-SM cluster copy, which aligns with the substantial changes across documentation, APIs, and implementation files.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can use OpenGrep to find security vulnerabilities and bugs across 17+ programming languages.

OpenGrep is compatible with Semgrep configurations. Add an opengrep.yml or semgrep.yml configuration file to your project to enable OpenGrep analysis.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 20

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/transform/thread_storage_sync.cc (1)

459-490: ⚠️ Potential issue | 🟠 Major

Allocate dynamic barrier IDs independently from the fixed producer/consumer IDs.

Once Line 470 or Line 477 inserts the fixed IDs, barrier_id_map_.size() no longer matches the next free numeric barrier slot. If the consumer half is seen first, Line 483 hands out barrier 6 instead of 5, and those gaps make the 16-barrier hardware limit easier to hit. Track a separate next-dynamic ID, and guard against emitting anything above 15.

Suggested fix
-    size_t barrier_id =
-        barrier_id_map_.size() +
-        static_cast<size_t>(ReservedNamedBarriers::kFirstUsedBarrier);
+    size_t barrier_id = next_dynamic_barrier_id_++;
+    ICHECK_LT(barrier_id, 16)
+        << "ThreadPartialSyncRewriter exhausted Hopper named barriers";
size_t next_dynamic_barrier_id_{
    static_cast<size_t>(ReservedNamedBarriers::kFirstUsedBarrier)};
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/transform/thread_storage_sync.cc` around lines 459 - 490, The code
currently derives new dynamic barrier IDs from barrier_id_map_.size(), causing
gaps once fixed producer/consumer IDs
(ReservedNamedBarriers::kConsumer/kProducer) are inserted; introduce a class
member next_dynamic_barrier_id_ initialized to
static_cast<size_t>(ReservedNamedBarriers::kFirstUsedBarrier) and use it instead
of barrier_id_map_.size() when allocating a new barrier in the allocation path
(the place that assigns barrier_id right now and updates barrier_id_map_ and
thread_count_map_); increment next_dynamic_barrier_id_ after assigning, and add
a guard that prevents returning/assigning any barrier_id above the hardware max
(e.g., 15) so you emit an error or handle allocation failure when
next_dynamic_barrier_id_ would exceed the limit.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/programming_guides/cluster_tma.md`:
- Around line 103-105: The sentence is misleading about non-issuer behavior;
update the text referencing
cp.async.bulk.tensor.Nd.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster
to state that masked non-issuer CTAs receive the multicast passively (do not
perform an active load) and only CTAs that fall outside the cluster_mask perform
a standard cp.async.bulk.tensor active load; ensure both the multicast path
(passive receipt) and the fallback active load for CTAs outside the mask are
explicitly described and contrasted.
- Around line 6-10: Update all example snippets and code blocks that use the
now-removed public launch argument name cluster_size to the new name
cluster_dims; search for the symbol "cluster_size" in the document and replace
it with "cluster_dims" in every code/example instance (including the multiple
occurrences called out in the review) so pasted examples match the current API
surface.

In `@src/op/copy.cc`:
- Around line 934-936: The Lower() path currently returns LowerClusterCopy(T,
analyzer) whenever dst_block.defined(), which bypasses target capability checks
and may emit tma_store_cluster/ptx_cluster_store on targets that don't support
cluster copy; update the dst_block branch in Lower() to first check the same
Hopper/cluster capability predicate used for the intrinsic lowering (the
capability check used elsewhere for cluster intrinsics) and only call
LowerClusterCopy(T, analyzer) when that capability is present, otherwise return
an early failure/diagnostic explaining cluster copy is unsupported on the
current target; reference dst_block, Lower(), LowerClusterCopy, and the
tma_store_cluster/ptx_cluster_store intrinsics when making the change.
- Around line 1476-1515: The code assumes flat contiguity when building the
tma_store_cluster call; instead implement a contiguity check (e.g., add an
is_contiguous(Buffer&, Array<Range>&) helper that uses compute_linear_offset to
compute min_offset and max_offset (using ranges[i]->min and ranges[i]->min +
ranges[i]->extent - 1) and verifies (max_offset - min_offset + 1 ==
total_elements)). Use that check before taking the fast path that builds
dst_ptr/src_ptr and calls tma_store_cluster; if either src or dst is not
contiguous, fall back to the existing element-wise cluster-store path (same code
path used when not in this fast TMA branch). Make sure to reference
compute_linear_offset, dst_offset, src_offset, total_elements,
tma_store_cluster, and the in_tma_context_/HandleAccessPtrAndOffset behavior
when locating where to insert the check.

In `@src/tl_templates/cuda/cluster.h`:
- Around line 19-21: Replace all assert(false && "CLUSTER_ENABLED is not
defined") fallbacks and any code paths that assert then return zero with a hard
trap: use asm volatile("trap") (as used in tcgen_05_ld.h) so unsupported cluster
code cannot continue with bogus IDs. Specifically, find occurrences of the
CLUSTER_ENABLED guard and every function/macro in this header that currently
does an assert(false) (and the ones that follow that with return 0) and change
them to invoke asm volatile("trap") before any return; do this for every
assert-based fallback so the code fails hard on unsupported GPUs rather than
becoming a no-op in release builds.

In `@src/transform/inject_tma_barrier.cc`:
- Around line 163-177: The current check sets is_thread_eq true for any EQNode
that mentions thread_var_ (via UsesVar), which is too permissive; instead,
detect only equality predicates that uniquely select a single thread index
(e.g., threadIdx.x == CONST or threadIdx.x +/− CONST == CONST). Replace the
UsesVar-based logic in the op->condition EQNode branch with a structural
analyzer that accepts only EQNode patterns where one side is the thread_var_
(optionally wrapped by a single Add/Sub with an IntImm) and the other side is a
constant IntImm (or equivalent cast to IntImm); reject patterns involving
modulus, bitwise-and, other arithmetic, or non-constant sides. Keep the rest of
the flow (is_thread_eq usage and the guard around expect_tx/mbarrier_expect_tx)
unchanged so injection only happens when the predicate guarantees exactly one
thread value.
- Around line 103-112: The current VisitStmt_(const IfThenElseNode *op)
unconditionally skips visiting the else_case and assumes both arms transfer
equal bytes for mbarrier_expect_tx; instead, compute the transmitted-byte count
for both branches (e.g., by invoking the same byte-accounting visitor on
op->then_case and op->else_case) and compare the results, collapsing to a single
visit only if they are equal; otherwise keep the accounting branch-local by
visiting both branches separately so mbarrier_expect_tx is correct for each
path. Ensure you reference VisitStmt_(const IfThenElseNode *op),
StmtExprVisitor::VisitStmt/VisitStmt_, and the mbarrier_expect_tx accounting
logic when making the change.

In `@src/transform/lower_hopper_intrin.cc`:
- Around line 206-215: The code currently reuses get_mbarrier(i) starting at
zero causing aliasing across multiple create_list_of_mbarrier() calls; fix by
assigning a base offset from the current num_managed_barriers_ (e.g. int base =
num_managed_barriers_), then use get_mbarrier(base + i) when constructing each
mbarrier and push to init_mbarrier_calls_, and finally increment
num_managed_barriers_ by num_barriers (or set it after the loop) so each list
gets its own unique handle slice; update references in this block around
create_list_of_mbarrier(), num_managed_barriers_, get_mbarrier(),
init_mbarrier_calls_, and the local mbarrier variable accordingly.

In `@src/transform/lower_tile_op.cc`:
- Around line 663-671: The current branch sets in_tma_context_ for
tma_store_cluster(), which suppresses the normal shared-layout address remapping
and causes tvm_access_ptr/address_of operands to target wrong byte ranges;
instead, do not flip in_tma_context_ for the tma_store_cluster() path — call
Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op)) without setting
in_tma_context_ so the regular layout-offset rewrite runs for
tvm_access_ptr/address_of, but still ensure has_tma_ is not set for this op
(preserve the original intent of not triggering warp specialization).

In `@src/transform/pipeline_planning.cc`:
- Around line 372-395: The current branch in VisitExpr_ for mbarrier_wait_parity
drops dependency modeling when args[0] is not a BufferLoadNode by calling
StmtExprVisitor::VisitExpr_(op), which loses read/write edges; update the logic
handling mbarrier_wait_parity so that handle-based barriers (e.g.
tl.get_mbarrier(id)) are normalized into the same maps
(chain_builder_.mbar_to_buffer_reads_ / mbar_to_buffer_writes_) or, if the
concrete Buffer cannot be resolved, conservatively attach the wait to the set of
pending async buffers (the same collections used for BufferLoadNode) instead of
falling back to StmtExprVisitor::VisitExpr_. Locate this in the
mbarrier_wait_parity handling block in pipeline_planning.cc and modify the else
path to look up/normalize the handle (or add conservative edges) so reads_ and
writes_ are populated consistently.

In `@src/transform/warp_specialized_rewriter.cc`:
- Around line 1204-1209: The current code wraps the kept pipelined loop into
SeqStmt({result, update}) which hides the For from downstream post-processing
(so num_stages/pipeline annotations and GroupOpRewriter never run) and also
advances ragged_prefix_buf_ even when FilterByRole dropped the loop; instead,
append the ragged-prefix update only when the loop is actually preserved: avoid
creating a SeqStmt that nests the For—keep result as the loop Stmt and attach
the BufferStore update in the same AST level or conditional branch after
loop-preservation checks (use the existing LoadRaggedPrefix(),
ragged_prefix_buf_, and the computed new_prefix and update Stmt) so the For
remains visible to the post-processing and only advance the prefix when the loop
wasn't filtered out.
- Around line 161-164: The stage-local mbarrier rewrite in MbarrierRewriter
currently retargets only tma_load() and tma_load_im2col(), leaving
tma_load_multicast() signaling the original barrier and causing a mismatch with
the consumer; update the MbarrierRewriter rewrite logic to include
tma_load_multicast() alongside tma_load() and tma_load_im2col() so multicast
producers are retargeted to the per-stage barrier as well, ensuring consistent
signaling and waiting between producer and consumer.
- Around line 1477-1503: The extractor currently hoists entire IfThenElseNodes
(VisitStmt_) even when they have no-else or live inside an
Allocate/T.alloc_shared scope, which moves inits out of scope and drops else
branches; change VisitStmt_ in UserBarrierInitExtractor to only extract when the
if has no else branch (check op->else_case is null/empty) and then push the
inner init statement(s) (op->then_case or unwrapped SeqStmt) into init_stmts
instead of the whole IfThenElseNode; also ensure IsOnlyInit still recognizes
nested SeqStmt of Evaluate(CallNode) as before so callers can reinsert the init
inside the allocating scope rather than at block root.

In `@src/transform/warp_specialized_rewriter.h`:
- Around line 42-46: The current check uses the single-bit
detector.has_cluster_copy_ to exempt all mbarrier+TMA cases, which wrongly
preserves warp-specialisation when a kernel mixes regular tma_load* paths with
an unrelated tma_store_cluster; update the detector and the condition to
distinguish cluster-copy-specific TMA paths from regular TMA-load paths (e.g.,
add a flag like has_regular_tma_load_ or has_cluster_copy_path_ that tracks
whether the mbarrier is tied to the cluster-copy path), then change the
condition using detector.has_tma_op_ && detector.has_mbarrier_op_ &&
!detector.has_cluster_copy_ to instead require that no regular TMA-load path
exists (e.g., detector.has_tma_op_ && detector.has_mbarrier_op_ &&
!detector.has_regular_tma_load_) or scope the exemption to the specific
barrier/copy path; apply the same change to the other occurrence referenced
(lines ~81-83).

In `@testing/python/cuda/test_tma_dsmem.py`:
- Around line 65-92: Convert this script into a real pytest test by replacing
main() with a test function (e.g., test_tma_store_cluster_copy) that calls
make_store_cluster_kernel(N) and runs the kernel; use pytest.skip when the
device capability is < 9 (instead of printing/returning), and replace the
print-based pass/fail with an assert that np.allclose(result, expected) (or
assert with a helpful message showing max diff). Ensure the test imports pytest
and references make_store_cluster_kernel and the produced result/expected arrays
so CI will fail on mismatches.

In `@testing/python/cuda/test_tma_multicast_demo.py`:
- Around line 51-70: Guard the test_tma_multicast_demo test to run only on CUDA
devices with Hopper (SM >= 9.0) hardware: at the start of
test_tma_multicast_demo check torch.cuda.is_available() and
torch.cuda.get_device_capability(0) (or
torch.cuda.get_device_properties(0).major) to ensure SM >= 9, and call
pytest.skip(...) (or return) when the checks fail so
make_tma_multicast_demo_kernel, tilelang.compile and the mod(A) launch are not
executed on non-CUDA or pre-SM90 runners.

In `@tilelang/engine/phase.py`:
- Around line 295-300: shared_align_bytes is currently set to 128 only when
allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target) is true, but
kernels that get TMA-lowered via allow_tma_lower() (e.g. Hopper TMA) should also
receive 128-byte shared alignment; update the condition that computes
shared_align_bytes to check allow_tma_lower(pass_ctx=pass_ctx, target=target)
(or OR it with the existing allow_tma_and_warp_specialized check) so
MergeSharedMemoryAllocations(enable_aggressive_merge=...,
align_bytes=shared_align_bytes) receives 128 for any TMA-lowered kernel.

In `@tilelang/language/builtin.py`:
- Around line 427-443: The wrapper mbarrier_init currently accepts int/PrimExpr
and calls _get_mbarrier, but _get_mbarrier isn't implemented so ID-based inputs
fail at runtime; remove support for integer/PrimExpr IDs for now by deleting the
branch that handles isinstance(mbarrier, (tir.PrimExpr, int)) and its call to
_get_mbarrier, and update the TypeError message in mbarrier_init to only accept
tir.Call, tir.Buffer, or tir.BufferLoad (so callers must provide a real handle
or buffer until _get_mbarrier is implemented); keep the existing tir.BufferLoad
and tir.Call handling and the final tir.call_intrin(invocation) unchanged.

In `@tilelang/language/copy_op.py`:
- Around line 79-85: Update the docstring for the parameters (around dst_block
and cluster_mask) in copy_op.py to correct multicast semantics: state that when
cluster_mask is set the CTA whose rank equals the lowest set bit issues
tma_load_multicast and all other CTAs included in the mask receive the data
passively from that multicast (they must not issue redundant tma_load calls or
additional synchronization), and clarify that cluster_mask=None disables
multicast and causes regular per-CTA TMA loads; reference the cluster_mask
parameter and the description of the multicast leader (lowest set bit) so
readers can locate and edit the docstring.
- Around line 137-138: The code currently overwrites ann["dst_block"] whenever
dst_block is not None, violating the stated precedence that annotations beat
individual kwargs; update the assignment in the function/method that builds ann
(the block handling code where ann, annotations, dst_block are in scope) to only
set ann["dst_block"] when dst_block is not None AND "dst_block" is not already
present in ann (i.e., check 'if dst_block is not None and "dst_block" not in
ann: ann["dst_block"] = dst_block'), matching the existing precedence handling
used for cluster_mask and remote_barrier.

---

Outside diff comments:
In `@src/transform/thread_storage_sync.cc`:
- Around line 459-490: The code currently derives new dynamic barrier IDs from
barrier_id_map_.size(), causing gaps once fixed producer/consumer IDs
(ReservedNamedBarriers::kConsumer/kProducer) are inserted; introduce a class
member next_dynamic_barrier_id_ initialized to
static_cast<size_t>(ReservedNamedBarriers::kFirstUsedBarrier) and use it instead
of barrier_id_map_.size() when allocating a new barrier in the allocation path
(the place that assigns barrier_id right now and updates barrier_id_map_ and
thread_count_map_); increment next_dynamic_barrier_id_ after assigning, and add
a guard that prevents returning/assigning any barrier_id above the hardware max
(e.g., 15) so you emit an error or handle allocation failure when
next_dynamic_barrier_id_ would exceed the limit.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ac4c6e39-f962-4c65-ab08-2abf507a83a7

📥 Commits

Reviewing files that changed from the base of the PR and between 4929ad8 and 5d57c11.

📒 Files selected for processing (23)
  • docs/programming_guides/cluster_tma.md
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/op/copy.cc
  • src/op/copy.h
  • src/target/codegen_cuda.cc
  • src/tl_templates/cuda/barrier.h
  • src/tl_templates/cuda/cluster.h
  • src/tl_templates/cuda/copy_sm90.h
  • src/transform/common/thread_sync_types.h
  • src/transform/inject_tma_barrier.cc
  • src/transform/lower_hopper_intrin.cc
  • src/transform/lower_tile_op.cc
  • src/transform/multi_version_buffer_rewriter.cc
  • src/transform/pipeline_planning.cc
  • src/transform/thread_storage_sync.cc
  • src/transform/warp_specialized_rewriter.cc
  • src/transform/warp_specialized_rewriter.h
  • testing/python/cuda/test_tma_dsmem.py
  • testing/python/cuda/test_tma_multicast_demo.py
  • tilelang/engine/phase.py
  • tilelang/language/builtin.py
  • tilelang/language/copy_op.py

@LeiWang1999 LeiWang1999 requested a review from Rachmanino March 7, 2026 10:05
@LeiWang1999
Copy link
Member

awesome contributions! cc @chengyupku @Rachmanino

@He-Jingkai
Copy link
Author

@Rachmanino It seems that this PR and #1874 have introduced some APIs with the same functionality. We will unify them into the #1874 API.

@Rachmanino
Copy link
Collaborator

Rachmanino commented Mar 9, 2026

Thanks for you awesome contribution! Left some comments though, we can discuss on them later. I can also fix them for you if needed :)

@Rachmanino Rachmanino changed the title [Feature] Extend t.copy to support TMA multicast and SM-to-SM cluster copy [Feature] Extend T.copy to support TMA multicast and SM-to-SM cluster copy Mar 9, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (3)
docs/programming_guides/cluster_tma.md (1)

6-6: ⚠️ Potential issue | 🟡 Minor

Use the current cluster_dims launch argument in the examples.

These snippets still use the old cluster_size name, so readers copying them will get stale API usage.

Also applies to: 22-25, 77-77, 162-162, 260-260

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/programming_guides/cluster_tma.md` at line 6, The documentation examples
use the deprecated launch argument name `cluster_size`; update every occurrence
(including the instances at lines noted) to the current argument name
`cluster_dims` so examples match the current API; search for and replace
`cluster_size` with `cluster_dims` in code snippets and prose within this
document (e.g., in example invocations and parameter descriptions) to ensure
examples compile and reflect the active launch argument.
src/transform/inject_tma_barrier.cc (1)

104-112: ⚠️ Potential issue | 🟠 Major

Keep expect_tx byte accounting branch-local.

This still drops the else arm without proving both paths move the same byte count. A nested conditional can legitimately take a smaller or no-op alternative path, which makes the injected mbarrier_expect_tx wrong for that branch. Only collapse after proving equality; otherwise visit both arms separately.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/transform/inject_tma_barrier.cc` around lines 104 - 112, The current
VisitStmt_(const IfThenElseNode *op) collapses byte accounting by only visiting
the then branch which can mis-count mbarrier_expect_tx for the else path; change
VisitStmt_ for IfThenElseNode to handle branches locally: when
op->else_case.defined(), snapshot the mbarrier_expect_tx (and any related
byte-accounting state), call StmtExprVisitor::VisitStmt on op->then_case and
record its expect_tx delta, restore the snapshot, call
StmtExprVisitor::VisitStmt on op->else_case and record its expect_tx delta, and
only collapse the two deltas if they are provably equal (otherwise keep
branch-local accounting). Ensure you reference and update the same
mbarrier_expect_tx variable used elsewhere so both branches are visited
independently and state is restored between visits.
src/op/copy.cc (1)

1482-1515: ⚠️ Potential issue | 🔴 Critical

The bulk cluster fast path only preserves bytewise memcpy semantics.

compute_linear_offset plus size_bytes = total_elements * src->dtype.bytes() flatten this copy into a raw byte span. That is only correct when the src/dst slices are contiguous and share the same physical layout and dtype; strided subregions, swizzled-vs-linear shared buffers, or dtype-changing T.copy calls will silently copy the wrong bytes here. Fall back to the SIMT path unless those properties are proven first.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.cc` around lines 1482 - 1515, The bulk-copy fast path flattens
ranges into a raw byte span (compute_linear_offset + size_bytes = total_elements
* src->dtype.bytes()) which is only correct for contiguous, same-layout,
same-dtype transfers; update the logic before building dst_ptr/src_ptr and
calling tma_store_cluster() to verify that src and dst are element-contiguous,
have identical dtype, and are not subject to swizzle or layout transformations
(i.e., the exact conditions required for a bytewise memcpy). If any of those
checks fail, do not construct the TMA bulk path (dst.access_ptr/src.access_ptr +
tma_store_cluster call); instead fall back to the existing SIMT copy path so
element-wise semantics are preserved. Ensure the guards reference
compute_linear_offset, total_elements/size_bytes, dst.access_ptr, src.access_ptr
and tma_store_cluster so the correct branch is chosen at runtime/compile-time.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/programming_guides/cluster_tma.md`:
- Around line 41-46: The fenced ASCII diagram block starting with "Global memory
──TMA multicast──▶ shared memory (rank 0)" is unlabeled and triggers
markdownlint; update the opening fence to include a language label (use "text")
so the fence becomes ```text and the rest of the block remains unchanged,
ensuring the diagram stays the same but the markdown linter no longer flags it.

In `@src/op/copy.cc`:
- Around line 1863-1865: The cluster multicast mask from GetClusterMask() is
only applied inside LowerBulkCopy(), so non-descriptor or fallback paths (e.g.,
when copy_kind is kBulkLoad1D or other non-bulk paths) ignore it; update the
copy selection/dispatch logic to consult cluster_mask (and the computed
use_multicast) before deciding per-CTA vs multicast strategies so the mask is
honored across all copy paths—specifically propagate or check
cluster_mask/use_multicast where copy mode is selected (places that branch on
kBulkLoad1D, LowerBulkCopy(), or fallback handlers) and ensure any fallback to
per-CTA loads respects the mask by choosing the corresponding multicast-aware
code path or forcing the same source-coordinate logic as the descriptor path.
- Around line 1568-1581: Summary: The SIMT fallback flattens indices using
logical layout (op->indices and op->buffer->shape) but writes into a remapped
physical destination (target_dst_ / target_buffer.access_ptr(2)), causing wrong
element stores. Fix: replace the current linearization that uses op->indices and
op->buffer->shape in this block (where linearized_index is computed and pushed
to args) with a remapped/physical index consistent with target_dst_; use the
target_dst_ remapping API or compute linear index from the target buffer's
physical layout/strides (the same mapping used by ptx_cluster_store /
target_buffer.access_ptr(2)) so produced linearized_index matches the physical
destination. Ensure references: update code around linearized_index,
op->indices, op->buffer->shape, target_dst_, and ptx_cluster_store to use the
same mapping.

In `@src/transform/lower_hopper_intrin.cc`:
- Around line 140-145: The managed-barrier allocation currently only reserves
num_managed_barriers_ slots starting at ID 0 which can collide with
backend-reserved IDs (1 and 2); change the allocation and base so user-managed
barriers start at ID 3. Concretely, introduce a reserved_offset (3) and adjust
the allocation call (the CreateBarriers invocation currently assigned to
alloc_mbarrier) to request reserved_offset + num_managed_barriers_ slots, and
ensure any use of get_mbarrier(barrier_base + i) (and the barrier_base value set
elsewhere) is shifted by reserved_offset; also verify InjectTmaBarrier logic
that rewrites descriptor-based TMA loads to get_mbarrier(0) is updated to
account for the new base so no user-managed slot overlaps IDs 1 and 2.

---

Duplicate comments:
In `@docs/programming_guides/cluster_tma.md`:
- Line 6: The documentation examples use the deprecated launch argument name
`cluster_size`; update every occurrence (including the instances at lines noted)
to the current argument name `cluster_dims` so examples match the current API;
search for and replace `cluster_size` with `cluster_dims` in code snippets and
prose within this document (e.g., in example invocations and parameter
descriptions) to ensure examples compile and reflect the active launch argument.

In `@src/op/copy.cc`:
- Around line 1482-1515: The bulk-copy fast path flattens ranges into a raw byte
span (compute_linear_offset + size_bytes = total_elements * src->dtype.bytes())
which is only correct for contiguous, same-layout, same-dtype transfers; update
the logic before building dst_ptr/src_ptr and calling tma_store_cluster() to
verify that src and dst are element-contiguous, have identical dtype, and are
not subject to swizzle or layout transformations (i.e., the exact conditions
required for a bytewise memcpy). If any of those checks fail, do not construct
the TMA bulk path (dst.access_ptr/src.access_ptr + tma_store_cluster call);
instead fall back to the existing SIMT copy path so element-wise semantics are
preserved. Ensure the guards reference compute_linear_offset,
total_elements/size_bytes, dst.access_ptr, src.access_ptr and tma_store_cluster
so the correct branch is chosen at runtime/compile-time.

In `@src/transform/inject_tma_barrier.cc`:
- Around line 104-112: The current VisitStmt_(const IfThenElseNode *op)
collapses byte accounting by only visiting the then branch which can mis-count
mbarrier_expect_tx for the else path; change VisitStmt_ for IfThenElseNode to
handle branches locally: when op->else_case.defined(), snapshot the
mbarrier_expect_tx (and any related byte-accounting state), call
StmtExprVisitor::VisitStmt on op->then_case and record its expect_tx delta,
restore the snapshot, call StmtExprVisitor::VisitStmt on op->else_case and
record its expect_tx delta, and only collapse the two deltas if they are
provably equal (otherwise keep branch-local accounting). Ensure you reference
and update the same mbarrier_expect_tx variable used elsewhere so both branches
are visited independently and state is restored between visits.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3637adfa-cf51-472e-814c-a877a8d589c7

📥 Commits

Reviewing files that changed from the base of the PR and between 5d57c11 and cbc15d7.

📒 Files selected for processing (6)
  • docs/programming_guides/cluster_tma.md
  • src/op/copy.cc
  • src/tl_templates/cuda/cluster.h
  • src/transform/inject_tma_barrier.cc
  • src/transform/lower_hopper_intrin.cc
  • src/transform/lower_tile_op.cc
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/tl_templates/cuda/cluster.h

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
src/op/copy.cc (3)

1479-1518: ⚠️ Potential issue | 🟠 Major

The bulk cluster fast path still assumes contiguous byte spans.

The compute_linear_offset lambda computes offset from range mins and total_elements = product(extents). This is only correct for contiguous regions. A 2D/ND subregion inside a larger shared buffer is generally strided, so tma_store_cluster may copy wrong elements.

Consider adding a contiguity check before taking this fast path and falling back to the element-wise path otherwise.


1568-1581: ⚠️ Potential issue | 🟠 Major

SIMT fallback linearizes using logical layout, not physical.

When target_dst_ is remapped (swizzled), ptx_cluster_store writes into the physical buffer span, but linearized_index is derived from op->indices and op->buffer->shape (logical layout). For swizzled destinations, this will address the wrong element.

The linearization should use the target buffer's physical layout/strides consistent with how target_buffer.access_ptr(2) addresses memory.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.cc` around lines 1568 - 1581, The linearization currently builds
linearized_index from op->indices and op->buffer->shape (logical layout), which
is wrong when target_dst_ is remapped; change linearization to compute the index
using the target (physical) buffer layout/strides used by
target_buffer.access_ptr(2) so ptx_cluster_store writes the correct physical
location. Concretely: in the block that sets linearized_index (and uses
multiplier), replace op->buffer->shape and op->indices as sources with the
corresponding physical shape/stride information from the target buffer (the
remapped target_dst_ / target_buffer used for access_ptr), or compute the
linearized index by folding op->indices through the same stride vector/mapping
that target_buffer.access_ptr(2) uses; keep the same loop structure but multiply
by physical strides instead of logical shape multipliers so ptx_cluster_store
addresses correctly.

1863-1865: ⚠️ Potential issue | 🟠 Major

cluster_mask is only honored in the descriptor-based bulk-load path.

This flag is consumed only inside LowerBulkCopy(). If a copy falls back to kBulkLoad1D or any non-bulk path, the annotation is silently ignored and masked CTAs revert to per-CTA loads.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.cc` around lines 1863 - 1865, The cluster_mask from
GetClusterMask() is only used inside LowerBulkCopy(), so when the copy falls
back to kBulkLoad1D or any non-bulk path the multicast annotation is ignored;
change the path-selection and lowering logic so the cluster_mask decision is
made before choosing a lowering path and is propagated into all load
implementations: compute use_multicast = (is_load && GetClusterMask() > 0) at
the top-level (where use_multicast is currently set), and pass the
cluster_mask/use_multicast flag into LowerBulkCopy() and the fallback loaders
(the kBulkLoad1D handler and any non-bulk lowering functions) so masked-CTA
behavior is honored across all paths. Ensure all loaders consult the passed
cluster_mask/use_multicast instead of assuming no multicast.
🧹 Nitpick comments (1)
src/tl_templates/cuda/cluster.h (1)

14-15: Consider adding [[noreturn]] attribute for compiler optimization.

The trap function correctly addresses the previous review concern about assert-only fallbacks. Since trap terminates execution and never returns, marking the function with [[noreturn]] (C++11) or __attribute__((noreturn)) would help the compiler eliminate dead code paths and make the intent explicit.

♻️ Optional improvement
-TL_DEVICE void cluster_unsupported_trap() { asm volatile("trap;"); }
+TL_DEVICE [[noreturn]] void cluster_unsupported_trap() { asm volatile("trap;"); }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/tl_templates/cuda/cluster.h` around lines 14 - 15, The function
cluster_unsupported_trap should be marked as non-returning to aid optimizations
and make intent explicit: update the declaration/definition of
cluster_unsupported_trap to include a noreturn annotation (e.g., C++11
[[noreturn]] or compiler-specific __attribute__((noreturn))) on the TL_DEVICE
void cluster_unsupported_trap() function and keep the asm volatile("trap;");
body unchanged so callers and the compiler know the function never returns.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/transform/lower_hopper_intrin.cc`:
- Around line 145-152: The allocation size for create_barriers() must cover
every hoisted mbarrier referenced by both create_list_of_mbarrier() and direct
ptx_init_barrier_thread_count() calls; track the maximum required slots in a
shared counter (e.g., num_required_barriers_ or extend num_managed_barriers_)
whenever create_list_of_mbarrier() assigns a new slice or when
ptx_init_barrier_thread_count() references get_mbarrier(id), and use that max
value (including kReservedBarriers) when emitting the create_barriers() call so
the allocation always covers the highest referenced barrier index.

---

Duplicate comments:
In `@src/op/copy.cc`:
- Around line 1568-1581: The linearization currently builds linearized_index
from op->indices and op->buffer->shape (logical layout), which is wrong when
target_dst_ is remapped; change linearization to compute the index using the
target (physical) buffer layout/strides used by target_buffer.access_ptr(2) so
ptx_cluster_store writes the correct physical location. Concretely: in the block
that sets linearized_index (and uses multiplier), replace op->buffer->shape and
op->indices as sources with the corresponding physical shape/stride information
from the target buffer (the remapped target_dst_ / target_buffer used for
access_ptr), or compute the linearized index by folding op->indices through the
same stride vector/mapping that target_buffer.access_ptr(2) uses; keep the same
loop structure but multiply by physical strides instead of logical shape
multipliers so ptx_cluster_store addresses correctly.
- Around line 1863-1865: The cluster_mask from GetClusterMask() is only used
inside LowerBulkCopy(), so when the copy falls back to kBulkLoad1D or any
non-bulk path the multicast annotation is ignored; change the path-selection and
lowering logic so the cluster_mask decision is made before choosing a lowering
path and is propagated into all load implementations: compute use_multicast =
(is_load && GetClusterMask() > 0) at the top-level (where use_multicast is
currently set), and pass the cluster_mask/use_multicast flag into
LowerBulkCopy() and the fallback loaders (the kBulkLoad1D handler and any
non-bulk lowering functions) so masked-CTA behavior is honored across all paths.
Ensure all loaders consult the passed cluster_mask/use_multicast instead of
assuming no multicast.

---

Nitpick comments:
In `@src/tl_templates/cuda/cluster.h`:
- Around line 14-15: The function cluster_unsupported_trap should be marked as
non-returning to aid optimizations and make intent explicit: update the
declaration/definition of cluster_unsupported_trap to include a noreturn
annotation (e.g., C++11 [[noreturn]] or compiler-specific
__attribute__((noreturn))) on the TL_DEVICE void cluster_unsupported_trap()
function and keep the asm volatile("trap;"); body unchanged so callers and the
compiler know the function never returns.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9ce26ceb-5e92-4f03-84c8-22fd00d08fd1

📥 Commits

Reviewing files that changed from the base of the PR and between 5d57c11 and f2cb665.

📒 Files selected for processing (6)
  • docs/programming_guides/cluster_tma.md
  • src/op/copy.cc
  • src/tl_templates/cuda/cluster.h
  • src/transform/inject_tma_barrier.cc
  • src/transform/lower_hopper_intrin.cc
  • src/transform/lower_tile_op.cc
🚧 Files skipped from review as they are similar to previous changes (1)
  • docs/programming_guides/cluster_tma.md

Copy link
Collaborator

@Rachmanino Rachmanino left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some minor comments. Besides, I think we should discuss with @LeiWang1999 on the frontend API design: whether to append new params to T.copy (I'm a bit worried about that there seems to have many params now) or to create a new interface (sth like T.copy_cluster)

@He-Jingkai
Copy link
Author

whether to append new params to T.copy (I'm a bit worried about that there seems to have many params now) or to create a new interface (sth like T.copy_cluster)

I agree. Creating a more specific interface like T.copy_cluster seems like a cleaner way.

@LeiWang1999, do you have a preference?

@He-Jingkai
Copy link
Author

He-Jingkai commented Mar 17, 2026

@Rachmanino We have gone with T.copy_cluster and keep T.copy simple and focused. And the comments are fixed.

@He-Jingkai He-Jingkai changed the title [Feature] Extend T.copy to support TMA multicast and SM-to-SM cluster copy [Feature] Add T.copy_cluster to support TMA multicast and SM-to-SM cluster copy Mar 17, 2026
@He-Jingkai He-Jingkai requested a review from Rachmanino March 17, 2026 09:33
@Rachmanino
Copy link
Collaborator

Plz remove unused mbarrier_arrive in builin.cc. Another small question: I see nothing like ptx_arrive_barrier in the codebase changes. Is it not used at all in this pr?

@He-Jingkai
Copy link
Author

He-Jingkai commented Mar 17, 2026

@Rachmanino

Good catch, and thanks for the careful review. To clarify: this PR does not use ptx_arrive_barrier, but it does introduce use of ptx_arrive_cluster_barrier.

T.cluster_copy is lowered through three distinct paths in LowerClusterCopy (copy.cc):

  • TMA multicast (cluster_mask set): a single tma_load_multicast instruction broadcasts one global tile to multiple CTAs; the mbarrier is managed entirely by the TMA hardware pipeline via inject_tma_barrier.

  • Bulk async SM-to-SM fast path (contiguous regions): a single elected thread issues cp.async.bulk.shared::cluster, with mbarrier.arrive.expect_tx.shared::cluster embedded inside tl::tma_store_cluster. The barrier is again maintained by the TMA hardware.

  • SIMT element-wise fallback (non-contiguous src/dst regions such as non-full-span multi-dimensional tiles): all threads cooperatively write via cooperative_groups::map_shared_rank. When a remote_barrierwas supplied but the fast path could not be taken, we auto-generate a thread-0-guarded ptx_arrive_cluster_barrier call (copy.cc line 1704-1720) to notify the destination CTA's mbarrier.

So ptx_arrive_cluster_barrier appears specifically in path 3 as the software-level barrier signal that substitutes for what the TMA hardware would have done automatically in paths 1 and 2.

@He-Jingkai
Copy link
Author

@Rachmanino Let me know if there are further changes needed ~

Previously, a non-contiguous copy region (e.g. a 2-D slice where the
inner dim does not span the full buffer width) with a remote_barrier
fell back to element-wise SIMT stores, losing TMA throughput.
This commit introduces a "multi-TMA" middle path in LowerClusterCopy:
the ND copy region is recursively decomposed into individual contiguous
rows, each emitting one tma_store_cluster call.  The mbarrier
arrive_count is updated to N (the number of rows) so the destination
CTA's wait() still works correctly without any API change.
Key changes:
- src/op/copy.cc: add static MakeTMARows() that recursively splits an
  ND region into contiguous rows; static extents are unrolled at
  compile time, symbolic extents produce TIR For loops.
- src/op/operator.h: add UpdateBarrierArriveCallback (Var → PrimExpr)
  to LowerArgs so LowerClusterCopy can propagate the new arrive_count.
- src/transform/lower_tile_op.cc: collect arrive_count overrides in
  barrier_arrive_updates_ and apply them to the barrier_init block
  annotation after visiting the block body, before LowerSharedBarrier
  consumes it.
- testing/python/cuda/test_tma_dsmem.py: replace the old SIMT-barrier
  test with test_store_cluster_multi_tma_barrier (2-D, M rows) and add
  test_store_cluster_3d_multi_tma (3-D, D×M rows) to cover the
  two-level recursive decomposition.
@He-Jingkai
Copy link
Author

SIMT element-wise fallback (non-contiguous src/dst regions such as non-full-span multi-dimensional tiles): all threads cooperatively write via cooperative_groups::map_shared_rank. When a remote_barrierwas supplied but the fast path could not be taken, we auto-generate a thread-0-guarded ptx_arrive_cluster_barrier call (copy.cc line 1704-1720) to notify the destination CTA's mbarrier.

I have refactored SM-to-SM copy to use TMA and hardware-managed mbarriers for non-full-span multi-dimensional tiles; SIMT fallback is now reserved only for reshape copies.

@Rachmanino
Copy link
Collaborator

I have no further concern as long as CI passed, cc @LeiWang1999. Besides, just a friendly note: some TileLang APIs have changed in recent prs (e.g. it is required to use T.tma_copy to manually bind tma load and barrier)

@He-Jingkai
Copy link
Author

He-Jingkai commented Mar 18, 2026

@Rachmanino Thanks. (´▽`) It seems that the CI is broken.

@He-Jingkai
Copy link
Author

@Rachmanino

We've merged the recent changes. Could you please help us to rerun the CI?

@Rachmanino
Copy link
Collaborator

@Rachmanino

We've merged the recent changes. Could you please help us to rerun the CI?

Sure, ask me whenever in need!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants