diff --git a/docs/programming_guides/cluster_tma.md b/docs/programming_guides/cluster_tma.md new file mode 100644 index 000000000..5e91ec205 --- /dev/null +++ b/docs/programming_guides/cluster_tma.md @@ -0,0 +1,303 @@ +# Cluster TMA: Multicast and SM-to-SM Copy + +This page describes two advanced data-movement features that are available on +NVIDIA Hopper (SM90) and later: **TMA multicast** and **SM-to-SM cluster +copy**. Both features are exposed through extensions to the existing `T.copy` +operator and require a kernel launched with thread block cluster, i.e., with `cluster_dims != (1, 1, 1)`. + +Requirements: +- CUDA Compute Capability ≥ 9.0 (Hopper / Blackwell / RTX 5090) + +--- + +## Background: Thread Block Clusters + +A *thread block cluster* is a group of CTAs that share a common virtual address +space for their shared-memory regions and can communicate without going through +global memory. Within a cluster, each CTA has a *block rank* (0-indexed +position inside the cluster), and all CTAs can observe each other's shared +memory via the `shared::cluster` address space. + +```python +with T.Kernel(grid_x, grid_y, threads=128, cluster_dims=(4, 1, 1)) as (bx, by): + rank = T.block_rank_in_cluster() # 0..3 within this cluster + nctas = T.get_cluster_block_nums() # total CTAs in this cluster + T.cluster_sync() # barrier across all CTAs in cluster +``` + +--- + +## Feature 1 — TMA Multicast (`cluster_mask`) + +### What it does + +Normally each CTA issues its own TMA load, fetching a tile from global memory +into its private shared memory. With multicast, **a single TMA transaction +broadcasts one global tile to every participating CTA simultaneously**, saving +repeated DRAM traffic when multiple CTAs in a cluster need the same data (e.g., +the same K-panel in a split-K GEMM). + +```text +Global memory ──TMA multicast──▶ shared memory (rank 0) + └─▶ shared memory (rank 1) (same tile, no extra DRAM read) + TMA load ──▶ shared memory (rank 2) (independent tile) + TMA load ──▶ shared memory (rank 3) (independent tile) +``` + +### API + +```python +T.copy_cluster(src_global, dst_shared, cluster_mask=) +``` + +`cluster_mask` is a bitmask where each set bit identifies a CTA rank that +participates in the multicast. The CTA whose rank equals the lowest set bit +in the mask issues `cp.async.bulk.tensor … multicast::cluster`; every other +CTA in the mask receives the data passively (no instruction issued). CTAs +outside the mask perform a regular TMA load for their own tile. + +### Example + +```python +import tilelang +import tilelang.language as T + +def make_tma_multicast_kernel(M, N, block_M, block_N, cluster_mask): + @T.prim_func + def kernel( + A: T.Tensor((M, N), "float16"), + B: T.Tensor((M, N), "float16"), + ): + # 4 CTAs per cluster; ranks 0 and 1 share the same tile via multicast. + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=128, + cluster_dims=(4, 1, 1) + ) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), "float16") + + # cluster_mask=0b0011: ranks 0 and 1 participate. + # Rank 0 issues tma_load_multicast; rank 1 receives passively. + # Ranks 2 and 3 each issue a regular tma_load. + T.copy_cluster(A[by * block_M, bx * block_N], A_shared, + cluster_mask=cluster_mask) + + T.copy(A_shared, B[by * block_M, bx * block_N]) + + return kernel +``` + +Running the kernel above with `cluster_mask = 0b0011`: + +| Rank | Action | `B` slice receives | +|------|--------|--------------------| +| 0 | issues multicast load | A tile at rank-0 address | +| 1 | passively receives | **same** A tile as rank 0 | +| 2 | regular TMA load | A tile at rank-2 address | +| 3 | regular TMA load | A tile at rank-3 address | + +### Notes + +- The compiler lowers `cluster_mask != 0` to + `cp.async.bulk.tensor.Nd.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster` + for the issuing CTA; CTAs in the mask but not elected as issuer receive + passively, and only CTAs outside the mask issue a standard + `cp.async.bulk.tensor`. +- Software-pipelining (`T.Pipelined`) is fully supported; the warp-specialized + rewriter recognises `tma_load_multicast` as a producer operation. +- `cluster_mask` is a compile-time constant; dynamic masks are not supported. + +--- + +## Feature 2 — SM-to-SM Cluster Copy (`dst_block`) + +### What it does + +SM-to-SM copy lets one CTA **push data directly from its own shared memory +into another CTA's shared memory** within the same cluster, without a round +trip through global memory. This is useful for patterns such as: + +- Partial result exchange (e.g., split-K partial sums across SM boundaries) +- Producer–consumer pipelines where the producer fills a neighbor's buffer +- All-to-all collective communication within a cluster + +Two sub-variants are provided depending on whether an mbarrier is supplied: + +| Variant | Parameter | Hardware instruction | Threads used | +|---------|-----------|---------------------|--------------| +| **Fast path** | `dst_block` + `remote_barrier` | `cp.async.bulk.shared::cluster` | 1 (async DMA) | +| **Slow path** | `dst_block` only | `map_shared_rank` + scalar stores | all (SIMT loop) | + +### Fast path — bulk async copy with mbarrier + +```python +T.copy_cluster(src_shared, dst_shared, dst_block=, remote_barrier=) +``` + +A single elected thread issues one `cp.async.bulk.shared::cluster` instruction. +The hardware DMA engine transfers the entire tile asynchronously and signals +the destination CTA's mbarrier on completion. The destination CTA waits with +`T.mbarrier_wait_parity`. + +Steps: +1. Both CTAs allocate the **same** shared memory layout so their mbarriers live + at the same offset. +2. Every CTA initialises its own barrier for 1 arrival. +3. The source CTA (`pid == 0` below) calls `T.copy_cluster(... dst_block=1, remote_barrier=...)`. +4. The destination CTA (`pid == 1`) waits on its local barrier copy. + +```python +import tilelang +import tilelang.language as T + +@tilelang.jit(verbose=True, execution_backend="cython") +def make_cluster_copy_kernel(N: int): + @T.prim_func + def kernel( + A: T.Tensor((N,), "float32"), + B: T.Tensor((N,), "float32"), + ): + with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: + s_src = T.alloc_shared((N,), "float32") + s_dst = T.alloc_shared((N,), "float32") + s_barrier = T.alloc_cluster_barrier([1]) + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + + T.cluster_sync() + + if pid == 0: + # Load A into local shared memory. + for i in T.Parallel(N): + s_src[i] = A[i] + + # Async-push s_src → s_dst in CTA 1, signal CTA 1's barrier. + T.copy_cluster(s_src, s_dst, dst_block=1, + remote_barrier=s_barrier[0]) + + if pid == 1: + # Wait until CTA 0 finishes writing. + T.mbarrier_wait_parity(s_barrier[0], 0) + + for i in T.Parallel(N): + B[i] = s_dst[i] + + return kernel +``` + +Generated producer code (single-thread guard, one PTX instruction): + +```cuda +if (((int)threadIdx.x) == 0) { + tl::tma_store_cluster(&s_dst[0], &s_src[0], 1, + (uint32_t)(N * 4), s_barrier[0]); +} +``` + +### Slow path — element-by-element SIMT fallback + +Omit `remote_barrier` to use the slow path: + +```python +T.copy_cluster(s_src, s_dst, dst_block=1) +``` + +This lowers to a SIMT parallel loop where every thread writes one (or a few) +elements into the remote CTA's shared memory via +`cooperative_groups::map_shared_rank`. Because `map_shared_rank` returns a +scalar pointer, vectorised writes are not possible. Use this path only when an +mbarrier is unavailable or when the tile is too small to justify barrier +overhead. + +### Synchronisation contract + +| | Fast path | Slow path | +|-|-----------|-----------| +| Source CTA | no wait needed; copy is async | effectively sync after the loop | +| Destination CTA | `T.mbarrier_wait_parity(barrier, parity)` | external `T.cluster_sync()` or equivalent | + +### Notes + +- Both paths require `src` and `dst` to be in `shared` or `shared.dyn` scope. +- The mbarrier must be allocated with `T.alloc_cluster_barrier([arrive_count])`. +- `T.cluster_sync()` after allocation but before the copy is required to ensure + all CTAs have reached the barrier-init barrier before any data is pushed. +- `dst_block` may be a compile-time integer or a runtime `tir.PrimExpr`. + +--- + +## Cluster Helper Builtins + +| Builtin | Return | Description | +|---------|--------|-------------| +| `T.block_rank_in_cluster()` | `int32` | Block rank (0-indexed) within the cluster | +| `T.get_cluster_block_nums()` | `int32` | Total number of CTAs in the cluster | +| `T.cluster_sync()` | — | Barrier synchronisation across all cluster CTAs | +| `T.alloc_cluster_barrier([count])` | `Buffer` | Allocate and initialise an mbarrier for `count` arrivals | +| `T.mbarrier_arrive(bar)` | — | Signal one arrival on an mbarrier | +| `T.mbarrier_wait_parity(bar, parity)` | — | Wait until `bar` flips to `parity` | + +--- + +## Putting It Together: Split-K Sketch + +A common pattern combining both features: multicast the shared K-panel to +all cluster CTAs (saving DRAM bandwidth), then reduce partial sums with +SM-to-SM copy (saving global-memory round trips). + +```python +@T.prim_func +def split_k_gemm(A, B, C): + with T.Kernel(grid_x, grid_y, threads=256, cluster_dims=(4, 1, 1)) as (bx, by): + rank = T.block_rank_in_cluster() + A_s = T.alloc_shared((BM, BK), "float16") + B_s = T.alloc_shared((BK, BN), "float16") + C_f = T.alloc_fragment((BM, BN), "float32") + C_s = T.alloc_shared((BM, BN), "float32") + barrier = T.alloc_cluster_barrier([3]) + T.clear(C_f) + + # Phase 1: each CTA loads its K-slice; A is multicast to rank 0 and 1. + for ko in T.Pipelined(T.ceildiv(K, BK * 4), num_stages=3): + k_off = (rank + ko * 4) * BK + T.copy_cluster(A[by * BM, k_off], A_s, cluster_mask=0b0011) + T.copy(B[k_off, bx * BN], B_s) + T.gemm(A_s, B_s, C_f) + + # Phase 2: push each rank's partial sums to rank 0 for accumulation. + # + # Use a per-rank staging slot so every non-zero rank writes to a + # distinct destination region — avoiding both a destination race and + # an arrival-count mismatch. Each CTA stores its own partial into + # C_parts[rank]; non-zero ranks then push that slot to the matching + # slot in rank 0's shared memory. + # + # Arrival count must equal the number of producers: cluster_size - 1. + C_parts = T.alloc_shared((4, BM, BN), "float32") # one slot per rank + T.copy(C_f, C_parts[rank]) + + T.cluster_sync() + + if rank != 0: + # Push this rank's slot to the *same* slot index in rank 0's + # C_parts — different offsets, so no destination race. + T.copy_cluster(C_parts[rank], C_parts[rank], + dst_block=0, remote_barrier=barrier[0]) + + if rank == 0: + T.mbarrier_wait_parity(barrier[0], 0) # wakes after all 3 arrivals + # C_parts[0..3] in rank 0's smem now hold all four partial sums. + # accumulate and store ... + T.copy(C_parts[0], C[by * BM, bx * BN]) +``` + +--- + +## See Also + +- `testing/python/cuda/test_tma_multicast_demo.py` — multicast validation +- `testing/python/cuda/test_tma_dsmem.py` — SM-to-SM copy validation +- Programming Guides → Instructions — complete `T.copy` parameter reference +- Programming Guides → Control Flow — `T.Pipelined` and warp-specialized pipelines diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 4e467135e..18575f692 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -156,6 +156,11 @@ TIR_DEFINE_TL_BUILTIN(tma_load_im2col) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(tma_load_multicast) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(tma_store).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -278,6 +283,11 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(get_cluster_block_nums) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + TIR_DEFINE_TL_BUILTIN(get_lane_idx) .set_num_inputs(-1) .set_attr("TCallEffectKind", @@ -531,5 +541,15 @@ TIR_DEFINE_TL_BUILTIN(stg128).set_num_inputs(-1).set_attr( TIR_DEFINE_TL_BUILTIN(stg256).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(ptx_cluster_store) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tma_store_cluster) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + } // namespace tl } // namespace tvm diff --git a/src/op/builtin.h b/src/op/builtin.h index 9e4b0c4eb..552018c18 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -263,6 +263,18 @@ TVM_DLL const Op &tma_load(); */ TVM_DLL const Op &tma_load_im2col(); +/*! + * \brief tvm intrinsics for multicasting data from global tensor descriptor to + * shared memory of multiple CTAs in a cluster simultaneously + * + * tma_load_multicast(descriptor, mbarrier, smem_data, multicast_mask, + * coord_0, coord_1, ..., eviction_policy) + * + * Only the CTA with the minimum rank in the multicast_mask initiates the + * transfer; other CTAs in the mask receive data via the multicast mechanism. + */ +TVM_DLL const Op &tma_load_multicast(); + /*! * \brief tvm intrinsics for storing data from shared memory to global tensor * descriptor @@ -476,6 +488,14 @@ TVM_DLL const Op &warpgroup_wait(); */ TVM_DLL const Op &warpgroup_fence_operand(); +/*! + * \brief Return the number of blocks in the cluster. + * + * get_cluster_block_nums() + * + */ +TVM_DLL const Op &get_cluster_block_nums(); + /*! * \brief Return the canonical lane index for the calling thread. * @@ -926,6 +946,29 @@ TVM_DLL const Op &stg128(); */ TVM_DLL const Op &stg256(); +/*! + * \brief tilelang intrinsic for cluster store. + * + * This op is used to represent a cluster store operation in tilelang. + */ +TVM_DLL const Op &ptx_cluster_store(); + +/*! + * \brief tilelang intrinsic for bulk SM-to-SM async cluster store. + * + * Uses cp.async.bulk.shared::cluster to bulk-copy a contiguous region of + * shared memory to another CTA in the same cluster, signalling the + * destination CTA's mbarrier on completion. + * + * Args: [dst_ptr, src_ptr, dst_cta, size_bytes, bar_ref] + * - dst_ptr (handle): address_of(dst_buf[dst_offset]) – destination pointer + * - src_ptr (handle): address_of(src_buf[src_offset]) – source pointer + * - dst_cta (int32): destination CTA rank in the cluster + * - size_bytes(uint32): number of bytes to transfer + * - bar_ref (uint64): mbarrier element (passed by reference to the callee) + */ +TVM_DLL const Op &tma_store_cluster(); + } // namespace tl } // namespace tvm diff --git a/src/op/copy.cc b/src/op/copy.cc index 6f115a0bf..43533b663 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include namespace tvm { @@ -178,7 +179,40 @@ Copy::Copy(Array args, Map annotations) { std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); // Copy annotations from the Call node + // then override with positional + // args when provided for backward compatibility. node->annotations = annotations; + if (auto dst_block = node->annotations.Get("dst_block")) { + if (auto int_imm = dst_block->as()) { + if (int_imm->value != -1) { + node->dst_block = Integer(int_imm->value); + } + } else { + node->dst_block = Downcast(dst_block.value()); + } + } + if (args.size() >= 3) { + auto coalesced_width = Downcast(args[2]); + if (coalesced_width->value > 0) { + node->annotations.Set(attr::kCoalescedWidth, coalesced_width); + } + } + if (args.size() >= 4) { + node->annotations.Set("disable_tma", Downcast(args[3])); + } + if (args.size() >= 5) { + node->annotations.Set("eviction_policy", args[4]); + } + if (args.size() >= 6) { + auto dst_block = args[5]; + if (auto int_imm = dst_block.as()) { + if (int_imm->value != -1) { + node->dst_block = dst_block; + } + } else { + node->dst_block = dst_block; + } + } data_ = std::move(node); } @@ -918,6 +952,27 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto copy_inst = GetCopyInst(target, disable_tma_lower || GetDisableTMA(), T.layout_map, analyzer, /*buffer_oob=*/false, /*in_pipeline=*/T.in_pipeline); + // cluster_mask is only honored in the descriptor-based TMA path (kBulkLoad). + // Any other instruction type silently drops the multicast semantics and + // causes masked CTAs to fall back to per-CTA loads, potentially changing + // results when different ranks require different source coordinates. + { + int64_t cluster_mask = GetClusterMask(); + ICHECK(cluster_mask == 0 || copy_inst == CopyInst::kBulkLoad) + << "cluster_mask=0x" << std::hex << cluster_mask + << " requires descriptor-based TMA (kBulkLoad), but this copy was " + "routed to copy_inst=" + << static_cast(copy_inst) + << ". Ensure the copy meets TMA bulk-load constraints. src=" + << src->name << " (scope=" << src.scope() << "), dst=" << dst->name + << " (scope=" << dst.scope() << ")."; + } + if (dst_block.defined()) { + ICHECK(TargetHasBulkCopy(target)) + << "T.copy with dst_block requires cluster-copy support (CUDA SM90+). " + << "Got target=" << target; + return LowerClusterCopy(T, analyzer); + } if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { auto tmem_copy = LowerTmemCopy(T, analyzer); ICHECK(tmem_copy.defined()) << "Failed to lower tensor memory copy"; @@ -944,6 +999,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return LowerNormalCopy(T, analyzer); } else { LOG(FATAL) << "Unsupported copy inst " << static_cast(copy_inst); + return Stmt(); } } @@ -1443,6 +1499,427 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, return body; } +// --------------------------------------------------------------------------- +// Multi-TMA decomposition helper +// --------------------------------------------------------------------------- +// Recursively splits an N-D copy region into individual contiguous "rows" and +// emits one tma_store_cluster call per row. +// +// Returns (stmts, arrive_count_expr) where: +// • stmts – flat list of TIR statements to emit (may contain For +// loops when an extent is symbolic/dynamic) +// • arrive_count – PrimExpr for the total number of tma_store_cluster +// calls; used as the mbarrier arrive_count so the +// destination CTA's wait completes only after all rows +// have been transferred. +// +// Algorithm: +// Base case – both regions are already contiguous → emit one call, count=1. +// Static dim – extent is a compile-time IntImm → unroll and recurse. +// Dynamic dim – extent is symbolic → generate a TIR For loop; the loop +// body is the recursion result; count = extent * body_count. +// +// Preconditions: +// • src_ranges and dst_ranges have the same rank and matching per-dim +// extents (checked by the caller). +static std::pair, PrimExpr> +MakeTMARows(const Buffer &src, const Array &src_ranges, + const Buffer &dst, const Array &dst_ranges, + PrimExpr dst_block, PrimExpr barrier_load, + arith::Analyzer *analyzer) { + int n = static_cast(src_ranges.size()); + + // Check whether a buffer/range pair is contiguous in row-major storage. + auto is_contig = [&](const Buffer &buf, const Array &ranges) -> bool { + int pivot = -1; + for (int i = 0; i < n; ++i) { + if (!analyzer->CanProveEqual(ranges[i]->extent, 1)) { + pivot = i; + break; + } + } + if (pivot == -1) + return true; + for (int i = pivot + 1; i < n; ++i) { + if (!analyzer->CanProveEqual(ranges[i]->min, 0) || + !analyzer->CanProveEqual(ranges[i]->extent, buf->shape[i])) + return false; + } + return true; + }; + + // Linear element offset for the starting element of a region. + auto linear_off = [](const Buffer &buf, + const Array &ranges) -> PrimExpr { + int r = static_cast(ranges.size()); + PrimExpr off = 0, stride = 1; + for (int i = r - 1; i >= 0; --i) { + off = off + ranges[i]->min * stride; + if (i > 0) + stride = stride * buf->shape[i]; + } + return off; + }; + + // Base case: both regions are contiguous → one tma_store_cluster. + if (is_contig(src, src_ranges) && is_contig(dst, dst_ranges)) { + PrimExpr total_elems = 1; + for (const auto &r : src_ranges) + total_elems = total_elems * r->extent; + PrimExpr size_bytes = + cast(DataType::UInt(32), total_elems * src->dtype.bytes()); + PrimExpr src_ptr = src.access_ptr(1, DataType::Handle(), 1, + linear_off(src, src_ranges), total_elems); + PrimExpr dst_ptr = dst.access_ptr(2, DataType::Handle(), 1, + linear_off(dst, dst_ranges), total_elems); + Stmt call = + Evaluate(Call(DataType::Handle(), tma_store_cluster(), + {dst_ptr, src_ptr, dst_block, size_bytes, barrier_load})); + return {{call}, IntImm(DataType::Int(32), 1)}; + } + + // Recursive case: find the outermost non-trivial dimension. + int split_dim = -1; + for (int d = 0; d < n; ++d) { + if (!analyzer->CanProveEqual(src_ranges[d]->extent, 1)) { + split_dim = d; + break; + } + } + ICHECK(split_dim >= 0) + << "MakeTMARows: all dimensions are trivial yet region is not " + "contiguous – this should not happen"; + + PrimExpr extent = src_ranges[split_dim]->extent; + const auto *ext_imm = extent.as(); + + if (ext_imm) { + // ── Static extent: unroll at compile time ────────────────────────────── + Array all_stmts; + PrimExpr total = IntImm(DataType::Int(32), 0); + for (int64_t k = 0; k < ext_imm->value; ++k) { + Array new_src = src_ranges; + Array new_dst = dst_ranges; + PrimExpr kexpr = IntImm(DataType::Int(32), k); + new_src.Set(split_dim, + Range::FromMinExtent(src_ranges[split_dim]->min + kexpr, 1)); + new_dst.Set(split_dim, + Range::FromMinExtent(dst_ranges[split_dim]->min + kexpr, 1)); + auto [stmts, cnt] = MakeTMARows(src, new_src, dst, new_dst, dst_block, + barrier_load, analyzer); + for (const auto &s : stmts) + all_stmts.push_back(s); + total = total + cnt; + } + return {all_stmts, total}; + } else { + // ── Dynamic extent: emit a TIR For loop ──────────────────────────────── + // Build the loop body: fix split_dim to (range_min + loop_var). + Var k("k_tma_row", DataType::Int(32)); + Array body_src = src_ranges; + Array body_dst = dst_ranges; + body_src.Set(split_dim, + Range::FromMinExtent(src_ranges[split_dim]->min + k, 1)); + body_dst.Set(split_dim, + Range::FromMinExtent(dst_ranges[split_dim]->min + k, 1)); + auto [body_stmts, body_cnt] = MakeTMARows( + src, body_src, dst, body_dst, dst_block, barrier_load, analyzer); + Stmt body = body_stmts.size() == 1 ? body_stmts[0] + : static_cast(SeqStmt(body_stmts)); + Stmt for_loop = + For(k, IntImm(DataType::Int(32), 0), extent, ForKind::kSerial, body); + // arrive_count = extent * body_count (every iteration contributes the same) + PrimExpr total = extent * body_cnt; + return {{for_loop}, total}; + } +} + +Stmt CopyNode::LowerClusterCopy(const LowerArgs &T, + arith::Analyzer *analyzer) const { + // Check if the target supports cluster copy + // Currently only support shared memory to shared memory copy + ICHECK(src.scope() == "shared" || src.scope() == "shared.dyn"); + ICHECK(dst.scope() == "shared" || dst.scope() == "shared.dyn"); + + // --------------------------------------------------------------------------- + // Fast path: bulk async copy via tl::tma_store_cluster + // Used when the caller supplies a shared-memory mbarrier via the "barrier" + // annotation. A single elected thread issues the cp.async.bulk instruction; + // the destination CTA waits on its local copy of the same mbarrier. + // --------------------------------------------------------------------------- + if (auto barrier_opt = GetBarrier()) { + // Bulk cluster copy issues a single flat byte-span transfer. This is only + // correct when both src/dst regions are contiguous in row-major storage. + auto is_contiguous_region = [&](const Buffer &buf, + const Array &ranges) -> bool { + ICHECK_EQ(buf->shape.size(), ranges.size()) + << "Buffer/range rank mismatch for " << buf->name; + + int pivot = -1; + for (int i = 0; i < static_cast(ranges.size()); ++i) { + if (!analyzer->CanProveEqual(ranges[i]->extent, 1)) { + pivot = i; + break; + } + } + // Scalar region is contiguous. + if (pivot == -1) { + return true; + } + + // Outer dimensions must be fixed (extent == 1). + for (int i = 0; i < pivot; ++i) { + if (!analyzer->CanProveEqual(ranges[i]->extent, 1)) { + return false; + } + } + + // Inner dimensions must be full-span [0, shape[i]) to avoid strides. + for (int i = pivot + 1; i < static_cast(ranges.size()); ++i) { + if (!analyzer->CanProveEqual(ranges[i]->min, 0) || + !analyzer->CanProveEqual(ranges[i]->extent, buf->shape[i])) { + return false; + } + } + return true; + }; + + bool src_contiguous = is_contiguous_region(src, src_range); + bool dst_contiguous = is_contiguous_region(dst, dst_range); + + PrimExpr src_elements = 1; + for (auto r : src_range) + src_elements = src_elements * r->extent; + PrimExpr dst_elements = 1; + for (auto r : dst_range) + dst_elements = dst_elements * r->extent; + bool element_match = analyzer->CanProveEqual(src_elements, dst_elements); + + if (src_contiguous && dst_contiguous && element_match) { + // ----------------------------------------------------------------------- + // Fast path: single tma_store_cluster for the whole contiguous region. + // ----------------------------------------------------------------------- + PrimExpr barrier_load = barrier_opt.value(); + + // Compute linear offsets from the copy ranges (one offset per buffer). + auto compute_linear_offset = [](const Buffer &buf, + const Array &ranges) -> PrimExpr { + PrimExpr offset = 0; + PrimExpr stride = 1; + for (int i = static_cast(ranges.size()) - 1; i >= 0; --i) { + offset = offset + ranges[i]->min * stride; + if (i > 0) + stride = stride * buf->shape[i]; + } + return offset; + }; + + PrimExpr dst_offset = compute_linear_offset(dst, dst_range); + PrimExpr src_offset = compute_linear_offset(src, src_range); + + // Total number of elements to transfer. + PrimExpr total_elements = 1; + for (auto r : src_range) + total_elements = total_elements * r->extent; + PrimExpr size_bytes = + cast(DataType::UInt(32), total_elements * src->dtype.bytes()); + + // Build tvm_access_ptr arguments. These are processed by LowerTileOp's + // HandleAccessPtrAndOffset which, for TMA ops (in_tma_context_=true), + // keeps the raw linear offset without applying any swizzle + // transformation. + PrimExpr dst_ptr = + dst.access_ptr(2, DataType::Handle(), 1, dst_offset, total_elements); + PrimExpr src_ptr = + src.access_ptr(1, DataType::Handle(), 1, src_offset, total_elements); + + Stmt bulk_copy = Evaluate(Call( + DataType::Handle(), tma_store_cluster(), + {dst_ptr, src_ptr, dst_block.value(), size_bytes, barrier_load})); + + // Single-thread guard: only thread_bounds->min issues the instruction. + return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), bulk_copy); + } + + // ------------------------------------------------------------------------- + // Multi-TMA path: non-contiguous region decomposed into N contiguous rows. + // Requires matching per-dim extents between src and dst (element_match + // alone is insufficient – we need the same iteration space on both sides). + // Extents may be compile-time constants (→ unrolled calls) or symbolic + // (→ TIR For loops); in both cases MakeTMARows computes the arrive_count + // expression used to initialise the mbarrier. + // ------------------------------------------------------------------------- + bool same_shape = (src_range.size() == dst_range.size()); + for (size_t d = 0; d < src_range.size() && same_shape; ++d) { + if (!analyzer->CanProveEqual(src_range[d]->extent, + dst_range[d]->extent)) { + same_shape = false; + } + } + + if (element_match && same_shape) { + PrimExpr barrier_load = barrier_opt.value(); + const auto *barrier_buf_load = barrier_load.as(); + ICHECK(barrier_buf_load) + << "LowerClusterCopy: expected BufferLoad for barrier annotation"; + Var barrier_data_var = barrier_buf_load->buffer->data; + + auto [tma_stmts, n_rows] = + MakeTMARows(src, src_range, dst, dst_range, dst_block.value(), + barrier_load, analyzer); + + // Inform LowerTileOpPass so it can update the barrier's arrive_count in + // the barrier_init block annotation before LowerSharedBarrier runs. + if (T.UpdateBarrierArrive) { + T.UpdateBarrierArrive(barrier_data_var, n_rows); + } + + Stmt seq = (tma_stmts.size() == 1) + ? tma_stmts[0] + : static_cast(SeqStmt(tma_stmts)); + // Single-thread guard: only thread_bounds->min issues TMA instructions. + return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), seq); + } + + LOG(WARNING) + << "Falling back to element-wise cluster copy: bulk cluster " + "fast/multi-TMA paths require matching element counts and same " + "per-dim extents between src and dst. src=" + << src->name << ", dst=" << dst->name; + } + + // --------------------------------------------------------------------------- + // Slow path: element-by-element SIMT copy via ptx_cluster_store + // Used when no barrier is provided (backward-compatible behaviour). + // --------------------------------------------------------------------------- + + // Generate the loop nest for the copy + auto simt_loop = MakeSIMTLoop(analyzer); + auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); + + // Partition across threads but force scalar (vectorize_hint=1): + // ClusterCopyReplacer replaces BufferStore with ptx_cluster_store + // (cooperative_groups map_shared_rank + scalar write). Vectorized stores + // would produce vector-dtype values that cannot be written through + // map_shared_rank's scalar pointer return type. + std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, + InferLevel::kFree}; + auto par_op = ParallelOp(fused_loop); + for (auto level : levels) { + par_op->InferLayout({T.target, + T.thread_bounds, + T.layout_map, + analyzer, + false, + T.buffer_remap, + {}}, + level); + } + auto loop_layout = par_op->GetLoopLayout(); + auto thread_loop = + PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout); + auto vectorized_thread_loop = + VectorizeLoop(thread_loop, T.layout_map, /*vectorize_hint=*/1); + + // Replace the buffer store with the cluster copy intrinsic + class ClusterCopyReplacer : public StmtExprMutator { + public: + ClusterCopyReplacer(const Buffer &dst, PrimExpr dst_block, + const Buffer &target_dst, Optional dst_layout) + : dst_(dst), dst_block_(dst_block), target_dst_(target_dst), + dst_layout_(dst_layout) {} + + Stmt VisitStmt_(const BufferStoreNode *op) final { + if (op->buffer.same_as(dst_)) { + Array args; + args.push_back(target_dst_.access_ptr(2)); // The buffer var (handle) + args.push_back(op->value); // The value to store + args.push_back(dst_block_); // The destination block index + + // Compute the physical linear index in the target buffer. + // When dst is remapped to a layout-transformed shared buffer, we must + // forward logical indices through that layout before flattening. + Array physical_indices = op->indices; + if (!target_dst_.same_as(dst_) && dst_layout_.defined()) { + physical_indices = dst_layout_.value()->Forward(op->indices); + } + + PrimExpr linearized_index = physical_indices[0]; + if (physical_indices.size() > 1) { + PrimExpr multiplier = 1; + linearized_index = 0; + for (int i = physical_indices.size() - 1; i >= 0; --i) { + linearized_index = + linearized_index + physical_indices[i] * multiplier; + if (i > 0) { + multiplier = multiplier * target_dst_->shape[i]; + } + } + } + + args.push_back(linearized_index); + + Buffer target_buffer = target_dst_; + if (target_dst_.same_as(dst_)) { + target_buffer = op->buffer; + } + + // Guard against out-of-bounds remote stores when the thread count + // exceeds the copy extent (e.g. 128 threads for a 64-element region). + PrimExpr total_elems = 1; + for (const PrimExpr &s : target_buffer->shape) { + total_elems = total_elems * s; + } + + Stmt remote_store = Evaluate( + Call(DataType::Handle(), ptx_cluster_store(), + {target_buffer.access_ptr(2), args[1], args[2], args[3]})); + + return IfThenElse(args[3] < total_elems, remote_store, Stmt()); + } + return StmtExprMutator::VisitStmt_(op); + } + + private: + const Buffer &dst_; + PrimExpr dst_block_; + const Buffer &target_dst_; + Optional dst_layout_; + }; + + Buffer target_dst = dst; + if (T.buffer_remap.count(dst)) { + target_dst = T.buffer_remap[dst]; + } + + Optional dst_layout = std::nullopt; + if (T.layout_map.count(dst)) { + dst_layout = T.layout_map[dst]; + } + + Stmt simt_copy = ClusterCopyReplacer(dst, dst_block.value(), target_dst, + dst_layout)(vectorized_thread_loop); + + // When a remote_barrier is supplied but the fast path (tma_store_cluster) is + // unavailable (e.g. non-contiguous layout), the SIMT stores are not tracked + // by any hardware completion mechanism. Auto-generate: + // __syncthreads(); + // if (threadIdx.x == 0) { s_barrier[0].arrive(u); } + // so the destination CTA can still wait on the barrier as usual, without + // requiring the caller to insert these statements manually. + if (auto barrier_opt = GetBarrier()) { + Stmt sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), + {StringImm("shared")})); + Stmt arrive = + Evaluate(Call(DataType::Handle(), ptx_arrive_cluster_barrier(), + {barrier_opt.value(), dst_block.value()})); + Stmt guarded_arrive = + IfThenElse(EQ(T.thread_var, T.thread_bounds->min), arrive); + return SeqStmt({simt_copy, sync, guarded_arrive}); + } + return simt_copy; +} + // Lowers copy to a bulk TMA (Tensor Memory Accelerator) transfer. // Falls back to LowerNormalCopy if preconditions are not satisfied. Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, @@ -1687,6 +2164,10 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, Call create_descriptor = Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs()); + // Check for cluster multicast mask annotation (0 means no multicast) + int64_t cluster_mask = GetClusterMask(); + bool use_multicast = is_load && (cluster_mask > 0); + // For TMA loads, allocate mbarrier(s) for synchronous semantics. // Determine the mbarrier handle for TMA loads. // T.tma_copy(): requires user-provided barrier @@ -1727,6 +2208,24 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, for (auto e : desc.smem_box) total_elements *= e; + // Helper lambda to build multicast args by inserting the cluster_mask + // value between smem_ptr (args[2]) and global coordinates. + // Regular tma_load args layout: [desc, mbar, smem_ptr, coord0..., eviction] + // Multicast tma_load args layout: [desc, mbar, smem_ptr, mask, coord0..., + // eviction] + auto build_multicast_args = [&](const Array ®ular_args) { + Array mc_args; + mc_args.reserve(regular_args.size() + 1); + mc_args.push_back(regular_args[0]); // descriptor + mc_args.push_back(regular_args[1]); // mbarrier placeholder + mc_args.push_back(regular_args[2]); // smem_ptr + mc_args.push_back( + IntImm(DataType::Int(32), cluster_mask)); // multicast mask + for (size_t i = 3; i < regular_args.size(); i++) + mc_args.push_back(regular_args[i]); // coords + eviction_policy + return mc_args; + }; + if ((*inner_box_dim) != instruction_dim) { Var loop_var("i"); int loop_extent = (*inner_box_dim) / instruction_dim; @@ -1748,6 +2247,33 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, } tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, Evaluate(Call(DataType::Handle(), op, args, ann_loop))); + + if (use_multicast) { + // Build multicast args using the same loop_var (safe since branches are + // mutually exclusive in the outer IfThenElse) + Array mc_args = build_multicast_args(args); + Stmt multicast_copy = For( + loop_var, 0, loop_extent, ForKind::kUnrolled, + Evaluate(Call(DataType::Handle(), tma_load_multicast(), mc_args))); + + // 3-way split based on cluster mask: + // block_rank == min_rank_in_mask -> tma_load_multicast (then branch) + // block_rank NOT in mask -> regular tma_load (elif branch) + // block_rank in mask, not min -> do nothing (no else) + int min_cta_rank = static_cast( + __builtin_ctzll(static_cast(cluster_mask))); + PrimExpr block_rank = + Call(DataType::Int(32), block_rank_in_cluster(), {}); + PrimExpr mask_imm = IntImm(DataType::Int(32), cluster_mask); + // (mask >> block_rank) & 1 == 0 ⟺ block_rank is NOT set in the mask + PrimExpr not_in_mask = EQ(bitwise_and(right_shift(mask_imm, block_rank), + IntImm(DataType::Int(32), 1)), + IntImm(DataType::Int(32), 0)); + Stmt regular_or_noop = IfThenElse(not_in_mask, tma_copy, std::nullopt); + tma_copy = + IfThenElse(EQ(block_rank, IntImm(DataType::Int(32), min_cta_rank)), + multicast_copy, regular_or_noop); + } } else { PrimExpr shared_addr = shared_tensor.access_ptr( is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, total_elements); @@ -1765,6 +2291,29 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, ann.Set("use_2cta", IntImm(DataType::Int(32), 1)); } tma_copy = Evaluate(Call(DataType::Handle(), op, args, ann)); + + if (use_multicast) { + Array mc_args = build_multicast_args(args); + Stmt multicast_copy = + Evaluate(Call(DataType::Handle(), tma_load_multicast(), mc_args)); + + // 3-way split based on cluster mask: + // block_rank == min_rank_in_mask -> tma_load_multicast (then branch) + // block_rank NOT in mask -> regular tma_load (elif branch) + // block_rank in mask, not min -> do nothing (no else) + int min_cta_rank = static_cast( + __builtin_ctzll(static_cast(cluster_mask))); + PrimExpr block_rank = + Call(DataType::Int(32), block_rank_in_cluster(), {}); + PrimExpr mask_imm = IntImm(DataType::Int(32), cluster_mask); + PrimExpr not_in_mask = EQ(bitwise_and(right_shift(mask_imm, block_rank), + IntImm(DataType::Int(32), 1)), + IntImm(DataType::Int(32), 0)); + Stmt regular_or_noop = IfThenElse(not_in_mask, tma_copy, std::nullopt); + tma_copy = + IfThenElse(EQ(block_rank, IntImm(DataType::Int(32), min_cta_rank)), + multicast_copy, regular_or_noop); + } } // Bulk TMA stores participate in the cp.async.bulk group mechanism, so we @@ -1866,6 +2415,18 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, ICHECK(copy_inst == CopyInst::kBulkLoad1D || copy_inst == CopyInst::kBulkStore1D); + // 1D TMA uses cp.async.bulk which has no multicast variant; a descriptor- + // based bulk load (kBulkLoad) is required instead. + { + int64_t cluster_mask = GetClusterMask(); + ICHECK(cluster_mask == 0) + << "cluster_mask=0x" << std::hex << cluster_mask + << " requires descriptor-based TMA (kBulkLoad); the 1D bulk-copy path " + "(kBulkLoad1D) does not support multicast. src=" + << src->name << " (scope=" << src.scope() << "), dst=" << dst->name + << " (scope=" << dst.scope() << ")."; + } + // Add 1D TMA copy when the global and shared memory is contiguous // Check if shared_tensor->name is present in T.buffer_var_gemm // (Array) to avoid use 1D TMA copy for swizzled layout @@ -2285,7 +2846,7 @@ void CopyNode::CollectFragmentLayouts(const PrimExpr &expr, // eviction_policy // - Marked as opaque since it has side effects (memory writes) TIR_REGISTER_TL_TILE_OP(Copy, copy) - .set_num_inputs(5) + .set_num_inputs(6) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/copy.h b/src/op/copy.h index 2f3a63070..bbab78b06 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -118,6 +118,8 @@ class CopyNode : public TileOperatorNode { public: Buffer src, dst; // Source and destination buffers Array src_range, dst_range; // Ranges for each dimension in src and dst + Optional dst_block; // Destination block index for cluster copy + Map annotations; // Annotations for the copy operation // Supported annotation keys: // - "coalesced_width": IntImm, width for coalesced memory access @@ -139,6 +141,7 @@ class CopyNode : public TileOperatorNode { .def_ro("dst", &CopyNode::dst) .def_ro("src_range", &CopyNode::src_range) .def_ro("dst_range", &CopyNode::dst_range) + .def_ro("dst_block", &CopyNode::dst_block) .def_ro("annotations", &CopyNode::annotations); } @@ -170,6 +173,28 @@ class CopyNode : public TileOperatorNode { return 0; // default: evict_normal } + // Returns the cluster multicast mask (0 means no multicast / regular TMA + // load) + int64_t GetClusterMask() const { + if (auto val = annotations.Get("cluster_mask")) { + if (auto int_val = val->as()) { + return int_val->value; + } + } + return 0; + } + + // Returns the mbarrier BufferLoad (as PrimExpr) used for tma_store_cluster, + // or an empty Optional when the "barrier" annotation is absent. + Optional GetBarrier() const { + if (auto val = annotations.Get("barrier")) { + if (val->as()) { + return Downcast(val.value()); + } + } + return Optional(); + } + bool GetIsAsyncCopy() const { if (auto val = annotations.Get("is_async_copy")) { if (auto int_val = val->as()) { @@ -291,6 +316,7 @@ class CopyNode : public TileOperatorNode { */ Stmt LowerTmemCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; + Stmt LowerClusterCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; /*! * \brief Generate lowering for normal copy. */ diff --git a/src/op/operator.h b/src/op/operator.h index 4377b1085..8deb0edf2 100644 --- a/src/op/operator.h +++ b/src/op/operator.h @@ -24,6 +24,11 @@ using namespace tir; using AddWorkspaceCallback = std::function; using AllocMBarrierCallback = std::function; +// Called by LowerClusterCopy when the multi-TMA path decomposes a +// non-contiguous region into N separate tma_store_cluster calls. The pass +// records the (barrier_data_var → N) mapping and updates the barrier_init +// block annotation before LowerSharedBarrier consumes it. +using UpdateBarrierArriveCallback = std::function; using LayoutMap = Map; using BufferMap = Map; @@ -53,6 +58,7 @@ struct LowerArgs { Var thread_var; AddWorkspaceCallback AddWorkspace; AllocMBarrierCallback AllocMBarrier; + UpdateBarrierArriveCallback UpdateBarrierArrive; LayoutMap layout_map; Map buffer_remap; // Map from LetStmt variable to its bound expression, for resolving diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 31a4606b3..0715caebf 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1771,6 +1771,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->stream << ss.str(); this->stream << ");\n"; }; + auto print_mbarrier_obj = [&](PrimExpr barrier_id) { + std::ostringstream ss; + if (barrier_id.as()) { + ss << mbarrier_name_ << "[" << barrier_id << "]"; + } else { + ss << this->PrintExpr(barrier_id); + } + return ss.str(); + }; if (op->op.same_as(builtin::ptx_cp_async())) { // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, // args[3] = predicate (optional) @@ -1939,6 +1948,35 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ss << ");\n"; this->PrintIndent(); this->stream << ss.str(); + } else if (op->op.same_as(tl::tma_load_multicast())) { + // args layout: [descriptor, mbarrier, smem_ptr, multicast_mask, + // coord_0, ..., coord_{n-1}, eviction_policy] + std::ostringstream ss; + ICHECK_GE(op->args.size(), 5) + << "tma_load_multicast requires at least 5 args"; + auto eviction_policy = + this->eviction_policy_names_ + [op->args[op->args.size() - 1].as()->value]; + if (eviction_policy != "EVICT_NORMAL") { + ss << "tl::tma_load_multicast("; + } else { + ss << "tl::tma_load_multicast("; + } + // descriptor + ss << this->PrintExpr(op->args[0]) << ", "; + // mbarrier + ss << this->PrintExpr(op->args[1]) << ", "; + ss << this->PrintExpr(op->args[2]) << ", "; + // multicast_mask cast to uint16_t + ss << "(uint16_t)(" << this->PrintExpr(op->args[3]) << ")"; + // global coordinates (args[4..N-2]) + for (size_t i = 4; i < op->args.size() - 1; i++) { + ss << ", " << this->PrintExpr(op->args[i]); + } + ss << ");\n"; + this->PrintIndent(); + this->stream << ss.str(); } else if (op->op.same_as(tl::tma_load_im2col())) { std::stringstream ss; auto eviction_policy = @@ -2204,6 +2242,72 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { replacer.register_rule("(C_ptr)", c_ref); replacer.register_rule("(C_offset)", c_bias); this->stream << replacer.rewrite(mma_call); + } else if (op->op.same_as(tl::tma_store_cluster())) { + // args: [dst_ptr, src_ptr, dst_cta, size_bytes, bar_ref] + // dst_ptr – address_of(dst_buf[offset]) (void*) + // src_ptr – address_of(src_buf[offset]) (void*) + // dst_cta – destination CTA rank (int) + // size_bytes– bytes to transfer (uint32_t) + // bar_ref – mbarrier element (uint64_t&) + ICHECK_EQ(op->args.size(), 5U) << "tma_store_cluster requires 5 args"; + this->PrintIndent(); + this->stream << "tl::tma_store_cluster("; + this->stream << this->PrintExpr(op->args[0]) << ", "; + this->stream << this->PrintExpr(op->args[1]) << ", "; + this->stream << "(int)(" << this->PrintExpr(op->args[2]) << "), "; + this->stream << "(uint32_t)(" << this->PrintExpr(op->args[3]) << "), "; + this->stream << this->PrintExpr(op->args[4]) << ");\n"; + + } else if (op->op.same_as(tl::ptx_cluster_store())) { + // arg 0: buffer var (handle) + // arg 1: value to store + // arg 2: dst block index + // arg 3: linearized index + + ICHECK_EQ(op->args.size(), 4U); + + std::string buffer_var = this->PrintExpr(op->args[0]); + std::string value = this->PrintExpr(op->args[1]); + std::string dst_block = this->PrintExpr(op->args[2]); + std::string index = this->PrintExpr(op->args[3]); + + // We need to cast the buffer var to the correct type if it's not already. + // But `buffer_var` here is likely just the name of the variable. + // We need to handle the type of the value being stored. + + // The generated code should look like: + // { + // namespace cg = cooperative_groups; + // cg::cluster_group cluster = cg::this_cluster(); + // auto* dst_ptr = cluster.map_shared_rank(&buffer[index], dst_block); + // *dst_ptr = value; + // } + + // However, `buffer[index]` might be an expression. + // We need to get the address of `buffer[index]`. + // If `buffer` is a pointer, then `&buffer[index]` is `buffer + index`. + + // We need to know the type of the value to cast the pointer correctly if + // needed. But `map_shared_rank` is a template or returns a pointer to the + // same type? `T* map_shared_rank(T* addr, unsigned int rank)` + + this->need_cooperative_groups_ = true; + this->PrintIndent(); + this->stream << "{\n"; + int cluster_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "namespace cg = cooperative_groups;\n"; + this->PrintIndent(); + this->stream << "cg::cluster_group cluster = cg::this_cluster();\n"; + this->PrintIndent(); + this->stream << "auto* dst_ptr = cluster.map_shared_rank(&" << buffer_var + << "[" << index << "], " << dst_block << ");\n"; + this->PrintIndent(); + this->stream << "*dst_ptr = " << value << ";\n"; + this->EndScope(cluster_scope); + this->PrintIndent(); + this->stream << "}\n"; + } else if (op->op.same_as(tl::ptx_mma_sm70())) { // arg 0: shape: mXnXkX // arg 1: A layout: row/col @@ -3090,6 +3194,11 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { os << PrintExpr(op->args[i]); } os << ")"; + } else if (op->op.same_as(tl::get_cluster_block_nums())) { + ICHECK_EQ(op->args.size(), 0) + << "tl.get_cluster_block_nums expects no arguments."; + need_cluster_h_ = true; + os << "([]{auto s=tl::cluster_shape();return (int)(s.x*s.y*s.z);}())"; } else if (op->op.same_as(tl::tl_shuffle_elect())) { os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()"; } else if (op->op.same_as(tl::initialize_wgmma_descriptor())) { diff --git a/src/tl_templates/cuda/cluster.h b/src/tl_templates/cuda/cluster.h index 0da4dc904..8b1398d0a 100644 --- a/src/tl_templates/cuda/cluster.h +++ b/src/tl_templates/cuda/cluster.h @@ -11,6 +11,8 @@ namespace tl { +TL_DEVICE void cluster_unsupported_trap() { asm volatile("trap;"); } + TL_DEVICE void cluster_arrive_relaxed() { #if defined(TILELANG_CLUSTER_ENABLED) asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : :); diff --git a/src/tl_templates/cuda/copy_sm90.h b/src/tl_templates/cuda/copy_sm90.h index 3d5b3f414..dd89bfa30 100644 --- a/src/tl_templates/cuda/copy_sm90.h +++ b/src/tl_templates/cuda/copy_sm90.h @@ -38,6 +38,167 @@ TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr, :); } +// Generic SM-to-SM async bulk copy via cp.async.bulk.shared::cluster +template +TL_DEVICE void tma_store_cluster(void *dst, void *src, int dst_cta, + uint32_t size_bytes, BarrierType &bar) { + uint32_t mbarrier_ptr = static_cast( + __cvta_generic_to_shared(reinterpret_cast(&bar))); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(src)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); + + uint32_t neighbor_addr_dst; + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_addr_dst) + : "r"(dst_ptr), "r"(dst_cta)); + + uint32_t neighbor_addr_mbarrier; + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_addr_mbarrier) + : "r"(mbarrier_ptr), "r"(dst_cta)); + + // Arrive at the remote barrier and announce the expected TX byte count. + // This satisfies one arrival (matching the mbarrier_init count) and tells + // the barrier how many bytes the subsequent cp.async.bulk will transfer. + asm volatile("mbarrier.arrive.expect_tx.shared::cluster.b64 _, [%0], %1;\n" + : + : "r"(neighbor_addr_mbarrier), "r"(size_bytes) + : "memory"); + + asm volatile("fence.proxy.async.shared::cta;\n" ::: "memory"); + asm volatile("cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_" + "tx::bytes [%0], [%1], %2, [%3];\n" + : + : "r"(neighbor_addr_dst), "r"(src_ptr), "r"(size_bytes), + "r"(neighbor_addr_mbarrier) + : "memory"); +} + +template +TL_DEVICE void +tma_load_multicast(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, uint16_t multicast_mask, + int32_t const &crd0) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4, %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "h"(multicast_mask), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void +tma_load_multicast(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, uint16_t multicast_mask, + int32_t const &crd0, int32_t const &crd1) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5, %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "r"(crd1), "h"(multicast_mask), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_load_multicast(const CUtensorMap &descriptor, + BarrierType &smem_mbar, + void const *const smem_ptr, + uint16_t multicast_mask, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "r"(crd1), "r"(crd2), "h"(multicast_mask), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void +tma_load_multicast(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, uint16_t multicast_mask, + int32_t const &crd0, int32_t const &crd1, + int32_t const &crd2, int32_t const &crd3) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7, %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "r"(crd1), "r"(crd2), "r"(crd3), "h"(multicast_mask), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_load_multicast(const CUtensorMap &descriptor, + BarrierType &smem_mbar, + void const *const smem_ptr, + uint16_t multicast_mask, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2, + int32_t const &crd3, int32_t const &crd4) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "h"(multicast_mask), + "l"(cache_hint) + : "memory"); +} + template TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, diff --git a/src/transform/common/thread_sync_types.h b/src/transform/common/thread_sync_types.h index bbcf4c2b4..be9099dcc 100644 --- a/src/transform/common/thread_sync_types.h +++ b/src/transform/common/thread_sync_types.h @@ -28,7 +28,13 @@ enum class ReservedNamedBarriers : uint8_t { kSyncThreads = 0, kReduce_0 = 1, kReduce_1 = 2, - kFirstUsedBarrier = kReduce_1 + 1 + // TileLang convention for 256-thread CTA split into two 128-thread groups. + // Producer: threadIdx.x in [128, 255] + // Consumer: threadIdx.x in [0, 127] + // These must be distinct to avoid mixing barrier states and deadlocks. + kProducer = kReduce_1 + 1, + kConsumer = kProducer + 1, + kFirstUsedBarrier = kConsumer + 1 }; } // namespace tl diff --git a/src/transform/inject_fence_proxy.cc b/src/transform/inject_fence_proxy.cc index 14f130e78..97d1b2b3d 100644 --- a/src/transform/inject_fence_proxy.cc +++ b/src/transform/inject_fence_proxy.cc @@ -94,8 +94,8 @@ bool IsAsyncIntrinsic(const CallNode *call) { // TileLang async intrinsics if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || - call->op.same_as(tma_store()) || call->op.same_as(ptx_wgmma_ss()) || - call->op.same_as(ptx_wgmma_rs()) || + call->op.same_as(tma_load_multicast()) || call->op.same_as(tma_store()) || + call->op.same_as(ptx_wgmma_ss()) || call->op.same_as(ptx_wgmma_rs()) || call->op.same_as(ptx_tcgen05_mma_ss()) || call->op.same_as(ptx_tcgen05_mma_ts())) { return true; diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index 56a7282ce..318e6f059 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -19,6 +19,11 @@ namespace tl { using namespace tir; #if (CUDA_MAJOR_VERSION >= 12) +// Barrier IDs 0–2 are reserved: 0 by InjectTmaBarrier (descriptor-based TMA +// loads) and 1–2 by the backend for internal synchronization (e.g. AllReduce). +// User-managed barriers must start after this reserved range. +static constexpr int kReservedBarriers = 3; + class LowerHopperIntrin : public StmtExprMutator { public: static PrimFunc Substitute(PrimFunc &f, bool disable_shuffle_elect) { @@ -177,6 +182,30 @@ class LowerHopperIntrin : public StmtExprMutator { } } + Stmt VisitStmt_(const IfThenElseNode *op) final { + Stmt new_stmt = StmtExprMutator::VisitStmt_(op); + if (const auto *if_node = new_stmt.as()) { + if (IsNoOp(if_node->then_case) && (!if_node->else_case.defined() || + IsNoOp(if_node->else_case.value()))) { + return Evaluate(0); + } + } + return new_stmt; + } + + bool IsNoOp(const Stmt &stmt) { + if (const auto *eval = stmt.as()) { + return is_const_int(eval->value, 0); + } else if (const auto *seq = stmt.as()) { + for (const auto &s : seq->seq) { + if (!IsNoOp(s)) + return false; + } + return true; + } + return false; + } + private: Array prefetch_calls_; std::unordered_map desc_map_; diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 06bec7d10..987a4c1ae 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -345,6 +345,36 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { } workspace_stack_.pop_back(); } + + // Apply any barrier arrive-count overrides registered by LowerClusterCopy + // during the multi-TMA decomposition path. We update the barrier_init + // annotation here (before LowerSharedBarrier consumes it) so that the + // mbarrier is initialised with arrive_count = N (number of TMA rows). + if (!barrier_arrive_updates_.empty() && + block->annotations.count("barrier_init")) { + auto barrier_init_map = Downcast>>( + block->annotations.Get("barrier_init").value()); + bool updated = false; + for (auto it = barrier_arrive_updates_.begin(); + it != barrier_arrive_updates_.end();) { + if (barrier_init_map.count(it->first)) { + auto old_counts = barrier_init_map.at(it->first); + Array new_counts; + for (size_t i = 0; i < old_counts.size(); i++) { + new_counts.push_back(it->second); + } + barrier_init_map.Set(it->first, new_counts); + updated = true; + it = barrier_arrive_updates_.erase(it); + } else { + ++it; + } + } + if (updated) { + block_ptr->annotations.Set("barrier_init", barrier_init_map); + } + } + return block; } @@ -719,6 +749,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const tir::CallNode *op) final { if (op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_load_im2col()) || + op->op.same_as(tl::tma_load_multicast()) || op->op.same_as(tl::tma_store())) { // skip tma related calls, as they were transformed implicitly. has_tma_ = true; @@ -727,6 +758,12 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { in_tma_context_ = false; return call; } + if (op->op.same_as(tl::tma_store_cluster())) { + // SM-to-SM bulk async copy does not use a tensor-map descriptor, so + // shared-memory swizzle must still be reflected in pointer/index + // remapping. + return Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + } if (is_ptx_) { return Downcast(op); @@ -1079,6 +1116,11 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { return id; }; + UpdateBarrierArriveCallback barrier_arrive_callback = [this](Var data_var, + PrimExpr n) { + barrier_arrive_updates_[data_var] = n; + }; + // Compute mbarrier expressions from the enclosing loop and pipeline info. // pipeline_num_stages: number of pipeline stages (from T.Pipelined // annotation) mbar_stage_expr: ko % num_stages (cycles through multiple @@ -1100,8 +1142,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { auto lowered = tile_op->Lower( LowerArgs{target_, thread_bounds, thread_var_->var, callback, - mbarrier_callback, layout_map_, buffer_remap_, - let_var_to_expr, + mbarrier_callback, barrier_arrive_callback, layout_map_, + buffer_remap_, let_var_to_expr, /*in_pipeline=*/pipelined_depth_ > 0, mbar_phase_expr, pipeline_num_stages, mbar_stage_expr, &mbarrier_buffer_, cluster_size_}, @@ -1405,6 +1447,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { // parameters rather than in memory indices. bool in_tma_context_{false}; int pipelined_depth_{0}; + // Pending barrier arrive-count overrides from multi-TMA cluster-copy + // decomposition. Maps barrier buffer data Var → new arrive count N. + // Populated by LowerClusterCopy via UpdateBarrierArriveCallback and + // consumed (then cleared) in VisitStmt_(BlockNode) before LowerSharedBarrier + // processes the barrier_init annotation. + std::unordered_map + barrier_arrive_updates_; }; namespace transform { diff --git a/src/transform/multi_version_buffer_rewriter.cc b/src/transform/multi_version_buffer_rewriter.cc index a3a5cab4f..23849a61e 100644 --- a/src/transform/multi_version_buffer_rewriter.cc +++ b/src/transform/multi_version_buffer_rewriter.cc @@ -42,7 +42,8 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor { void VisitStmt_(const EvaluateNode *op) final { Role role = Role::kConsumer; if (auto call = op->value.as()) { - if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || + call->op.same_as(tma_load_multicast())) { role = Role::kProducer; has_bulk_copy_ = true; } @@ -135,7 +136,9 @@ class MultiVersionBufferRewriter : public StmtExprMutator { Var buffer_var = buffer->data; rewriter.buffer_data_to_buffer_.Set(buffer_var, buffer); } - f.CopyOnWrite()->body = rewriter(f->body); + Stmt rewritten = rewriter(f->body); + rewritten = rewriter.FinalizeRaggedPrefixAllocation(std::move(rewritten)); + f.CopyOnWrite()->body = std::move(rewritten); return f; } @@ -143,6 +146,36 @@ class MultiVersionBufferRewriter : public StmtExprMutator { explicit MultiVersionBufferRewriter(bool barrier_only = false) : barrier_only_(barrier_only) {} + void EnsureRaggedPrefixBuffer() { + if (ragged_prefix_buf_.defined()) { + return; + } + Array shape = {IntImm(DataType::Int(32), 1)}; + ragged_prefix_buf_ = + decl_buffer(shape, DataType::Int(32), "tl_mvb_ragged_prefix", "local"); + } + + PrimExpr LoadRaggedPrefix() { + EnsureRaggedPrefixBuffer(); + Array zero_indices = {0}; + return BufferLoad(ragged_prefix_buf_, zero_indices); + } + + Stmt FinalizeRaggedPrefixAllocation(Stmt body) { + if (!needs_ragged_prefix_ || inserted_ragged_prefix_) { + return body; + } + EnsureRaggedPrefixBuffer(); + Array zero_indices = {0}; + Stmt init = BufferStore(ragged_prefix_buf_, IntImm(DataType::Int(32), 0), + zero_indices); + Stmt seq = SeqStmt({init, body}); + seq = DeclBuffer(ragged_prefix_buf_, seq); + inserted_ragged_prefix_ = true; + return Allocate(ragged_prefix_buf_->data, ragged_prefix_buf_->dtype, + ragged_prefix_buf_->shape, const_true(), seq); + } + Array GetVersionedBuffers(const Array &seq_stmt, const Array &scoped_buffers) { Array pipeline_stmts; @@ -339,6 +372,37 @@ class MultiVersionBufferRewriter : public StmtExprMutator { return stmt; } + Stmt VisitStmt_(const AttrStmtNode *op) final { + // Make sure any ragged prefix allocation lives inside the device thread + // scope (threadIdx.x). Allocating at PrimFunc scope can be lifted into an + // extra kernel parameter by downstream lowering. + stmt_stack_.push_back(op); + Stmt body = this->VisitStmt(op->body); + stmt_stack_.pop_back(); + + bool is_thread_extent = (op->attr_key == tir::attr::thread_extent); + bool is_threadidx_x = false; + if (is_thread_extent) { + if (const auto *iv = op->node.as()) { + is_threadidx_x = (iv->thread_tag == "threadIdx.x"); + } + } + + if (needs_ragged_prefix_ && is_threadidx_x && !inserted_ragged_prefix_) { + inserted_ragged_prefix_ = true; + EnsureRaggedPrefixBuffer(); + Array zero_indices = {0}; + Stmt init = BufferStore(ragged_prefix_buf_, IntImm(DataType::Int(32), 0), + zero_indices); + Stmt seq = SeqStmt({init, body}); + seq = DeclBuffer(ragged_prefix_buf_, seq); + body = Allocate(ragged_prefix_buf_->data, ragged_prefix_buf_->dtype, + ragged_prefix_buf_->shape, const_true(), seq); + } + + return AttrStmt(op->node, op->attr_key, op->value, body); + } + Stmt VisitStmt_(const ForNode *op) final { stmt_stack_.push_back(op); loop_stack_.emplace_back(op->loop_var, op->extent); @@ -353,6 +417,15 @@ class MultiVersionBufferRewriter : public StmtExprMutator { ICHECK(num_stages_anno->as()); int num_stages = static_cast(num_stages_anno->as()->value); + // A single-stage software pipeline does not require multi-versioned + // buffers or ragged-prefix bookkeeping; keep the loop unchanged. + if (num_stages <= 1) { + auto for_node = StmtExprMutator::VisitStmt_(op); + loop_stack_.pop_back(); + stmt_stack_.pop_back(); + return for_node; + } + Stmt pipeline_body_root{nullptr}; if (const auto *realize = op->body.as()) { const auto &block = realize->block; @@ -474,6 +547,28 @@ class MultiVersionBufferRewriter : public StmtExprMutator { buffer_data_to_buffer_.Set(buffer_var, buffer); } } + PrimExpr version_index_before = version_index_; + + // For ragged (dynamic extent) pipelined loops, rectangular linearization + // using loop extents is invalid. Use a runtime prefix counter so + // ping-pong buffer selection stays consistent across outer iterations. + bool is_dynamic_extent = !op->extent.as(); + if (is_dynamic_extent) { + needs_ragged_prefix_ = true; + version_index_ = + FloorMod(LoadRaggedPrefix() + (op->loop_var - op->min), num_stages); + Stmt for_node = StmtExprMutator::VisitStmt_(op); + version_index_ = version_index_before; + + Array zero_indices = {0}; + Stmt update = BufferStore(ragged_prefix_buf_, + LoadRaggedPrefix() + op->extent, zero_indices); + + loop_stack_.pop_back(); + stmt_stack_.pop_back(); + return SeqStmt({for_node, update}); + } + PrimExpr linear_index = loop_stack_[0].first; for (size_t i = 1; i < loop_stack_.size(); ++i) { linear_index = @@ -487,6 +582,7 @@ class MultiVersionBufferRewriter : public StmtExprMutator { pipeline_loop_var_ = op->loop_var; pipeline_loop_min_ = op->min; auto for_node = StmtExprMutator::VisitStmt_(op); + version_index_ = version_index_before; parity_cycle_ = PrimExpr(); // reset pipeline_loop_var_ = Var(); pipeline_loop_min_ = PrimExpr(); @@ -612,6 +708,9 @@ class MultiVersionBufferRewriter : public StmtExprMutator { bool barrier_only_; PrimExpr version_index_; + Buffer ragged_prefix_buf_; + bool needs_ragged_prefix_ = false; + bool inserted_ragged_prefix_ = false; PrimExpr parity_cycle_; // (k / num_stages) % 2 for mbarrier parity rewriting Var pipeline_loop_var_; // loop variable of the pipelined loop PrimExpr pipeline_loop_min_; // min value of the pipelined loop diff --git a/src/transform/pipeline_planning.cc b/src/transform/pipeline_planning.cc index 55a8b118d..16028ad55 100644 --- a/src/transform/pipeline_planning.cc +++ b/src/transform/pipeline_planning.cc @@ -408,10 +408,18 @@ class BufferRegionCollector : public StmtExprVisitor { buffer_reads->second.end()); } if (buffer_writes != chain_builder_.mbar_to_buffer_writes_.end()) { - writes_.insert( - writes_.end(), - chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).begin(), - chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).end()); + writes_.insert(writes_.end(), buffer_writes->second.begin(), + buffer_writes->second.end()); + } + } else { + // Handle-based mbarrier (e.g. get_mbarrier(id)): cannot resolve to a + // concrete Buffer. Conservatively attach all known async buffer + // dependencies so the wait is not treated as dependency-free. + for (const auto &[_, regions] : chain_builder_.mbar_to_buffer_reads_) { + reads_.insert(reads_.end(), regions.begin(), regions.end()); + } + for (const auto &[_, regions] : chain_builder_.mbar_to_buffer_writes_) { + writes_.insert(writes_.end(), regions.begin(), regions.end()); } } } else { diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index c06be1eb1..8c7291d94 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -457,11 +457,33 @@ class ThreadPartialSyncRewriter : public IRMutatorWithAnalyzer { if (barrier_id_map_.count(key)) { return {barrier_id_map_[key], thread_count_map_[key]}; } + size_t thread_count = extent_tx * extent_ty * extent_tz; + + // Special-case: enforce distinct, fixed barrier IDs for the common + // producer/consumer split (256-thread CTA -> two 128-thread groups). + // This avoids accidentally mixing the same barrier id between the two + // halves, which can deadlock when both halves use bar.sync. + if (thread_count == 128 && key.ty_min == key.ty_max && + key.tz_min == key.tz_max) { + if (key.tx_min == 0 && key.tx_max == 127) { + size_t barrier_id = + static_cast(ReservedNamedBarriers::kConsumer); + barrier_id_map_[key] = barrier_id; + thread_count_map_[key] = thread_count; + return {barrier_id, thread_count}; + } + if (key.tx_min == 128 && key.tx_max == 255) { + size_t barrier_id = + static_cast(ReservedNamedBarriers::kProducer); + barrier_id_map_[key] = barrier_id; + thread_count_map_[key] = thread_count; + return {barrier_id, thread_count}; + } + } size_t barrier_id = barrier_id_map_.size() + static_cast(ReservedNamedBarriers::kFirstUsedBarrier); - size_t thread_count = extent_tx * extent_ty * extent_tz; barrier_id_map_[key] = barrier_id; thread_count_map_[key] = thread_count; @@ -1078,7 +1100,8 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { if (auto opt = op->op.as()) { const Op &call_op = opt.value(); return call_op.same_as(tl::tma_load()) || - call_op.same_as(tl::tma_load_im2col()); + call_op.same_as(tl::tma_load_im2col()) || + call_op.same_as(tl::tma_load_multicast()); } return false; }(); diff --git a/src/transform/warp_specialized_rewriter.h b/src/transform/warp_specialized_rewriter.h index e9a9bbf5c..881d72201 100644 --- a/src/transform/warp_specialized_rewriter.h +++ b/src/transform/warp_specialized_rewriter.h @@ -31,7 +31,7 @@ using arith::IRVisitorWithAnalyzer; class WarpSpecializedDetector : public IRVisitorWithAnalyzer { public: - // return true means this aws will be disabled + // return true means auto warp specialization will be disabled static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) { WarpSpecializedDetector detector; detector.VisitStmt(stmt); @@ -40,6 +40,11 @@ class WarpSpecializedDetector : public IRVisitorWithAnalyzer { "specialization is manually enabled"; return true; } + // TMA loads with mbarrier ops indicate a user-managed pipeline that is + // incompatible with auto warp specialisation. The cluster-copy exemption + // only applies when mbarrier ops exist *without* any regular TMA-load + // path (i.e. barriers are solely for SM-to-SM cluster copy sync), which + // naturally falls through here because has_tma_op_ is false in that case. if (detector.has_tma_op_ && detector.has_mbarrier_op_) { LOG(WARNING) << "Auto warp specialization will be disabled because TMA " "and mbarrier are both present"; @@ -69,6 +74,7 @@ class WarpSpecializedDetector : public IRVisitorWithAnalyzer { void VisitExpr_(const CallNode *op) final { if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) || + op->op.same_as(tma_load_multicast()) || op->op.same_as(set_max_nreg())) { has_tma_op_ = true; } diff --git a/testing/python/cuda/test_tma_dsmem.py b/testing/python/cuda/test_tma_dsmem.py new file mode 100644 index 000000000..d2b6a423d --- /dev/null +++ b/testing/python/cuda/test_tma_dsmem.py @@ -0,0 +1,341 @@ +""" +Regression tests for SM-to-SM cluster copy (T.copy_cluster with dst_block). + +Three lowering paths are covered: + + Fast path (test_tma_store_cluster): + T.copy_cluster(src, dst, dst_block=1, remote_barrier=bar) + → single tl::tma_store_cluster issued by one thread; mbarrier completion + is tracked by the TMA hardware via mbarrier.arrive.expect_tx. + + SIMT fallback, no barrier (test_store_cluster_simt_no_barrier): + T.copy_cluster(src, dst, dst_block=1) # no remote_barrier + → element-wise cooperative_groups::map_shared_rank stores by all threads; + caller uses T.cluster_sync() for ordering. + + SIMT fallback, with barrier (test_store_cluster_simt_barrier): + T.copy_cluster(src2d[0:M, 0:N_tile], dst2d[0:M, 0:N_tile], + dst_block=1, remote_barrier=bar) + where N_tile < N_full, so the inner-dim extent fails the contiguity check. + → element-wise map_shared_rank stores, followed by auto-injected + __syncthreads(); + if (threadIdx.x == 0) s_barrier[0].arrive(1u); + so the destination CTA can wait on the same mbarrier as in the fast path. +""" + +import torch +import tilelang +import tilelang.language as T +import tilelang.testing +import numpy as np + + +# --------------------------------------------------------------------------- +# Fast path kernel +# --------------------------------------------------------------------------- + + +def make_store_cluster_kernel(N: int): + @T.prim_func + def kernel( + A: T.Tensor((N,), "float32"), + B: T.Tensor((N,), "float32"), + ): + with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: + s_src = T.alloc_shared((N,), "float32") + s_dst = T.alloc_shared((N,), "float32") + s_barrier = T.alloc_cluster_barrier([1]) + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + T.cluster_sync() + + if pid == 0: + for i in T.Parallel(N): + s_src[i] = A[i] + T.copy_cluster(s_src, s_dst, dst_block=1, remote_barrier=s_barrier[0]) + + if pid == 1: + T.mbarrier_wait_parity(s_barrier[0], 0) + for i in T.Parallel(N): + B[i] = s_dst[i] + + return kernel + + +# --------------------------------------------------------------------------- +# SIMT fallback, no barrier +# --------------------------------------------------------------------------- + + +def make_store_cluster_simt_no_barrier_kernel(N: int): + """No remote_barrier → SIMT fallback always taken; cluster_sync() orders stores.""" + + @T.prim_func + def kernel( + A: T.Tensor((N,), "float32"), + B: T.Tensor((N,), "float32"), + ): + with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: + s_src = T.alloc_shared((N,), "float32") + s_dst = T.alloc_shared((N,), "float32") + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + T.cluster_sync() + + if pid == 0: + for i in T.Parallel(N): + s_src[i] = A[i] + # No remote_barrier: LowerClusterCopy always takes the SIMT path. + # All threads write into block 1's s_dst via map_shared_rank. + T.copy_cluster(s_src, s_dst, dst_block=1) + + # Full cluster barrier: ensures all map_shared_rank stores from + # block 0 are visible in block 1's address space before block 1 + # reads s_dst. + T.cluster_sync() + + if pid == 1: + for i in T.Parallel(N): + B[i] = s_dst[i] + + return kernel + + +# --------------------------------------------------------------------------- +# SIMT fallback, with auto-injected ptx_arrive_cluster_barrier +# --------------------------------------------------------------------------- + + +def make_store_cluster_simt_barrier_kernel(M: int, N_full: int, N_tile: int): + """2-D slice copy that forces the SIMT fallback even though remote_barrier is set. + + s_src / s_dst are allocated with inner dimension N_full, but only the + first N_tile columns are copied. Because N_tile < N_full the + is_contiguous_region() check fails: the inner-dim extent of the copy + region (N_tile) does not equal the buffer shape (N_full). + + LowerClusterCopy falls back to map_shared_rank stores and, because + remote_barrier was supplied, automatically appends: + __syncthreads(); + if (threadIdx.x == 0) s_barrier[0].arrive(1u); + Block 1 therefore waits on the same mbarrier as in the fast-path API, + verifying that ptx_arrive_cluster_barrier is injected and functional. + """ + + @T.prim_func + def kernel( + A: T.Tensor((M, N_tile), "float32"), + B: T.Tensor((M, N_tile), "float32"), + ): + with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: + # Deliberately wider buffer: N_full > N_tile so the slice + # [0:M, 0:N_tile] is non-contiguous in row-major storage. + s_src = T.alloc_shared((M, N_full), "float32") + s_dst = T.alloc_shared((M, N_full), "float32") + s_barrier = T.alloc_cluster_barrier([1]) + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + T.cluster_sync() + + if pid == 0: + for i, j in T.Parallel(M, N_tile): + s_src[i, j] = A[i, j] + + # [0:M, 0:N_tile] inner-dim extent N_tile != N_full + # → contiguity check fails → SIMT fallback. + # Compiler auto-injects: __syncthreads() + + # if (t == 0) s_barrier[0].arrive(1u); + T.copy_cluster( + s_src[0:M, 0:N_tile], + s_dst[0:M, 0:N_tile], + dst_block=1, + remote_barrier=s_barrier[0], + ) + + if pid == 1: + # Block 1 waits on the auto-injected ptx_arrive_cluster_barrier. + T.mbarrier_wait_parity(s_barrier[0], 0) + for i, j in T.Parallel(M, N_tile): + B[i, j] = s_dst[i, j] + + return kernel + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_tma_store_cluster(): + """Fast path: T.copy_cluster emits tl::tma_store_cluster.""" + N = 128 + prim_func = make_store_cluster_kernel(N) + mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") + + src = mod.get_kernel_source() + assert "tl::tma_store_cluster" in src, ( + "Expected tl::tma_store_cluster in generated kernel source; " + "T.copy_cluster(dst_block=..., remote_barrier=...) may have regressed " + f"to the SIMT fallback.\nKernel source:\n{src}" + ) + + A = torch.arange(N, dtype=torch.float32, device="cuda") + B = mod(A) + np.testing.assert_allclose( + B.cpu().numpy(), + A.cpu().numpy(), + rtol=0, + atol=0, + err_msg="tma_store_cluster copy produced wrong result", + ) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_store_cluster_simt_no_barrier(): + """SIMT fallback (no remote_barrier): map_shared_rank + cluster_sync ordering.""" + N = 128 + prim_func = make_store_cluster_simt_no_barrier_kernel(N) + mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") + + src = mod.get_kernel_source() + assert "map_shared_rank" in src, f"Expected map_shared_rank in generated source for no-barrier SIMT fallback.\nKernel source:\n{src}" + assert "tl::tma_store_cluster" not in src, f"No-barrier path must NOT emit tl::tma_store_cluster.\nKernel source:\n{src}" + + A = torch.arange(N, dtype=torch.float32, device="cuda") + B = mod(A) + np.testing.assert_allclose( + B.cpu().numpy(), + A.cpu().numpy(), + rtol=0, + atol=0, + err_msg="SIMT no-barrier cluster copy produced wrong result", + ) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_store_cluster_multi_tma_barrier(): + """Multi-TMA path: non-contiguous 2-D slice decomposed into N row TMA calls. + + The compiler decomposes a non-full-span 2-D slice (shape [M, N_tile] inside + a buffer of shape [M, N_full]) into M individual tma_store_cluster calls – + one per contiguous row. The mbarrier arrive_count is updated to M so the + destination CTA's wait(0) completes only after all rows are transferred. + """ + M, N_full, N_tile = 4, 64, 32 # M rows, each N_tile elements + + prim_func = make_store_cluster_simt_barrier_kernel(M, N_full, N_tile) + mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") + + src = mod.get_kernel_source() + # Multi-TMA path must emit M separate tma_store_cluster calls, not SIMT stores. + assert "tl::tma_store_cluster" in src, f"Expected tl::tma_store_cluster for multi-TMA row decomposition.\nKernel source:\n{src}" + assert "map_shared_rank" not in src, f"Multi-TMA path must NOT fall back to map_shared_rank.\nKernel source:\n{src}" + # The barrier must be initialised with arrive_count == M (one per TMA call). + assert f"s_barrier[0].init({M})" in src, f"Expected barrier arrive_count={M} for {M}-row decomposition.\nKernel source:\n{src}" + # Exactly M tma_store_cluster calls should appear in the source. + assert src.count("tl::tma_store_cluster") == M, ( + f"Expected exactly {M} tma_store_cluster calls, got {src.count('tl::tma_store_cluster')}.\nKernel source:\n{src}" + ) + + A = torch.arange(M * N_tile, dtype=torch.float32, device="cuda").reshape(M, N_tile) + B = mod(A) + np.testing.assert_allclose( + B.cpu().numpy(), + A.cpu().numpy(), + rtol=0, + atol=0, + err_msg="Multi-TMA row-decomposed cluster copy produced wrong result", + ) + + +# --------------------------------------------------------------------------- +# 3-D multi-TMA: two levels of recursive decomposition +# --------------------------------------------------------------------------- + + +def make_store_cluster_3d_multi_tma_kernel(D: int, M: int, N_full: int, N_tile: int): + """3-D slice copy decomposed into D*M tma_store_cluster calls. + + Buffer shape is [D, M, N_full]; the copy region is [0:D, 0:M, 0:N_tile]. + Because N_tile < N_full the innermost dim is not full-span. + MakeTMARows recurses twice (once on dim 0, once on dim 1) producing D*M + contiguous-row TMA calls and sets barrier arrive_count = D*M. + """ + + @T.prim_func + def kernel( + A: T.Tensor((D, M, N_tile), "float32"), + B: T.Tensor((D, M, N_tile), "float32"), + ): + with T.Kernel(2, threads=D * M * N_tile, cluster_dims=(2, 1, 1)) as pid: + s_src = T.alloc_shared((D, M, N_full), "float32") + s_dst = T.alloc_shared((D, M, N_full), "float32") + s_barrier = T.alloc_cluster_barrier([1]) + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + T.cluster_sync() + + if pid == 0: + for d, i, j in T.Parallel(D, M, N_tile): + s_src[d, i, j] = A[d, i, j] + + T.copy_cluster( + s_src[0:D, 0:M, 0:N_tile], + s_dst[0:D, 0:M, 0:N_tile], + dst_block=1, + remote_barrier=s_barrier[0], + ) + + if pid == 1: + T.mbarrier_wait_parity(s_barrier[0], 0) + for d, i, j in T.Parallel(D, M, N_tile): + B[d, i, j] = s_dst[d, i, j] + + return kernel + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_store_cluster_3d_multi_tma(): + """3-D multi-TMA: recursive decomposition produces D*M separate TMA calls. + + With D=2 and M=4, the two-level recursion over dims 0 and 1 yields 8 + contiguous-row tma_store_cluster calls and initialises the barrier with + arrive_count=8. + """ + D, M, N_full, N_tile = 2, 4, 32, 16 # D*M*N_tile == 128 == thread count + + prim_func = make_store_cluster_3d_multi_tma_kernel(D, M, N_full, N_tile) + mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") + + src = mod.get_kernel_source() + n_expected = D * M + assert "tl::tma_store_cluster" in src, f"Expected tl::tma_store_cluster for 3-D multi-TMA.\nKernel source:\n{src}" + assert f"s_barrier[0].init({n_expected})" in src, ( + f"Expected barrier arrive_count={n_expected} for {D}x{M} decomposition.\nKernel source:\n{src}" + ) + assert src.count("tl::tma_store_cluster") == n_expected, ( + f"Expected exactly {n_expected} tma_store_cluster calls, got {src.count('tl::tma_store_cluster')}.\nKernel source:\n{src}" + ) + + A = torch.arange(D * M * N_tile, dtype=torch.float32, device="cuda").reshape(D, M, N_tile) + B = mod(A) + np.testing.assert_allclose( + B.cpu().numpy(), + A.cpu().numpy(), + rtol=0, + atol=0, + err_msg="3-D multi-TMA cluster copy produced wrong result", + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/cuda/test_tma_multicast_demo.py b/testing/python/cuda/test_tma_multicast_demo.py new file mode 100644 index 000000000..8127533ab --- /dev/null +++ b/testing/python/cuda/test_tma_multicast_demo.py @@ -0,0 +1,106 @@ +""" +TMA multicast validation demo. + +Verification logic: +- cluster_size=4, cluster_mask=0b0011 (bits 0 and 1 set, i.e. CTA ranks 0 and 1 are in the mask) +- CTA rank 0: issues tma_load_multicast, broadcasting its A tile to both rank 0 and rank 1 +- CTA rank 1: does not issue a load; passively receives the multicast data (same tile as rank 0) +- CTA ranks 2, 3: not in the mask, each performs a regular tma_load for its own tile + +Therefore within the same cluster: +- B at rank 0's region = A at rank 0's region +- B at rank 1's region = A at rank 0's region (multicast result, identical to rank 0) +- B at ranks 2, 3 regions = A at ranks 2, 3 respective regions + +The test verifies multicast by checking that rank 1's B region equals rank 0's A region. +""" + +import torch +import tilelang +import tilelang.language as T +import tilelang.testing + + +def make_tma_multicast_demo_kernel(M, N, block_M, block_N, cluster_mask): + """ + Build the TMA multicast demo kernel. + + cluster_mask: multicast bitmask. A set bit means the corresponding CTA + participates in multicast (receives the tile from the min-rank CTA). + e.g. 0b0011 means ranks 0 and 1 are in the mask; rank 0 issues + the multicast, rank 1 passively receives. + """ + + @T.prim_func + def kernel( + A: T.Tensor((M, N), "float16"), + B: T.Tensor((M, N), "float16"), + ): + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=128, + cluster_dims=(4, 1, 1), + ) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), "float16") + T.copy_cluster(A[by * block_M, bx * block_N], A_shared, cluster_mask=cluster_mask) + T.copy(A_shared, B[by * block_M, bx * block_N]) + + return kernel + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_tma_multicast_demo(): + """Verify TMA multicast: rank 1's B region should equal rank 0's A region within the same cluster.""" + M, N = 1024, 1024 + block_M, block_N = 128, 64 + # mask=0b0011: rank 0 multicasts, rank 1 receives, ranks 2/3 each do regular tma_load + cluster_mask = 0b0011 + + kernel = make_tma_multicast_demo_kernel(M, N, block_M, block_N, cluster_mask) + mod = tilelang.compile( + kernel, + out_idx=[1], + verbose=True, + execution_backend="cython", + ) + + print("--- TMA Multicast Demo Kernel Source ---") + print(mod.get_kernel_source()) + + A = torch.randn(M, N, device="cuda", dtype=torch.float16) + B = mod(A) + + # Within a cluster: the first 4 blocks in the grid are (0,0),(1,0),(2,0),(3,0) -> by=0, bx=0,1,2,3 + # rank 0 -> bx=0: A[0:block_M, 0:block_N] -> B[0:block_M, 0:block_N] + # rank 1 -> bx=1: multicast receives A[0:block_M, 0:block_N] -> B[0:block_M, block_N:2*block_N] + # rank 2 -> bx=2: A[0:block_M, 2*block_N:3*block_N] -> B[0:block_M, 2*block_N:3*block_N] + # rank 3 -> bx=3: A[0:block_M, 3*block_N:4*block_N] -> B[0:block_M, 3*block_N:4*block_N] + + # Multicast check: rank 1's B region should equal rank 0's A region + B_rank1 = B[0:block_M, block_N : 2 * block_N] + A_rank0 = A[0:block_M, 0:block_N] + torch.testing.assert_close(B_rank1, A_rank0, rtol=1e-2, atol=1e-2) + print("PASS: Multicast verified (B[rank1_region] == A[rank0_region])") + + # rank 0 itself: B should equal A + torch.testing.assert_close(B[0:block_M, 0:block_N], A[0:block_M, 0:block_N], rtol=1e-2, atol=1e-2) + # ranks 2, 3: each B region equals its own A region + torch.testing.assert_close( + B[0:block_M, 2 * block_N : 3 * block_N], + A[0:block_M, 2 * block_N : 3 * block_N], + rtol=1e-2, + atol=1e-2, + ) + torch.testing.assert_close( + B[0:block_M, 3 * block_N : 4 * block_N], + A[0:block_M, 3 * block_N : 4 * block_N], + rtol=1e-2, + atol=1e-2, + ) + print("PASS: TMA multicast demo passed") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 6980f0faf..b419d869d 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -1,11 +1,40 @@ from __future__ import annotations from tvm import tir, IRModule from tvm.target import Target +import tvm import tilelang from tilelang.transform import PassContext from tilelang.contrib.nvcc import have_tma, is_hopper, have_pdl +def _mod_uses_tma_barriers(mod: IRModule) -> bool: + """Return True if any PrimFunc in *mod* contains a ``create_barriers`` call. + + ``LowerHopperIntrin`` emits ``create_barriers`` into every device function + that uses TMA/mbarrier and nowhere else. Checking for that call after + lowering therefore tells us whether a given kernel actually needs the + 128-byte shared-memory alignment that TMA descriptors require, without + forcing the alignment on every kernel that merely runs on a Hopper GPU. + """ + found = [False] + + def _visit(node): + if found[0]: + return + if isinstance(node, tir.Call): + op = getattr(node, "op", None) + if op is not None and getattr(op, "name", "").endswith("create_barriers"): + found[0] = True + + for func in mod.functions.values(): + if found[0]: + break + if isinstance(func, tir.PrimFunc): + tvm.tir.stmt_functor.post_order_visit(func.body, _visit) + + return found[0] + + def allow_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: # avoid circular import from tilelang.jit.adapter.utils import is_cuda_target diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 7c02cbe6f..b2392e147 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -54,7 +54,7 @@ alloc_global, # noqa: F401 ) from tvm.script.parser.tir import allocate as allocate # noqa: F401 -from .copy_op import copy, async_copy, tma_copy, c2d_im2col # noqa: F401 +from .copy_op import copy, async_copy, tma_copy, c2d_im2col, copy_cluster # noqa: F401 from tilelang.tileop.base import GemmWarpPolicy # noqa: F401 from .gemm_op import gemm, gemm_v1, gemm_v2, wgmma_gemm, tcgen05_gemm # noqa: F401 from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401 diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 607537324..b09d7bbfb 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -605,6 +605,15 @@ def get_warp_group_idx( return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_group_idx"), *args) +def cluster_block_nums() -> PrimExpr: + """Return the number of blocks in the cluster. + + Lowers to `tl.get_cluster_block_nums` and emits + `cooperative_groups::this_cluster().num_blocks()` in CUDA codegen. + """ + return tir.call_intrin("int32", tir.op.Op.get("tl.get_cluster_block_nums")) + + def shuffle_elect(thread_extent: int) -> PrimExpr: """Elect exactly one lane within a logical thread group. diff --git a/tilelang/language/copy_op.py b/tilelang/language/copy_op.py index 24da86e01..0196405f5 100644 --- a/tilelang/language/copy_op.py +++ b/tilelang/language/copy_op.py @@ -120,6 +120,77 @@ def copy( return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, annotations=ann if ann else None) +def copy_cluster( + src: BufferLikeType, + dst: BufferLikeType, + *, + dst_block: int | tir.PrimExpr | None = None, + cluster_mask: int | None = None, + remote_barrier: tir.BufferLoad | None = None, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, + coalesced_width: int | None = None, + loop_layout: Any | None = None, +) -> tir.PrimExpr | tir.Stmt: + """Cluster-aware copy between shared memory regions or from global to shared memory. + + This is the entry point for two Hopper cluster features that require + ``cluster_dims`` to be set on the enclosing ``T.Kernel``: + + **TMA multicast** (global → shared, ``cluster_mask`` set): + The hardware delivers a single TMA load to the shared memory of every + masked CTA in the cluster simultaneously. At runtime the kernel splits + into three cases based on each CTA's rank: + + * **Leader** (rank == lowest set bit): issues ``tma_load_multicast``. + * **Masked peer** (other set bits): does nothing — receives passively. + * **Unmasked CTA**: issues a regular ``tma_load`` independently. + + **SM-to-SM copy** (shared → shared, ``dst_block`` set): + Copies from the local CTA's shared memory to a remote CTA's shared + memory within the same cluster. When ``remote_barrier`` is also + provided, a single bulk-async ``tl::tma_store_cluster`` instruction + is emitted; otherwise an element-by-element SIMT loop is used. + + Args: + src: Source memory region. + dst: Destination memory region. + dst_block: Destination CTA rank in the cluster for SM-to-SM copy. + cluster_mask: Bitmask of CTAs that participate in TMA multicast. + remote_barrier: Shared-memory mbarrier for asynchronous SM-to-SM copy + completion signalling. The destination CTA should wait on its + local copy of this barrier. + eviction_policy: Cache eviction hint passed to the TMA instruction. + Only relevant for the TMA multicast path (``cluster_mask`` set). + coalesced_width: Vectorization width (in elements) for the SIMT loop + used on the SM-to-SM fallback path (``dst_block`` set, no fast + bulk-async route available). + loop_layout: Parallel loop layout hint (Fragment) for the SIMT loop on + the SM-to-SM fallback path. Incompatible with the TMA multicast + path (``cluster_mask`` set). + + Returns: + tir.Call: A handle to the copy operation. + """ + src, dst = _normalize_copy_regions(src, dst) + + ann: dict = {} + if dst_block is not None: + ann["dst_block"] = dst_block + if cluster_mask is not None: + ann["cluster_mask"] = cluster_mask + if remote_barrier is not None: + ann["barrier"] = remote_barrier + if eviction_policy is not None: + eviction_policy_map = {"evict_normal": 0, "evict_first": 1, "evict_last": 2} + ann["eviction_policy"] = eviction_policy_map[eviction_policy] + if coalesced_width is not None: + ann["coalesced_width"] = coalesced_width + if loop_layout is not None: + ann["parallel_loop_layout"] = loop_layout + + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, annotations=ann if ann else None) + + def async_copy( src: BufferLikeType, dst: BufferLikeType,