[Feature] Add T.copy_cluster to support TMA multicast and SM-to-SM cluster copy#1908
[Feature] Add T.copy_cluster to support TMA multicast and SM-to-SM cluster copy#1908He-Jingkai wants to merge 62 commits intotile-ai:mainfrom
Conversation
… copy Add multicast copy (tma_load_multicast) and shared-memory-to-shared-memory cluster copy (tma_store_cluster / ptx_cluster_store) to the t.copy operator. Multicast copy: - A single CTA issues tma_load_multicast to broadcast a tile to multiple CTAs in the cluster simultaneously; other masked CTAs receive passively. - New `cluster_mask` parameter on T.copy() controls which CTAs participate. SM-to-SM cluster copy: - Fast path via cp.async.bulk.shared::cluster with mbarrier synchronisation (new `remote_barrier` parameter on T.copy()). - Slow path via cooperative_groups::map_shared_rank for element-by-element fallback (new `dst_block` parameter on T.copy()). Supporting changes: - New builtins: tma_load_multicast, tma_store_cluster, ptx_cluster_store, cluster_sync, get_cluster_id, get_cluster_block_rank, get_cluster_block_nums, mbarrier_arrive (cluster-scope). - Codegen (codegen_cuda.cc) and device templates (barrier.h, copy_sm90.h) for all new intrinsics; unified mbarrier API via tl::mbarrier_* free functions. - inject_tma_barrier: handle tma_load_multicast; distinguish thread guards from cluster-rank conditions for correct expect_tx injection. - lower_hopper_intrin: migrate barrier allocation to builtin::create_barriers; hoist user ptx_init_barrier_thread_count calls alongside compiler barriers. - warp_specialized_rewriter: support tma_load_multicast as producer op; UserBarrierInitExtractor to separate user barrier inits from body; ragged (dynamic-extent) pipeline prefix counter; fix nested IfThenElse state corruption in shuffle-elect optimisation. - multi_version_buffer_rewriter: ragged pipeline support with runtime prefix counter for correct ping-pong buffer versioning. - thread_storage_sync: reserve fixed barrier IDs (kProducer/kConsumer) for 256-thread CTA warp-specialised split to prevent deadlocks. - pipeline_planning: defensive handling of non-BufferLoad mbarrier args. - Shared memory alignment bumped to 128 bytes for TMA. Co-authored-by: Guangda Sun <2012661711@qq.com>
Add docs/programming_guides/cluster_tma.md covering the two new T.copy extensions introduced in the t_copy_extend feature branch: - TMA multicast (cluster_mask): explains how a single TMA transaction broadcasts one global tile to multiple CTAs in a cluster simultaneously, with API usage, a per-rank behaviour table, and a complete code example. - SM-to-SM cluster copy (dst_block / remote_barrier): documents the fast path (cp.async.bulk.shared::cluster + mbarrier, single-thread async DMA) and the slow path (map_shared_rank element-by-element SIMT fallback), including the synchronisation contract for source and destination CTAs. Also covers cluster helper builtins (T.get_cluster_block_rank, T.cluster_sync, T.mbarrier_init/arrive/wait_parity, etc.) and a split-K sketch combining both features end-to-end.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds SM90+ cluster data-movement features: TMA multicast and SM-to-SM cluster copy across TileLang frontend, Tile op, lowering, CUDA codegen, TL templates, transform passes, warp rewriter, tests, and documentation. Changes
Sequence Diagram(s)sequenceDiagram
participant P as Producer CTA (rank 0)
participant TL as TileLang Kernel / CodeGen
participant MB as mbarrier (handle)
participant C as Consumer CTA (rank 1)
Note right of TL: CodeGen emits tl::tma_store_cluster / tl::tma_load_multicast
P->>TL: load global -> shared (s_src) via TMA
TL->>MB: tl::mbarrier_expect_tx(bar, size)
TL->>TL: tl::tma_store_cluster(dst_shared, s_src, dst_cta=1, size, bar)
Note over P,C: cp.async.bulk cluster transfer (SM-to-SM) across CTAs
C->>MB: tl::mbarrier_wait(bar)
C->>TL: read from shared (s_dst) and write out
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can use OpenGrep to find security vulnerabilities and bugs across 17+ programming languages.OpenGrep is compatible with Semgrep configurations. Add an |
There was a problem hiding this comment.
Actionable comments posted: 20
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/thread_storage_sync.cc (1)
459-490:⚠️ Potential issue | 🟠 MajorAllocate dynamic barrier IDs independently from the fixed producer/consumer IDs.
Once Line 470 or Line 477 inserts the fixed IDs,
barrier_id_map_.size()no longer matches the next free numeric barrier slot. If the consumer half is seen first, Line 483 hands out barrier 6 instead of 5, and those gaps make the 16-barrier hardware limit easier to hit. Track a separate next-dynamic ID, and guard against emitting anything above 15.Suggested fix
- size_t barrier_id = - barrier_id_map_.size() + - static_cast<size_t>(ReservedNamedBarriers::kFirstUsedBarrier); + size_t barrier_id = next_dynamic_barrier_id_++; + ICHECK_LT(barrier_id, 16) + << "ThreadPartialSyncRewriter exhausted Hopper named barriers";size_t next_dynamic_barrier_id_{ static_cast<size_t>(ReservedNamedBarriers::kFirstUsedBarrier)};🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/thread_storage_sync.cc` around lines 459 - 490, The code currently derives new dynamic barrier IDs from barrier_id_map_.size(), causing gaps once fixed producer/consumer IDs (ReservedNamedBarriers::kConsumer/kProducer) are inserted; introduce a class member next_dynamic_barrier_id_ initialized to static_cast<size_t>(ReservedNamedBarriers::kFirstUsedBarrier) and use it instead of barrier_id_map_.size() when allocating a new barrier in the allocation path (the place that assigns barrier_id right now and updates barrier_id_map_ and thread_count_map_); increment next_dynamic_barrier_id_ after assigning, and add a guard that prevents returning/assigning any barrier_id above the hardware max (e.g., 15) so you emit an error or handle allocation failure when next_dynamic_barrier_id_ would exceed the limit.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/programming_guides/cluster_tma.md`:
- Around line 103-105: The sentence is misleading about non-issuer behavior;
update the text referencing
cp.async.bulk.tensor.Nd.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster
to state that masked non-issuer CTAs receive the multicast passively (do not
perform an active load) and only CTAs that fall outside the cluster_mask perform
a standard cp.async.bulk.tensor active load; ensure both the multicast path
(passive receipt) and the fallback active load for CTAs outside the mask are
explicitly described and contrasted.
- Around line 6-10: Update all example snippets and code blocks that use the
now-removed public launch argument name cluster_size to the new name
cluster_dims; search for the symbol "cluster_size" in the document and replace
it with "cluster_dims" in every code/example instance (including the multiple
occurrences called out in the review) so pasted examples match the current API
surface.
In `@src/op/copy.cc`:
- Around line 934-936: The Lower() path currently returns LowerClusterCopy(T,
analyzer) whenever dst_block.defined(), which bypasses target capability checks
and may emit tma_store_cluster/ptx_cluster_store on targets that don't support
cluster copy; update the dst_block branch in Lower() to first check the same
Hopper/cluster capability predicate used for the intrinsic lowering (the
capability check used elsewhere for cluster intrinsics) and only call
LowerClusterCopy(T, analyzer) when that capability is present, otherwise return
an early failure/diagnostic explaining cluster copy is unsupported on the
current target; reference dst_block, Lower(), LowerClusterCopy, and the
tma_store_cluster/ptx_cluster_store intrinsics when making the change.
- Around line 1476-1515: The code assumes flat contiguity when building the
tma_store_cluster call; instead implement a contiguity check (e.g., add an
is_contiguous(Buffer&, Array<Range>&) helper that uses compute_linear_offset to
compute min_offset and max_offset (using ranges[i]->min and ranges[i]->min +
ranges[i]->extent - 1) and verifies (max_offset - min_offset + 1 ==
total_elements)). Use that check before taking the fast path that builds
dst_ptr/src_ptr and calls tma_store_cluster; if either src or dst is not
contiguous, fall back to the existing element-wise cluster-store path (same code
path used when not in this fast TMA branch). Make sure to reference
compute_linear_offset, dst_offset, src_offset, total_elements,
tma_store_cluster, and the in_tma_context_/HandleAccessPtrAndOffset behavior
when locating where to insert the check.
In `@src/tl_templates/cuda/cluster.h`:
- Around line 19-21: Replace all assert(false && "CLUSTER_ENABLED is not
defined") fallbacks and any code paths that assert then return zero with a hard
trap: use asm volatile("trap") (as used in tcgen_05_ld.h) so unsupported cluster
code cannot continue with bogus IDs. Specifically, find occurrences of the
CLUSTER_ENABLED guard and every function/macro in this header that currently
does an assert(false) (and the ones that follow that with return 0) and change
them to invoke asm volatile("trap") before any return; do this for every
assert-based fallback so the code fails hard on unsupported GPUs rather than
becoming a no-op in release builds.
In `@src/transform/inject_tma_barrier.cc`:
- Around line 163-177: The current check sets is_thread_eq true for any EQNode
that mentions thread_var_ (via UsesVar), which is too permissive; instead,
detect only equality predicates that uniquely select a single thread index
(e.g., threadIdx.x == CONST or threadIdx.x +/− CONST == CONST). Replace the
UsesVar-based logic in the op->condition EQNode branch with a structural
analyzer that accepts only EQNode patterns where one side is the thread_var_
(optionally wrapped by a single Add/Sub with an IntImm) and the other side is a
constant IntImm (or equivalent cast to IntImm); reject patterns involving
modulus, bitwise-and, other arithmetic, or non-constant sides. Keep the rest of
the flow (is_thread_eq usage and the guard around expect_tx/mbarrier_expect_tx)
unchanged so injection only happens when the predicate guarantees exactly one
thread value.
- Around line 103-112: The current VisitStmt_(const IfThenElseNode *op)
unconditionally skips visiting the else_case and assumes both arms transfer
equal bytes for mbarrier_expect_tx; instead, compute the transmitted-byte count
for both branches (e.g., by invoking the same byte-accounting visitor on
op->then_case and op->else_case) and compare the results, collapsing to a single
visit only if they are equal; otherwise keep the accounting branch-local by
visiting both branches separately so mbarrier_expect_tx is correct for each
path. Ensure you reference VisitStmt_(const IfThenElseNode *op),
StmtExprVisitor::VisitStmt/VisitStmt_, and the mbarrier_expect_tx accounting
logic when making the change.
In `@src/transform/lower_hopper_intrin.cc`:
- Around line 206-215: The code currently reuses get_mbarrier(i) starting at
zero causing aliasing across multiple create_list_of_mbarrier() calls; fix by
assigning a base offset from the current num_managed_barriers_ (e.g. int base =
num_managed_barriers_), then use get_mbarrier(base + i) when constructing each
mbarrier and push to init_mbarrier_calls_, and finally increment
num_managed_barriers_ by num_barriers (or set it after the loop) so each list
gets its own unique handle slice; update references in this block around
create_list_of_mbarrier(), num_managed_barriers_, get_mbarrier(),
init_mbarrier_calls_, and the local mbarrier variable accordingly.
In `@src/transform/lower_tile_op.cc`:
- Around line 663-671: The current branch sets in_tma_context_ for
tma_store_cluster(), which suppresses the normal shared-layout address remapping
and causes tvm_access_ptr/address_of operands to target wrong byte ranges;
instead, do not flip in_tma_context_ for the tma_store_cluster() path — call
Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op)) without setting
in_tma_context_ so the regular layout-offset rewrite runs for
tvm_access_ptr/address_of, but still ensure has_tma_ is not set for this op
(preserve the original intent of not triggering warp specialization).
In `@src/transform/pipeline_planning.cc`:
- Around line 372-395: The current branch in VisitExpr_ for mbarrier_wait_parity
drops dependency modeling when args[0] is not a BufferLoadNode by calling
StmtExprVisitor::VisitExpr_(op), which loses read/write edges; update the logic
handling mbarrier_wait_parity so that handle-based barriers (e.g.
tl.get_mbarrier(id)) are normalized into the same maps
(chain_builder_.mbar_to_buffer_reads_ / mbar_to_buffer_writes_) or, if the
concrete Buffer cannot be resolved, conservatively attach the wait to the set of
pending async buffers (the same collections used for BufferLoadNode) instead of
falling back to StmtExprVisitor::VisitExpr_. Locate this in the
mbarrier_wait_parity handling block in pipeline_planning.cc and modify the else
path to look up/normalize the handle (or add conservative edges) so reads_ and
writes_ are populated consistently.
In `@src/transform/warp_specialized_rewriter.cc`:
- Around line 1204-1209: The current code wraps the kept pipelined loop into
SeqStmt({result, update}) which hides the For from downstream post-processing
(so num_stages/pipeline annotations and GroupOpRewriter never run) and also
advances ragged_prefix_buf_ even when FilterByRole dropped the loop; instead,
append the ragged-prefix update only when the loop is actually preserved: avoid
creating a SeqStmt that nests the For—keep result as the loop Stmt and attach
the BufferStore update in the same AST level or conditional branch after
loop-preservation checks (use the existing LoadRaggedPrefix(),
ragged_prefix_buf_, and the computed new_prefix and update Stmt) so the For
remains visible to the post-processing and only advance the prefix when the loop
wasn't filtered out.
- Around line 161-164: The stage-local mbarrier rewrite in MbarrierRewriter
currently retargets only tma_load() and tma_load_im2col(), leaving
tma_load_multicast() signaling the original barrier and causing a mismatch with
the consumer; update the MbarrierRewriter rewrite logic to include
tma_load_multicast() alongside tma_load() and tma_load_im2col() so multicast
producers are retargeted to the per-stage barrier as well, ensuring consistent
signaling and waiting between producer and consumer.
- Around line 1477-1503: The extractor currently hoists entire IfThenElseNodes
(VisitStmt_) even when they have no-else or live inside an
Allocate/T.alloc_shared scope, which moves inits out of scope and drops else
branches; change VisitStmt_ in UserBarrierInitExtractor to only extract when the
if has no else branch (check op->else_case is null/empty) and then push the
inner init statement(s) (op->then_case or unwrapped SeqStmt) into init_stmts
instead of the whole IfThenElseNode; also ensure IsOnlyInit still recognizes
nested SeqStmt of Evaluate(CallNode) as before so callers can reinsert the init
inside the allocating scope rather than at block root.
In `@src/transform/warp_specialized_rewriter.h`:
- Around line 42-46: The current check uses the single-bit
detector.has_cluster_copy_ to exempt all mbarrier+TMA cases, which wrongly
preserves warp-specialisation when a kernel mixes regular tma_load* paths with
an unrelated tma_store_cluster; update the detector and the condition to
distinguish cluster-copy-specific TMA paths from regular TMA-load paths (e.g.,
add a flag like has_regular_tma_load_ or has_cluster_copy_path_ that tracks
whether the mbarrier is tied to the cluster-copy path), then change the
condition using detector.has_tma_op_ && detector.has_mbarrier_op_ &&
!detector.has_cluster_copy_ to instead require that no regular TMA-load path
exists (e.g., detector.has_tma_op_ && detector.has_mbarrier_op_ &&
!detector.has_regular_tma_load_) or scope the exemption to the specific
barrier/copy path; apply the same change to the other occurrence referenced
(lines ~81-83).
In `@testing/python/cuda/test_tma_dsmem.py`:
- Around line 65-92: Convert this script into a real pytest test by replacing
main() with a test function (e.g., test_tma_store_cluster_copy) that calls
make_store_cluster_kernel(N) and runs the kernel; use pytest.skip when the
device capability is < 9 (instead of printing/returning), and replace the
print-based pass/fail with an assert that np.allclose(result, expected) (or
assert with a helpful message showing max diff). Ensure the test imports pytest
and references make_store_cluster_kernel and the produced result/expected arrays
so CI will fail on mismatches.
In `@testing/python/cuda/test_tma_multicast_demo.py`:
- Around line 51-70: Guard the test_tma_multicast_demo test to run only on CUDA
devices with Hopper (SM >= 9.0) hardware: at the start of
test_tma_multicast_demo check torch.cuda.is_available() and
torch.cuda.get_device_capability(0) (or
torch.cuda.get_device_properties(0).major) to ensure SM >= 9, and call
pytest.skip(...) (or return) when the checks fail so
make_tma_multicast_demo_kernel, tilelang.compile and the mod(A) launch are not
executed on non-CUDA or pre-SM90 runners.
In `@tilelang/engine/phase.py`:
- Around line 295-300: shared_align_bytes is currently set to 128 only when
allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target) is true, but
kernels that get TMA-lowered via allow_tma_lower() (e.g. Hopper TMA) should also
receive 128-byte shared alignment; update the condition that computes
shared_align_bytes to check allow_tma_lower(pass_ctx=pass_ctx, target=target)
(or OR it with the existing allow_tma_and_warp_specialized check) so
MergeSharedMemoryAllocations(enable_aggressive_merge=...,
align_bytes=shared_align_bytes) receives 128 for any TMA-lowered kernel.
In `@tilelang/language/builtin.py`:
- Around line 427-443: The wrapper mbarrier_init currently accepts int/PrimExpr
and calls _get_mbarrier, but _get_mbarrier isn't implemented so ID-based inputs
fail at runtime; remove support for integer/PrimExpr IDs for now by deleting the
branch that handles isinstance(mbarrier, (tir.PrimExpr, int)) and its call to
_get_mbarrier, and update the TypeError message in mbarrier_init to only accept
tir.Call, tir.Buffer, or tir.BufferLoad (so callers must provide a real handle
or buffer until _get_mbarrier is implemented); keep the existing tir.BufferLoad
and tir.Call handling and the final tir.call_intrin(invocation) unchanged.
In `@tilelang/language/copy_op.py`:
- Around line 79-85: Update the docstring for the parameters (around dst_block
and cluster_mask) in copy_op.py to correct multicast semantics: state that when
cluster_mask is set the CTA whose rank equals the lowest set bit issues
tma_load_multicast and all other CTAs included in the mask receive the data
passively from that multicast (they must not issue redundant tma_load calls or
additional synchronization), and clarify that cluster_mask=None disables
multicast and causes regular per-CTA TMA loads; reference the cluster_mask
parameter and the description of the multicast leader (lowest set bit) so
readers can locate and edit the docstring.
- Around line 137-138: The code currently overwrites ann["dst_block"] whenever
dst_block is not None, violating the stated precedence that annotations beat
individual kwargs; update the assignment in the function/method that builds ann
(the block handling code where ann, annotations, dst_block are in scope) to only
set ann["dst_block"] when dst_block is not None AND "dst_block" is not already
present in ann (i.e., check 'if dst_block is not None and "dst_block" not in
ann: ann["dst_block"] = dst_block'), matching the existing precedence handling
used for cluster_mask and remote_barrier.
---
Outside diff comments:
In `@src/transform/thread_storage_sync.cc`:
- Around line 459-490: The code currently derives new dynamic barrier IDs from
barrier_id_map_.size(), causing gaps once fixed producer/consumer IDs
(ReservedNamedBarriers::kConsumer/kProducer) are inserted; introduce a class
member next_dynamic_barrier_id_ initialized to
static_cast<size_t>(ReservedNamedBarriers::kFirstUsedBarrier) and use it instead
of barrier_id_map_.size() when allocating a new barrier in the allocation path
(the place that assigns barrier_id right now and updates barrier_id_map_ and
thread_count_map_); increment next_dynamic_barrier_id_ after assigning, and add
a guard that prevents returning/assigning any barrier_id above the hardware max
(e.g., 15) so you emit an error or handle allocation failure when
next_dynamic_barrier_id_ would exceed the limit.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ac4c6e39-f962-4c65-ab08-2abf507a83a7
📒 Files selected for processing (23)
docs/programming_guides/cluster_tma.mdsrc/op/builtin.ccsrc/op/builtin.hsrc/op/copy.ccsrc/op/copy.hsrc/target/codegen_cuda.ccsrc/tl_templates/cuda/barrier.hsrc/tl_templates/cuda/cluster.hsrc/tl_templates/cuda/copy_sm90.hsrc/transform/common/thread_sync_types.hsrc/transform/inject_tma_barrier.ccsrc/transform/lower_hopper_intrin.ccsrc/transform/lower_tile_op.ccsrc/transform/multi_version_buffer_rewriter.ccsrc/transform/pipeline_planning.ccsrc/transform/thread_storage_sync.ccsrc/transform/warp_specialized_rewriter.ccsrc/transform/warp_specialized_rewriter.htesting/python/cuda/test_tma_dsmem.pytesting/python/cuda/test_tma_multicast_demo.pytilelang/engine/phase.pytilelang/language/builtin.pytilelang/language/copy_op.py
|
awesome contributions! cc @chengyupku @Rachmanino |
|
@Rachmanino It seems that this PR and #1874 have introduced some APIs with the same functionality. We will unify them into the #1874 API. |
|
Thanks for you awesome contribution! Left some comments though, we can discuss on them later. I can also fix them for you if needed :) |
…instead of assert(false)+dummy return.
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (3)
docs/programming_guides/cluster_tma.md (1)
6-6:⚠️ Potential issue | 🟡 MinorUse the current
cluster_dimslaunch argument in the examples.These snippets still use the old
cluster_sizename, so readers copying them will get stale API usage.Also applies to: 22-25, 77-77, 162-162, 260-260
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/programming_guides/cluster_tma.md` at line 6, The documentation examples use the deprecated launch argument name `cluster_size`; update every occurrence (including the instances at lines noted) to the current argument name `cluster_dims` so examples match the current API; search for and replace `cluster_size` with `cluster_dims` in code snippets and prose within this document (e.g., in example invocations and parameter descriptions) to ensure examples compile and reflect the active launch argument.src/transform/inject_tma_barrier.cc (1)
104-112:⚠️ Potential issue | 🟠 MajorKeep
expect_txbyte accounting branch-local.This still drops the
elsearm without proving both paths move the same byte count. A nested conditional can legitimately take a smaller or no-op alternative path, which makes the injectedmbarrier_expect_txwrong for that branch. Only collapse after proving equality; otherwise visit both arms separately.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/inject_tma_barrier.cc` around lines 104 - 112, The current VisitStmt_(const IfThenElseNode *op) collapses byte accounting by only visiting the then branch which can mis-count mbarrier_expect_tx for the else path; change VisitStmt_ for IfThenElseNode to handle branches locally: when op->else_case.defined(), snapshot the mbarrier_expect_tx (and any related byte-accounting state), call StmtExprVisitor::VisitStmt on op->then_case and record its expect_tx delta, restore the snapshot, call StmtExprVisitor::VisitStmt on op->else_case and record its expect_tx delta, and only collapse the two deltas if they are provably equal (otherwise keep branch-local accounting). Ensure you reference and update the same mbarrier_expect_tx variable used elsewhere so both branches are visited independently and state is restored between visits.src/op/copy.cc (1)
1482-1515:⚠️ Potential issue | 🔴 CriticalThe bulk cluster fast path only preserves bytewise memcpy semantics.
compute_linear_offsetplussize_bytes = total_elements * src->dtype.bytes()flatten this copy into a raw byte span. That is only correct when the src/dst slices are contiguous and share the same physical layout and dtype; strided subregions, swizzled-vs-linear shared buffers, or dtype-changingT.copycalls will silently copy the wrong bytes here. Fall back to the SIMT path unless those properties are proven first.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/copy.cc` around lines 1482 - 1515, The bulk-copy fast path flattens ranges into a raw byte span (compute_linear_offset + size_bytes = total_elements * src->dtype.bytes()) which is only correct for contiguous, same-layout, same-dtype transfers; update the logic before building dst_ptr/src_ptr and calling tma_store_cluster() to verify that src and dst are element-contiguous, have identical dtype, and are not subject to swizzle or layout transformations (i.e., the exact conditions required for a bytewise memcpy). If any of those checks fail, do not construct the TMA bulk path (dst.access_ptr/src.access_ptr + tma_store_cluster call); instead fall back to the existing SIMT copy path so element-wise semantics are preserved. Ensure the guards reference compute_linear_offset, total_elements/size_bytes, dst.access_ptr, src.access_ptr and tma_store_cluster so the correct branch is chosen at runtime/compile-time.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/programming_guides/cluster_tma.md`:
- Around line 41-46: The fenced ASCII diagram block starting with "Global memory
──TMA multicast──▶ shared memory (rank 0)" is unlabeled and triggers
markdownlint; update the opening fence to include a language label (use "text")
so the fence becomes ```text and the rest of the block remains unchanged,
ensuring the diagram stays the same but the markdown linter no longer flags it.
In `@src/op/copy.cc`:
- Around line 1863-1865: The cluster multicast mask from GetClusterMask() is
only applied inside LowerBulkCopy(), so non-descriptor or fallback paths (e.g.,
when copy_kind is kBulkLoad1D or other non-bulk paths) ignore it; update the
copy selection/dispatch logic to consult cluster_mask (and the computed
use_multicast) before deciding per-CTA vs multicast strategies so the mask is
honored across all copy paths—specifically propagate or check
cluster_mask/use_multicast where copy mode is selected (places that branch on
kBulkLoad1D, LowerBulkCopy(), or fallback handlers) and ensure any fallback to
per-CTA loads respects the mask by choosing the corresponding multicast-aware
code path or forcing the same source-coordinate logic as the descriptor path.
- Around line 1568-1581: Summary: The SIMT fallback flattens indices using
logical layout (op->indices and op->buffer->shape) but writes into a remapped
physical destination (target_dst_ / target_buffer.access_ptr(2)), causing wrong
element stores. Fix: replace the current linearization that uses op->indices and
op->buffer->shape in this block (where linearized_index is computed and pushed
to args) with a remapped/physical index consistent with target_dst_; use the
target_dst_ remapping API or compute linear index from the target buffer's
physical layout/strides (the same mapping used by ptx_cluster_store /
target_buffer.access_ptr(2)) so produced linearized_index matches the physical
destination. Ensure references: update code around linearized_index,
op->indices, op->buffer->shape, target_dst_, and ptx_cluster_store to use the
same mapping.
In `@src/transform/lower_hopper_intrin.cc`:
- Around line 140-145: The managed-barrier allocation currently only reserves
num_managed_barriers_ slots starting at ID 0 which can collide with
backend-reserved IDs (1 and 2); change the allocation and base so user-managed
barriers start at ID 3. Concretely, introduce a reserved_offset (3) and adjust
the allocation call (the CreateBarriers invocation currently assigned to
alloc_mbarrier) to request reserved_offset + num_managed_barriers_ slots, and
ensure any use of get_mbarrier(barrier_base + i) (and the barrier_base value set
elsewhere) is shifted by reserved_offset; also verify InjectTmaBarrier logic
that rewrites descriptor-based TMA loads to get_mbarrier(0) is updated to
account for the new base so no user-managed slot overlaps IDs 1 and 2.
---
Duplicate comments:
In `@docs/programming_guides/cluster_tma.md`:
- Line 6: The documentation examples use the deprecated launch argument name
`cluster_size`; update every occurrence (including the instances at lines noted)
to the current argument name `cluster_dims` so examples match the current API;
search for and replace `cluster_size` with `cluster_dims` in code snippets and
prose within this document (e.g., in example invocations and parameter
descriptions) to ensure examples compile and reflect the active launch argument.
In `@src/op/copy.cc`:
- Around line 1482-1515: The bulk-copy fast path flattens ranges into a raw byte
span (compute_linear_offset + size_bytes = total_elements * src->dtype.bytes())
which is only correct for contiguous, same-layout, same-dtype transfers; update
the logic before building dst_ptr/src_ptr and calling tma_store_cluster() to
verify that src and dst are element-contiguous, have identical dtype, and are
not subject to swizzle or layout transformations (i.e., the exact conditions
required for a bytewise memcpy). If any of those checks fail, do not construct
the TMA bulk path (dst.access_ptr/src.access_ptr + tma_store_cluster call);
instead fall back to the existing SIMT copy path so element-wise semantics are
preserved. Ensure the guards reference compute_linear_offset,
total_elements/size_bytes, dst.access_ptr, src.access_ptr and tma_store_cluster
so the correct branch is chosen at runtime/compile-time.
In `@src/transform/inject_tma_barrier.cc`:
- Around line 104-112: The current VisitStmt_(const IfThenElseNode *op)
collapses byte accounting by only visiting the then branch which can mis-count
mbarrier_expect_tx for the else path; change VisitStmt_ for IfThenElseNode to
handle branches locally: when op->else_case.defined(), snapshot the
mbarrier_expect_tx (and any related byte-accounting state), call
StmtExprVisitor::VisitStmt on op->then_case and record its expect_tx delta,
restore the snapshot, call StmtExprVisitor::VisitStmt on op->else_case and
record its expect_tx delta, and only collapse the two deltas if they are
provably equal (otherwise keep branch-local accounting). Ensure you reference
and update the same mbarrier_expect_tx variable used elsewhere so both branches
are visited independently and state is restored between visits.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3637adfa-cf51-472e-814c-a877a8d589c7
📒 Files selected for processing (6)
docs/programming_guides/cluster_tma.mdsrc/op/copy.ccsrc/tl_templates/cuda/cluster.hsrc/transform/inject_tma_barrier.ccsrc/transform/lower_hopper_intrin.ccsrc/transform/lower_tile_op.cc
🚧 Files skipped from review as they are similar to previous changes (1)
- src/tl_templates/cuda/cluster.h
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
src/op/copy.cc (3)
1479-1518:⚠️ Potential issue | 🟠 MajorThe bulk cluster fast path still assumes contiguous byte spans.
The
compute_linear_offsetlambda computes offset from range mins andtotal_elements = product(extents). This is only correct for contiguous regions. A 2D/ND subregion inside a larger shared buffer is generally strided, sotma_store_clustermay copy wrong elements.Consider adding a contiguity check before taking this fast path and falling back to the element-wise path otherwise.
1568-1581:⚠️ Potential issue | 🟠 MajorSIMT fallback linearizes using logical layout, not physical.
When
target_dst_is remapped (swizzled),ptx_cluster_storewrites into the physical buffer span, butlinearized_indexis derived fromop->indicesandop->buffer->shape(logical layout). For swizzled destinations, this will address the wrong element.The linearization should use the target buffer's physical layout/strides consistent with how
target_buffer.access_ptr(2)addresses memory.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/copy.cc` around lines 1568 - 1581, The linearization currently builds linearized_index from op->indices and op->buffer->shape (logical layout), which is wrong when target_dst_ is remapped; change linearization to compute the index using the target (physical) buffer layout/strides used by target_buffer.access_ptr(2) so ptx_cluster_store writes the correct physical location. Concretely: in the block that sets linearized_index (and uses multiplier), replace op->buffer->shape and op->indices as sources with the corresponding physical shape/stride information from the target buffer (the remapped target_dst_ / target_buffer used for access_ptr), or compute the linearized index by folding op->indices through the same stride vector/mapping that target_buffer.access_ptr(2) uses; keep the same loop structure but multiply by physical strides instead of logical shape multipliers so ptx_cluster_store addresses correctly.
1863-1865:⚠️ Potential issue | 🟠 Major
cluster_maskis only honored in the descriptor-based bulk-load path.This flag is consumed only inside
LowerBulkCopy(). If a copy falls back tokBulkLoad1Dor any non-bulk path, the annotation is silently ignored and masked CTAs revert to per-CTA loads.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/copy.cc` around lines 1863 - 1865, The cluster_mask from GetClusterMask() is only used inside LowerBulkCopy(), so when the copy falls back to kBulkLoad1D or any non-bulk path the multicast annotation is ignored; change the path-selection and lowering logic so the cluster_mask decision is made before choosing a lowering path and is propagated into all load implementations: compute use_multicast = (is_load && GetClusterMask() > 0) at the top-level (where use_multicast is currently set), and pass the cluster_mask/use_multicast flag into LowerBulkCopy() and the fallback loaders (the kBulkLoad1D handler and any non-bulk lowering functions) so masked-CTA behavior is honored across all paths. Ensure all loaders consult the passed cluster_mask/use_multicast instead of assuming no multicast.
🧹 Nitpick comments (1)
src/tl_templates/cuda/cluster.h (1)
14-15: Consider adding[[noreturn]]attribute for compiler optimization.The trap function correctly addresses the previous review concern about assert-only fallbacks. Since
trapterminates execution and never returns, marking the function with[[noreturn]](C++11) or__attribute__((noreturn))would help the compiler eliminate dead code paths and make the intent explicit.♻️ Optional improvement
-TL_DEVICE void cluster_unsupported_trap() { asm volatile("trap;"); } +TL_DEVICE [[noreturn]] void cluster_unsupported_trap() { asm volatile("trap;"); }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/tl_templates/cuda/cluster.h` around lines 14 - 15, The function cluster_unsupported_trap should be marked as non-returning to aid optimizations and make intent explicit: update the declaration/definition of cluster_unsupported_trap to include a noreturn annotation (e.g., C++11 [[noreturn]] or compiler-specific __attribute__((noreturn))) on the TL_DEVICE void cluster_unsupported_trap() function and keep the asm volatile("trap;"); body unchanged so callers and the compiler know the function never returns.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/transform/lower_hopper_intrin.cc`:
- Around line 145-152: The allocation size for create_barriers() must cover
every hoisted mbarrier referenced by both create_list_of_mbarrier() and direct
ptx_init_barrier_thread_count() calls; track the maximum required slots in a
shared counter (e.g., num_required_barriers_ or extend num_managed_barriers_)
whenever create_list_of_mbarrier() assigns a new slice or when
ptx_init_barrier_thread_count() references get_mbarrier(id), and use that max
value (including kReservedBarriers) when emitting the create_barriers() call so
the allocation always covers the highest referenced barrier index.
---
Duplicate comments:
In `@src/op/copy.cc`:
- Around line 1568-1581: The linearization currently builds linearized_index
from op->indices and op->buffer->shape (logical layout), which is wrong when
target_dst_ is remapped; change linearization to compute the index using the
target (physical) buffer layout/strides used by target_buffer.access_ptr(2) so
ptx_cluster_store writes the correct physical location. Concretely: in the block
that sets linearized_index (and uses multiplier), replace op->buffer->shape and
op->indices as sources with the corresponding physical shape/stride information
from the target buffer (the remapped target_dst_ / target_buffer used for
access_ptr), or compute the linearized index by folding op->indices through the
same stride vector/mapping that target_buffer.access_ptr(2) uses; keep the same
loop structure but multiply by physical strides instead of logical shape
multipliers so ptx_cluster_store addresses correctly.
- Around line 1863-1865: The cluster_mask from GetClusterMask() is only used
inside LowerBulkCopy(), so when the copy falls back to kBulkLoad1D or any
non-bulk path the multicast annotation is ignored; change the path-selection and
lowering logic so the cluster_mask decision is made before choosing a lowering
path and is propagated into all load implementations: compute use_multicast =
(is_load && GetClusterMask() > 0) at the top-level (where use_multicast is
currently set), and pass the cluster_mask/use_multicast flag into
LowerBulkCopy() and the fallback loaders (the kBulkLoad1D handler and any
non-bulk lowering functions) so masked-CTA behavior is honored across all paths.
Ensure all loaders consult the passed cluster_mask/use_multicast instead of
assuming no multicast.
---
Nitpick comments:
In `@src/tl_templates/cuda/cluster.h`:
- Around line 14-15: The function cluster_unsupported_trap should be marked as
non-returning to aid optimizations and make intent explicit: update the
declaration/definition of cluster_unsupported_trap to include a noreturn
annotation (e.g., C++11 [[noreturn]] or compiler-specific
__attribute__((noreturn))) on the TL_DEVICE void cluster_unsupported_trap()
function and keep the asm volatile("trap;"); body unchanged so callers and the
compiler know the function never returns.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9ce26ceb-5e92-4f03-84c8-22fd00d08fd1
📒 Files selected for processing (6)
docs/programming_guides/cluster_tma.mdsrc/op/copy.ccsrc/tl_templates/cuda/cluster.hsrc/transform/inject_tma_barrier.ccsrc/transform/lower_hopper_intrin.ccsrc/transform/lower_tile_op.cc
🚧 Files skipped from review as they are similar to previous changes (1)
- docs/programming_guides/cluster_tma.md
There was a problem hiding this comment.
left some minor comments. Besides, I think we should discuss with @LeiWang1999 on the frontend API design: whether to append new params to T.copy (I'm a bit worried about that there seems to have many params now) or to create a new interface (sth like T.copy_cluster)
I agree. Creating a more specific interface like @LeiWang1999, do you have a preference? |
…rive thread counts
|
@Rachmanino We have gone with |
|
Plz remove unused |
|
Good catch, and thanks for the careful review. To clarify: this PR does not use
So |
|
@Rachmanino Let me know if there are further changes needed ~ |
Previously, a non-contiguous copy region (e.g. a 2-D slice where the inner dim does not span the full buffer width) with a remote_barrier fell back to element-wise SIMT stores, losing TMA throughput. This commit introduces a "multi-TMA" middle path in LowerClusterCopy: the ND copy region is recursively decomposed into individual contiguous rows, each emitting one tma_store_cluster call. The mbarrier arrive_count is updated to N (the number of rows) so the destination CTA's wait() still works correctly without any API change. Key changes: - src/op/copy.cc: add static MakeTMARows() that recursively splits an ND region into contiguous rows; static extents are unrolled at compile time, symbolic extents produce TIR For loops. - src/op/operator.h: add UpdateBarrierArriveCallback (Var → PrimExpr) to LowerArgs so LowerClusterCopy can propagate the new arrive_count. - src/transform/lower_tile_op.cc: collect arrive_count overrides in barrier_arrive_updates_ and apply them to the barrier_init block annotation after visiting the block body, before LowerSharedBarrier consumes it. - testing/python/cuda/test_tma_dsmem.py: replace the old SIMT-barrier test with test_store_cluster_multi_tma_barrier (2-D, M rows) and add test_store_cluster_3d_multi_tma (3-D, D×M rows) to cover the two-level recursive decomposition.
I have refactored SM-to-SM copy to use TMA and hardware-managed mbarriers for non-full-span multi-dimensional tiles; SIMT fallback is now reserved only for reshape copies. |
|
I have no further concern as long as CI passed, cc @LeiWang1999. Besides, just a friendly note: some TileLang APIs have changed in recent prs (e.g. it is required to use |
|
@Rachmanino Thanks. (´▽`) It seems that the CI is broken. |
|
We've merged the recent changes. Could you please help us to rerun the CI? |
Sure, ask me whenever in need! |
Add multicast copy (tma_load_multicast) and shared-memory-to-shared-memory cluster copy (tma_store_cluster) to the t.copy operator.
Multicast copy:
cluster_maskparameter onT.copy_cluster()controls which CTAs participate.SM-to-SM cluster copy:
remote_barrierparameter onT.copy_cluster()).dst_blockparameter onT.copy_cluster()).Supporting changes:
We will upstream some use cases soon. Stay Tuned.
Co-authored-by: Guangda Sun 2012661711@qq.com
Summary by CodeRabbit