From 80e409bd613066c215bc9fc584c9097e3d773f59 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Sat, 7 Mar 2026 04:07:54 +0800 Subject: [PATCH 01/51] [Feature] Extend t.copy to support TMA multicast and SM-to-SM cluster copy Add multicast copy (tma_load_multicast) and shared-memory-to-shared-memory cluster copy (tma_store_cluster / ptx_cluster_store) to the t.copy operator. Multicast copy: - A single CTA issues tma_load_multicast to broadcast a tile to multiple CTAs in the cluster simultaneously; other masked CTAs receive passively. - New `cluster_mask` parameter on T.copy() controls which CTAs participate. SM-to-SM cluster copy: - Fast path via cp.async.bulk.shared::cluster with mbarrier synchronisation (new `remote_barrier` parameter on T.copy()). - Slow path via cooperative_groups::map_shared_rank for element-by-element fallback (new `dst_block` parameter on T.copy()). Supporting changes: - New builtins: tma_load_multicast, tma_store_cluster, ptx_cluster_store, cluster_sync, get_cluster_id, get_cluster_block_rank, get_cluster_block_nums, mbarrier_arrive (cluster-scope). - Codegen (codegen_cuda.cc) and device templates (barrier.h, copy_sm90.h) for all new intrinsics; unified mbarrier API via tl::mbarrier_* free functions. - inject_tma_barrier: handle tma_load_multicast; distinguish thread guards from cluster-rank conditions for correct expect_tx injection. - lower_hopper_intrin: migrate barrier allocation to builtin::create_barriers; hoist user ptx_init_barrier_thread_count calls alongside compiler barriers. - warp_specialized_rewriter: support tma_load_multicast as producer op; UserBarrierInitExtractor to separate user barrier inits from body; ragged (dynamic-extent) pipeline prefix counter; fix nested IfThenElse state corruption in shuffle-elect optimisation. - multi_version_buffer_rewriter: ragged pipeline support with runtime prefix counter for correct ping-pong buffer versioning. - thread_storage_sync: reserve fixed barrier IDs (kProducer/kConsumer) for 256-thread CTA warp-specialised split to prevent deadlocks. - pipeline_planning: defensive handling of non-BufferLoad mbarrier args. - Shared memory alignment bumped to 128 bytes for TMA. Co-authored-by: Guangda Sun <2012661711@qq.com> --- src/op/builtin.cc | 40 +++ src/op/builtin.h | 73 +++++ src/op/copy.cc | 265 +++++++++++++++++- src/op/copy.h | 26 ++ src/target/codegen_cuda.cc | 152 +++++++++- src/tl_templates/cuda/barrier.h | 30 ++ src/tl_templates/cuda/copy_sm90.h | 159 +++++++++++ src/transform/common/thread_sync_types.h | 8 +- src/transform/inject_tma_barrier.cc | 51 +++- src/transform/lower_hopper_intrin.cc | 48 +++- src/transform/lower_tile_op.cc | 11 + .../multi_version_buffer_rewriter.cc | 72 +++++ src/transform/pipeline_planning.cc | 38 ++- src/transform/thread_storage_sync.cc | 27 +- src/transform/warp_specialized_rewriter.cc | 181 ++++++++++-- src/transform/warp_specialized_rewriter.h | 14 +- testing/python/cuda/test_tma_dsmem.py | 92 ++++++ .../python/cuda/test_tma_multicast_demo.py | 103 +++++++ tilelang/engine/phase.py | 7 +- tilelang/language/builtin.py | 57 ++++ tilelang/language/copy_op.py | 26 ++ 21 files changed, 1404 insertions(+), 76 deletions(-) create mode 100644 testing/python/cuda/test_tma_dsmem.py create mode 100644 testing/python/cuda/test_tma_multicast_demo.py diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 21f826537..a2cc29d18 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -155,6 +155,11 @@ TIR_DEFINE_TL_BUILTIN(get_mbarrier) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_TL_BUILTIN(mbarrier_arrive) + .set_num_inputs(3) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(tma_load).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -163,6 +168,11 @@ TIR_DEFINE_TL_BUILTIN(tma_load_im2col) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(tma_load_multicast) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(tma_store).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -280,6 +290,21 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(get_cluster_id) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(get_cluster_block_rank) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(get_cluster_block_nums) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + TIR_DEFINE_TL_BUILTIN(get_lane_idx) .set_num_inputs(-1) .set_attr("TCallEffectKind", @@ -508,5 +533,20 @@ TIR_DEFINE_TL_BUILTIN(stg128).set_num_inputs(-1).set_attr( TIR_DEFINE_TL_BUILTIN(stg256).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(ptx_cluster_store) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(tma_store_cluster) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(cluster_sync) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + } // namespace tl } // namespace tvm diff --git a/src/op/builtin.h b/src/op/builtin.h index 8672da598..2f1589a27 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -234,6 +234,14 @@ TVM_DLL const Op &create_list_of_mbarrier(); */ TVM_DLL const Op &get_mbarrier(); +/*! + * \brief Arrive at mbarrier with remote cta support + * + * mbarrier_arrive(mbarrier, cta_id, pred) + * + */ +TVM_DLL const Op &mbarrier_arrive(); + /*! * \brief tvm intrinsics for loading data from global tensor descriptor to * shared memory @@ -253,6 +261,18 @@ TVM_DLL const Op &tma_load(); */ TVM_DLL const Op &tma_load_im2col(); +/*! + * \brief tvm intrinsics for multicasting data from global tensor descriptor to + * shared memory of multiple CTAs in a cluster simultaneously + * + * tma_load_multicast(descriptor, mbarrier, smem_data, multicast_mask, + * coord_0, coord_1, ..., eviction_policy) + * + * Only the CTA with the minimum rank in the multicast_mask initiates the + * transfer; other CTAs in the mask receive data via the multicast mechanism. + */ +TVM_DLL const Op &tma_load_multicast(); + /*! * \brief tvm intrinsics for storing data from shared memory to global tensor * descriptor @@ -458,6 +478,30 @@ TVM_DLL const Op &warpgroup_wait(); */ TVM_DLL const Op &warpgroup_fence_operand(); +/*! + * \brief Return the cluster id (rank) of the current block within a cluster. + * + * get_cluster_id() + * + */ +TVM_DLL const Op &get_cluster_id(); + +/*! + * \brief Return the block rank within the current cluster. + * + * get_cluster_block_rank() + * + */ +TVM_DLL const Op &get_cluster_block_rank(); + +/*! + * \brief Return the number of blocks in the cluster. + * + * get_cluster_block_nums() + * + */ +TVM_DLL const Op &get_cluster_block_nums(); + /*! * \brief Return the canonical lane index for the calling thread. * @@ -868,6 +912,35 @@ TVM_DLL const Op &stg128(); */ TVM_DLL const Op &stg256(); +/*! + * \brief tilelang intrinsic for cluster store. + * + * This op is used to represent a cluster store operation in tilelang. + */ +TVM_DLL const Op &ptx_cluster_store(); + +/*! + * \brief tilelang intrinsic for bulk SM-to-SM async cluster store. + * + * Uses cp.async.bulk.shared::cluster to bulk-copy a contiguous region of + * shared memory to another CTA in the same cluster, signalling the + * destination CTA's mbarrier on completion. + * + * Args: [dst_ptr, src_ptr, dst_cta, size_bytes, bar_ref] + * - dst_ptr (handle): address_of(dst_buf[dst_offset]) – destination pointer + * - src_ptr (handle): address_of(src_buf[src_offset]) – source pointer + * - dst_cta (int32): destination CTA rank in the cluster + * - size_bytes(uint32): number of bytes to transfer + * - bar_ref (uint64): mbarrier element (passed by reference to the callee) + */ +TVM_DLL const Op &tma_store_cluster(); + +/*! + * \brief tilelang intrinsic for cluster sync. + * + * This op is used to represent a cluster sync operation in tilelang. + */ +TVM_DLL const Op &cluster_sync(); } // namespace tl } // namespace tvm diff --git a/src/op/copy.cc b/src/op/copy.cc index e9aaa1547..a88944609 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include namespace tvm { @@ -43,7 +44,40 @@ Copy::Copy(Array args, Map annotations) { std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); // Copy annotations from the Call node + // then override with positional + // args when provided for backward compatibility. node->annotations = annotations; + if (auto dst_block = node->annotations.Get("dst_block")) { + if (auto int_imm = dst_block->as()) { + if (int_imm->value != -1) { + node->dst_block = Integer(int_imm->value); + } + } else { + node->dst_block = Downcast(dst_block.value()); + } + } + if (args.size() >= 3) { + auto coalesced_width = Downcast(args[2]); + if (coalesced_width->value > 0) { + node->annotations.Set(attr::kCoalescedWidth, coalesced_width); + } + } + if (args.size() >= 4) { + node->annotations.Set("disable_tma", Downcast(args[3])); + } + if (args.size() >= 5) { + node->annotations.Set("eviction_policy", args[4]); + } + if (args.size() >= 6) { + auto dst_block = args[5]; + if (auto int_imm = dst_block.as()) { + if (int_imm->value != -1) { + node->dst_block = dst_block; + } + } else { + node->dst_block = dst_block; + } + } data_ = std::move(node); } @@ -683,6 +717,9 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { pass_ctx->GetConfig(kDisableTMALower, Bool(false)).value(); auto copy_inst = GetCopyInst(target, disable_tma_lower || GetDisableTMA(), T.layout_map, analyzer); + if (dst_block.defined()) { + return LowerClusterCopy(T, analyzer); + } if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { auto tmem_copy = LowerTmemCopy(T, analyzer); ICHECK(tmem_copy.defined()) << "Failed to lower tensor memory copy"; @@ -705,6 +742,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return LowerNormalCopy(T, analyzer); } else { LOG(FATAL) << "Unsupported copy inst " << static_cast(copy_inst); + return Stmt(); } } @@ -1130,6 +1168,159 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, return body; } +Stmt CopyNode::LowerClusterCopy(const LowerArgs &T, + arith::Analyzer *analyzer) const { + // Check if the target supports cluster copy + // Currently only support shared memory to shared memory copy + ICHECK(src.scope() == "shared" || src.scope() == "shared.dyn"); + ICHECK(dst.scope() == "shared" || dst.scope() == "shared.dyn"); + + // --------------------------------------------------------------------------- + // Fast path: bulk async copy via tl::tma_store_cluster + // Used when the caller supplies a shared-memory mbarrier via the "barrier" + // annotation. A single elected thread issues the cp.async.bulk instruction; + // the destination CTA waits on its local copy of the same mbarrier. + // --------------------------------------------------------------------------- + if (auto barrier_opt = GetBarrier()) { + PrimExpr barrier_load = barrier_opt.value(); + + // Compute linear offsets from the copy ranges (one offset per buffer). + auto compute_linear_offset = [](const Buffer &buf, + const Array &ranges) -> PrimExpr { + PrimExpr offset = 0; + PrimExpr stride = 1; + for (int i = static_cast(ranges.size()) - 1; i >= 0; --i) { + offset = offset + ranges[i]->min * stride; + if (i > 0) + stride = stride * buf->shape[i]; + } + return offset; + }; + + PrimExpr dst_offset = compute_linear_offset(dst, dst_range); + PrimExpr src_offset = compute_linear_offset(src, src_range); + + // Total number of elements to transfer. + PrimExpr total_elements = 1; + for (auto r : src_range) + total_elements = total_elements * r->extent; + PrimExpr size_bytes = + cast(DataType::UInt(32), total_elements * src->dtype.bytes()); + + // Build tvm_access_ptr arguments. These are processed by LowerTileOp's + // HandleAccessPtrAndOffset which, for TMA ops (in_tma_context_=true), + // keeps the raw linear offset without applying any swizzle transformation. + PrimExpr dst_ptr = + dst.access_ptr(2, DataType::Handle(), 1, dst_offset, total_elements); + PrimExpr src_ptr = + src.access_ptr(1, DataType::Handle(), 1, src_offset, total_elements); + + Stmt bulk_copy = Evaluate( + Call(DataType::Handle(), tma_store_cluster(), + {dst_ptr, src_ptr, dst_block.value(), size_bytes, barrier_load})); + + // Single-thread guard: only thread_bounds->min issues the instruction. + return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), bulk_copy); + } + + // --------------------------------------------------------------------------- + // Slow path: element-by-element SIMT copy via ptx_cluster_store + // Used when no barrier is provided (backward-compatible behaviour). + // --------------------------------------------------------------------------- + + // Generate the loop nest for the copy + auto simt_loop = MakeSIMTLoop(analyzer); + auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); + + // Partition across threads but force scalar (vectorize_hint=1): + // ClusterCopyReplacer replaces BufferStore with ptx_cluster_store + // (cooperative_groups map_shared_rank + scalar write). Vectorized stores + // would produce vector-dtype values that cannot be written through + // map_shared_rank's scalar pointer return type. + std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, + InferLevel::kFree}; + auto par_op = ParallelOp(fused_loop); + for (auto level : levels) { + par_op->InferLayout({T.target, + T.thread_bounds, + T.layout_map, + analyzer, + false, + T.buffer_remap, + {}}, + level); + } + auto loop_layout = par_op->GetLoopLayout(); + auto thread_loop = + PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout); + auto vectorized_thread_loop = + VectorizeLoop(thread_loop, T.layout_map, /*vectorize_hint=*/1); + + // Replace the buffer store with the cluster copy intrinsic + class ClusterCopyReplacer : public StmtExprMutator { + public: + ClusterCopyReplacer(const Buffer &dst, PrimExpr dst_block, + const Buffer &target_dst) + : dst_(dst), dst_block_(dst_block), target_dst_(target_dst) {} + + Stmt VisitStmt_(const BufferStoreNode *op) final { + if (op->buffer.same_as(dst_)) { + Array args; + args.push_back(target_dst_.access_ptr(2)); // The buffer var (handle) + args.push_back(op->value); // The value to store + args.push_back(dst_block_); // The destination block index + + // linearize the index. + PrimExpr linearized_index = op->indices[0]; + if (op->indices.size() > 1) { + PrimExpr multiplier = 1; + linearized_index = 0; + for (int i = op->indices.size() - 1; i >= 0; --i) { + linearized_index = linearized_index + op->indices[i] * multiplier; + if (i > 0) { + multiplier = multiplier * op->buffer->shape[i]; + } + } + } + + args.push_back(linearized_index); + + Buffer target_buffer = target_dst_; + if (target_dst_.same_as(dst_)) { + target_buffer = op->buffer; + } + + // Guard against out-of-bounds remote stores when the thread count + // exceeds the copy extent (e.g. 128 threads for a 64-element region). + PrimExpr total_elems = 1; + for (const PrimExpr &s : target_buffer->shape) { + total_elems = total_elems * s; + } + + Stmt remote_store = Evaluate( + Call(DataType::Handle(), ptx_cluster_store(), + {target_buffer.access_ptr(2), args[1], args[2], args[3]})); + + return IfThenElse(args[3] < total_elems, remote_store, Stmt()); + } + return StmtExprMutator::VisitStmt_(op); + } + + private: + const Buffer &dst_; + PrimExpr dst_block_; + const Buffer &target_dst_; + }; + + Buffer target_dst = dst; + if (T.buffer_remap.count(dst)) { + target_dst = T.buffer_remap[dst]; + } + + return ClusterCopyReplacer(dst, dst_block.value(), + target_dst)(vectorized_thread_loop); +} + // Lowers copy to a bulk TMA (Tensor Memory Accelerator) transfer. // Falls back to LowerNormalCopy if preconditions are not satisfied. Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, @@ -1377,6 +1568,10 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, Call create_descriptor = Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs()); + // Check for cluster multicast mask annotation (0 means no multicast) + int64_t cluster_mask = GetClusterMask(); + bool use_multicast = is_load && (cluster_mask > 0); + Array args; args.reserve(desc.rank + 4); args.push_back(create_descriptor); @@ -1389,6 +1584,24 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, for (auto e : desc.smem_box) total_elements *= e; + // Helper lambda to build multicast args by inserting the cluster_mask + // value between smem_ptr (args[2]) and global coordinates. + // Regular tma_load args layout: [desc, mbar, smem_ptr, coord0..., eviction] + // Multicast tma_load args layout: [desc, mbar, smem_ptr, mask, coord0..., + // eviction] + auto build_multicast_args = [&](const Array ®ular_args) { + Array mc_args; + mc_args.reserve(regular_args.size() + 1); + mc_args.push_back(regular_args[0]); // descriptor + mc_args.push_back(regular_args[1]); // mbarrier placeholder + mc_args.push_back(regular_args[2]); // smem_ptr + mc_args.push_back( + IntImm(DataType::Int(32), cluster_mask)); // multicast mask + for (size_t i = 3; i < regular_args.size(); i++) + mc_args.push_back(regular_args[i]); // coords + eviction_policy + return mc_args; + }; + if ((*inner_box_dim) != instruction_dim) { Var loop_var("i"); int loop_extent = (*inner_box_dim) / instruction_dim; @@ -1406,6 +1619,33 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, args.push_back(GetEvictionPolicy()); tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, Evaluate(Call(DataType::Handle(), op, args))); + + if (use_multicast) { + // Build multicast args using the same loop_var (safe since branches are + // mutually exclusive in the outer IfThenElse) + Array mc_args = build_multicast_args(args); + Stmt multicast_copy = For( + loop_var, 0, loop_extent, ForKind::kUnrolled, + Evaluate(Call(DataType::Handle(), tma_load_multicast(), mc_args))); + + // 3-way split based on cluster mask: + // block_rank == min_rank_in_mask -> tma_load_multicast (then branch) + // block_rank NOT in mask -> regular tma_load (elif branch) + // block_rank in mask, not min -> do nothing (no else) + int min_cta_rank = static_cast( + __builtin_ctzll(static_cast(cluster_mask))); + PrimExpr block_rank = + Call(DataType::Int(32), get_cluster_block_rank(), {}); + PrimExpr mask_imm = IntImm(DataType::Int(32), cluster_mask); + // (mask >> block_rank) & 1 == 0 ⟺ block_rank is NOT set in the mask + PrimExpr not_in_mask = EQ(bitwise_and(right_shift(mask_imm, block_rank), + IntImm(DataType::Int(32), 1)), + IntImm(DataType::Int(32), 0)); + Stmt regular_or_noop = IfThenElse(not_in_mask, tma_copy, std::nullopt); + tma_copy = + IfThenElse(EQ(block_rank, IntImm(DataType::Int(32), min_cta_rank)), + multicast_copy, regular_or_noop); + } } else { PrimExpr shared_addr = shared_tensor.access_ptr( is_load ? 2 : 1, DataType::Handle(), 1, shared_offset, total_elements); @@ -1417,6 +1657,29 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, args.push_back(need_reduce); args.push_back(GetEvictionPolicy()); tma_copy = Evaluate(Call(DataType::Handle(), op, args)); + + if (use_multicast) { + Array mc_args = build_multicast_args(args); + Stmt multicast_copy = + Evaluate(Call(DataType::Handle(), tma_load_multicast(), mc_args)); + + // 3-way split based on cluster mask: + // block_rank == min_rank_in_mask -> tma_load_multicast (then branch) + // block_rank NOT in mask -> regular tma_load (elif branch) + // block_rank in mask, not min -> do nothing (no else) + int min_cta_rank = static_cast( + __builtin_ctzll(static_cast(cluster_mask))); + PrimExpr block_rank = + Call(DataType::Int(32), get_cluster_block_rank(), {}); + PrimExpr mask_imm = IntImm(DataType::Int(32), cluster_mask); + PrimExpr not_in_mask = EQ(bitwise_and(right_shift(mask_imm, block_rank), + IntImm(DataType::Int(32), 1)), + IntImm(DataType::Int(32), 0)); + Stmt regular_or_noop = IfThenElse(not_in_mask, tma_copy, std::nullopt); + tma_copy = + IfThenElse(EQ(block_rank, IntImm(DataType::Int(32), min_cta_rank)), + multicast_copy, regular_or_noop); + } } tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); @@ -1733,7 +1996,7 @@ void CopyNode::CollectFragmentLayouts(const PrimExpr &expr, // eviction_policy // - Marked as opaque since it has side effects (memory writes) TIR_REGISTER_TL_TILE_OP(Copy, copy) - .set_num_inputs(5) + .set_num_inputs(6) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/copy.h b/src/op/copy.h index 3fe1b813c..cbe9343d2 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -115,6 +115,8 @@ class CopyNode : public TileOperatorNode { public: Buffer src, dst; // Source and destination buffers Array src_range, dst_range; // Ranges for each dimension in src and dst + Optional dst_block; // Destination block index for cluster copy + Map annotations; // Annotations for the copy operation // Supported annotation keys: // - "coalesced_width": IntImm, width for coalesced memory access @@ -136,6 +138,7 @@ class CopyNode : public TileOperatorNode { .def_ro("dst", &CopyNode::dst) .def_ro("src_range", &CopyNode::src_range) .def_ro("dst_range", &CopyNode::dst_range) + .def_ro("dst_block", &CopyNode::dst_block) .def_ro("annotations", &CopyNode::annotations); } @@ -158,6 +161,28 @@ class CopyNode : public TileOperatorNode { return 0; // default: evict_normal } + // Returns the cluster multicast mask (0 means no multicast / regular TMA + // load) + int64_t GetClusterMask() const { + if (auto val = annotations.Get("cluster_mask")) { + if (auto int_val = val->as()) { + return int_val->value; + } + } + return 0; + } + + // Returns the mbarrier BufferLoad (as PrimExpr) used for tma_store_cluster, + // or an empty Optional when the "barrier" annotation is absent. + Optional GetBarrier() const { + if (auto val = annotations.Get("barrier")) { + if (val->as()) { + return Downcast(val.value()); + } + } + return Optional(); + } + /*! * \brief Lower the copy operator to a TIR statement. * \param T Arguments for lowering. @@ -258,6 +283,7 @@ class CopyNode : public TileOperatorNode { */ Stmt LowerTmemCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; + Stmt LowerClusterCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; /*! * \brief Generate lowering for normal copy. */ diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 4d996e54c..ea4b2eb88 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1730,6 +1730,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->stream << ss.str(); this->stream << ");\n"; }; + auto print_mbarrier_obj = [&](PrimExpr barrier_id) { + std::ostringstream ss; + if (barrier_id.as()) { + ss << mbarrier_name_ << "[" << barrier_id << "]"; + } else { + ss << this->PrintExpr(barrier_id); + } + return ss.str(); + }; if (op->op.same_as(builtin::ptx_cp_async())) { // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, // args[3] = predicate (optional) @@ -1795,18 +1804,26 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ICHECK_EQ(op->args.size(), 1); std::string barrier_id = this->PrintExpr(op->args[0]); os << mbarrier_name_ + "[" + barrier_id + "]"; + } else if (op->op.same_as(tl::mbarrier_arrive())) { + ICHECK_EQ(op->args.size(), 3); + this->PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto cta_id = this->PrintExpr(op->args[1]); + auto pred = this->PrintExpr(op->args[2]); + this->stream << "tl::mbarrier_arrive(" << mbarrier_obj << ", " << cta_id + << ", " << pred << ");\n"; } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { if (op->args.size() == 1) { this->PrintIndent(); - auto mbarrier_obj = this->PrintExpr(op->args[0]); - this->stream << mbarrier_obj << ".arrive();\n"; + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + this->stream << "tl::mbarrier_arrive(" << mbarrier_obj << ");\n"; } else if (op->args.size() == 3) { this->PrintIndent(); auto mbarrier_obj = this->PrintExpr(op->args[0]); auto cta_id = this->PrintExpr(op->args[1]); auto pred = this->PrintExpr(op->args[2]); - this->stream << mbarrier_obj << ".arrive(" << cta_id << ", " << pred - << ");\n"; + this->stream << "tl::mbarrier_arrive(" << mbarrier_obj << ", " << cta_id + << ", " << pred << ");\n"; } else { LOG(FATAL) << "Invalid parameter for tl::arrive_barrier " << op->args.size(); @@ -1816,13 +1833,14 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->PrintIndent(); auto mbarrier_obj = this->PrintExpr(op->args[0]); auto arrive_count = this->PrintExpr(op->args[1]); - this->stream << mbarrier_obj << ".init(" << arrive_count << ");\n"; + this->stream << "tl::mbarrier_init(" << mbarrier_obj << ", " << arrive_count + << ");\n"; } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { if (op->args.size() == 2) { this->PrintIndent(); auto mbarrier_obj = this->PrintExpr(op->args[0]); auto transaction_bytes = this->PrintExpr(op->args[1]); - this->stream << mbarrier_obj << ".arrive_and_expect_tx(" + this->stream << "tl::mbarrier_arrive_expect_tx(" << mbarrier_obj << ", " << transaction_bytes << ");\n"; } else if (op->args.size() == 4) { this->PrintIndent(); @@ -1830,7 +1848,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { auto transaction_bytes = this->PrintExpr(op->args[1]); auto cta_id = this->PrintExpr(op->args[2]); auto pred = this->PrintExpr(op->args[3]); - this->stream << mbarrier_obj << ".arrive_and_expect_tx(" + this->stream << "tl::mbarrier_arrive_expect_tx(" << mbarrier_obj << ", " << transaction_bytes << ", " << cta_id << ", " << pred << ");\n"; } else { @@ -1848,14 +1866,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->PrintIndent(); auto mbarrier_obj = this->PrintExpr(op->args[0]); auto transaction_bytes = this->PrintExpr(op->args[1]); - this->stream << mbarrier_obj << ".expect_transaction(" << transaction_bytes - << ");\n"; + this->stream << "tl::mbarrier_expect_tx(" << mbarrier_obj << ", " + << transaction_bytes << ");\n"; } else if (op->op.same_as(tl::mbarrier_wait_parity())) { ICHECK_EQ(op->args.size(), 2); this->PrintIndent(); auto mbarrier_obj = this->PrintExpr(op->args[0]); auto phase = this->PrintExpr(op->args[1]); - this->stream << mbarrier_obj << ".wait(" << phase << ");\n"; + this->stream << "tl::mbarrier_wait(" << mbarrier_obj << ", " << phase + << ");\n"; } else if (op->op.same_as(tl::ptx_init_tensor_memory())) { print_extern_call_stmt("tl::tmem_allocate"); } else if (op->op.same_as(tl::ptx_deallocate_tensor_memory())) { @@ -1885,6 +1904,35 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ss << ");\n"; this->PrintIndent(); this->stream << ss.str(); + } else if (op->op.same_as(tl::tma_load_multicast())) { + // args layout: [descriptor, mbarrier, smem_ptr, multicast_mask, + // coord_0, ..., coord_{n-1}, eviction_policy] + std::ostringstream ss; + ICHECK_GE(op->args.size(), 5) + << "tma_load_multicast requires at least 5 args"; + auto eviction_policy = + this->eviction_policy_names_ + [op->args[op->args.size() - 1].as()->value]; + if (eviction_policy != "EVICT_NORMAL") { + ss << "tl::tma_load_multicast("; + } else { + ss << "tl::tma_load_multicast("; + } + // descriptor + ss << this->PrintExpr(op->args[0]) << ", "; + // mbarrier + ss << this->PrintExpr(op->args[1]) << ", "; + ss << this->PrintExpr(op->args[2]) << ", "; + // multicast_mask cast to uint16_t + ss << "(uint16_t)(" << this->PrintExpr(op->args[3]) << ")"; + // global coordinates (args[4..N-2]) + for (size_t i = 4; i < op->args.size() - 1; i++) { + ss << ", " << this->PrintExpr(op->args[i]); + } + ss << ");\n"; + this->PrintIndent(); + this->stream << ss.str(); } else if (op->op.same_as(tl::tma_load_im2col())) { std::stringstream ss; auto eviction_policy = @@ -2131,6 +2179,76 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { replacer.register_rule("(C_ptr)", c_ref); replacer.register_rule("(C_offset)", c_bias); this->stream << replacer.rewrite(mma_call); + } else if (op->op.same_as(tl::tma_store_cluster())) { + // args: [dst_ptr, src_ptr, dst_cta, size_bytes, bar_ref] + // dst_ptr – address_of(dst_buf[offset]) (void*) + // src_ptr – address_of(src_buf[offset]) (void*) + // dst_cta – destination CTA rank (int) + // size_bytes– bytes to transfer (uint32_t) + // bar_ref – mbarrier element (uint64_t&) + ICHECK_EQ(op->args.size(), 5U) << "tma_store_cluster requires 5 args"; + this->PrintIndent(); + this->stream << "tl::tma_store_cluster("; + this->stream << this->PrintExpr(op->args[0]) << ", "; + this->stream << this->PrintExpr(op->args[1]) << ", "; + this->stream << "(int)(" << this->PrintExpr(op->args[2]) << "), "; + this->stream << "(uint32_t)(" << this->PrintExpr(op->args[3]) << "), "; + this->stream << this->PrintExpr(op->args[4]) << ");\n"; + + } else if (op->op.same_as(tl::ptx_cluster_store())) { + // arg 0: buffer var (handle) + // arg 1: value to store + // arg 2: dst block index + // arg 3: linearized index + + ICHECK_EQ(op->args.size(), 4U); + + std::string buffer_var = this->PrintExpr(op->args[0]); + std::string value = this->PrintExpr(op->args[1]); + std::string dst_block = this->PrintExpr(op->args[2]); + std::string index = this->PrintExpr(op->args[3]); + + // We need to cast the buffer var to the correct type if it's not already. + // But `buffer_var` here is likely just the name of the variable. + // We need to handle the type of the value being stored. + + // The generated code should look like: + // { + // namespace cg = cooperative_groups; + // cg::cluster_group cluster = cg::this_cluster(); + // auto* dst_ptr = cluster.map_shared_rank(&buffer[index], dst_block); + // *dst_ptr = value; + // } + + // However, `buffer[index]` might be an expression. + // We need to get the address of `buffer[index]`. + // If `buffer` is a pointer, then `&buffer[index]` is `buffer + index`. + + // We need to know the type of the value to cast the pointer correctly if + // needed. But `map_shared_rank` is a template or returns a pointer to the + // same type? `T* map_shared_rank(T* addr, unsigned int rank)` + + this->need_cooperative_groups_ = true; + this->PrintIndent(); + this->stream << "{\n"; + int cluster_scope = this->BeginScope(); + this->PrintIndent(); + this->stream << "namespace cg = cooperative_groups;\n"; + this->PrintIndent(); + this->stream << "cg::cluster_group cluster = cg::this_cluster();\n"; + this->PrintIndent(); + this->stream << "auto* dst_ptr = cluster.map_shared_rank(&" << buffer_var + << "[" << index << "], " << dst_block << ");\n"; + this->PrintIndent(); + this->stream << "*dst_ptr = " << value << ";\n"; + this->EndScope(cluster_scope); + this->PrintIndent(); + this->stream << "}\n"; + + } else if (op->op.same_as(tl::cluster_sync())) { + this->need_cooperative_groups_ = true; + this->PrintIndent(); + this->stream << "cooperative_groups::this_cluster().sync();\n"; } else if (op->op.same_as(tl::ptx_mma_sm70())) { // arg 0: shape: mXnXkX // arg 1: A layout: row/col @@ -2998,6 +3116,20 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { os << PrintExpr(op->args[i]); } os << ")"; + } else if (op->op.same_as(tl::get_cluster_id())) { + ICHECK_EQ(op->args.size(), 0) << "tl.get_cluster_id expects no arguments."; + this->need_cooperative_groups_ = true; + os << "cooperative_groups::this_grid().cluster_rank()"; + } else if (op->op.same_as(tl::get_cluster_block_rank())) { + ICHECK_EQ(op->args.size(), 0) + << "tl.get_cluster_block_rank expects no arguments."; + this->need_cooperative_groups_ = true; + os << "cooperative_groups::this_cluster().block_rank()"; + } else if (op->op.same_as(tl::get_cluster_block_nums())) { + ICHECK_EQ(op->args.size(), 0) + << "tl.get_cluster_block_nums expects no arguments."; + this->need_cooperative_groups_ = true; + os << "cooperative_groups::this_cluster().num_blocks()"; } else if (op->op.same_as(tl::tl_shuffle_elect())) { os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()"; } else if (op->op.same_as(tl::initialize_wgmma_descriptor())) { diff --git a/src/tl_templates/cuda/barrier.h b/src/tl_templates/cuda/barrier.h index 79a57f7df..b8fc5d841 100644 --- a/src/tl_templates/cuda/barrier.h +++ b/src/tl_templates/cuda/barrier.h @@ -8,6 +8,10 @@ using Barrier = cutlass::arch::ClusterTransactionBarrier; namespace tl { +TL_DEVICE void mbarrier_init(Barrier &barrier, uint32_t arrive_count) { + barrier.init(arrive_count); +} + TL_DEVICE void mbarrier_init(uint64_t &smem_barrier, uint32_t arrive_count) { uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); asm volatile("mbarrier.init.shared.b64 [%1], %0;" @@ -31,6 +35,10 @@ TL_DEVICE uint32_t mbarrier_try_wait(uint64_t &smem_barrier, int phase_bit) { return waitComplete; } +TL_DEVICE void mbarrier_wait(Barrier &barrier, int phase_bit) { + barrier.wait(phase_bit); +} + TL_DEVICE void mbarrier_wait(uint64_t &smem_barrier, int phase_bit) { if (mbarrier_try_wait(smem_barrier, phase_bit) == 0) { uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); @@ -65,11 +73,17 @@ TL_DEVICE void mbarrier_test_wait(uint64_t &smem_barrier, int phase_bit) { "r"(phase_bit)); } +TL_DEVICE void mbarrier_arrive(Barrier &barrier) { barrier.arrive(); } + TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier) { uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(smem_int_ptr)); } +TL_DEVICE void mbarrier_arrive(Barrier &barrier, int cta_id, uint32_t pred) { + barrier.arrive(cta_id, pred); +} + TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier, int cta_id, uint32_t pred) { uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); @@ -84,6 +98,11 @@ TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier, int cta_id, } } +TL_DEVICE void mbarrier_expect_tx(Barrier &barrier, + uint32_t transaction_bytes) { + barrier.expect_transaction(transaction_bytes); +} + TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier, uint32_t transaction_bytes) { uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); @@ -92,6 +111,17 @@ TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier, : "r"(transaction_bytes), "r"(smem_int_ptr)); } +TL_DEVICE void mbarrier_arrive_expect_tx(Barrier &barrier, + uint32_t transaction_bytes) { + barrier.arrive_and_expect_tx(transaction_bytes); +} + +TL_DEVICE void mbarrier_arrive_expect_tx(Barrier &barrier, + uint32_t transaction_bytes, int cta_id, + uint32_t pred) { + barrier.arrive_and_expect_tx(transaction_bytes, cta_id, pred); +} + TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t &smem_barrier, uint32_t transaction_bytes) { uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); diff --git a/src/tl_templates/cuda/copy_sm90.h b/src/tl_templates/cuda/copy_sm90.h index 3d5b3f414..53145933f 100644 --- a/src/tl_templates/cuda/copy_sm90.h +++ b/src/tl_templates/cuda/copy_sm90.h @@ -38,6 +38,165 @@ TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr, :); } +// Generic SM-to-SM async bulk copy via cp.async.bulk.shared::cluster +TL_DEVICE void tma_store_cluster(void *dst, void *src, int dst_cta, + uint32_t size_bytes, uint64_t &bar) { + uint32_t mbarrier_ptr = static_cast(__cvta_generic_to_shared(&bar)); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(src)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); + + uint32_t neighbor_addr_dst; + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_addr_dst) + : "r"(dst_ptr), "r"(dst_cta)); + + uint32_t neighbor_addr_mbarrier; + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_addr_mbarrier) + : "r"(mbarrier_ptr), "r"(dst_cta)); + + // Arrive at the remote barrier and announce the expected TX byte count. + // This satisfies one arrival (matching the mbarrier_init count) and tells + // the barrier how many bytes the subsequent cp.async.bulk will transfer. + asm volatile("mbarrier.arrive.expect_tx.shared::cluster.b64 _, [%0], %1;\n" + : + : "r"(neighbor_addr_mbarrier), "r"(size_bytes) + : "memory"); + + asm volatile("fence.proxy.async.shared::cta;\n" ::: "memory"); + asm volatile("cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_" + "tx::bytes [%0], [%1], %2, [%3];\n" + : + : "r"(neighbor_addr_dst), "r"(src_ptr), "r"(size_bytes), + "r"(neighbor_addr_mbarrier) + : "memory"); +} + +template +TL_DEVICE void +tma_load_multicast(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, uint16_t multicast_mask, + int32_t const &crd0) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4, %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "h"(multicast_mask), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void +tma_load_multicast(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, uint16_t multicast_mask, + int32_t const &crd0, int32_t const &crd1) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5, %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "r"(crd1), "h"(multicast_mask), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_load_multicast(const CUtensorMap &descriptor, + BarrierType &smem_mbar, + void const *const smem_ptr, + uint16_t multicast_mask, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "r"(crd1), "r"(crd2), "h"(multicast_mask), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void +tma_load_multicast(const CUtensorMap &descriptor, BarrierType &smem_mbar, + void const *const smem_ptr, uint16_t multicast_mask, + int32_t const &crd0, int32_t const &crd1, + int32_t const &crd2, int32_t const &crd3) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7, %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "r"(crd1), "r"(crd2), "r"(crd3), "h"(multicast_mask), "l"(cache_hint) + : "memory"); +} + +template +TL_DEVICE void tma_load_multicast(const CUtensorMap &descriptor, + BarrierType &smem_mbar, + void const *const smem_ptr, + uint16_t multicast_mask, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2, + int32_t const &crd3, int32_t const &crd4) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::" + "bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), + "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "h"(multicast_mask), + "l"(cache_hint) + : "memory"); +} + template TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar, diff --git a/src/transform/common/thread_sync_types.h b/src/transform/common/thread_sync_types.h index bbcf4c2b4..be9099dcc 100644 --- a/src/transform/common/thread_sync_types.h +++ b/src/transform/common/thread_sync_types.h @@ -28,7 +28,13 @@ enum class ReservedNamedBarriers : uint8_t { kSyncThreads = 0, kReduce_0 = 1, kReduce_1 = 2, - kFirstUsedBarrier = kReduce_1 + 1 + // TileLang convention for 256-thread CTA split into two 128-thread groups. + // Producer: threadIdx.x in [128, 255] + // Consumer: threadIdx.x in [0, 127] + // These must be distinct to avoid mixing barrier states and deadlocks. + kProducer = kReduce_1 + 1, + kConsumer = kProducer + 1, + kFirstUsedBarrier = kConsumer + 1 }; } // namespace tl diff --git a/src/transform/inject_tma_barrier.cc b/src/transform/inject_tma_barrier.cc index 1c3e461bd..8d0a38d25 100644 --- a/src/transform/inject_tma_barrier.cc +++ b/src/transform/inject_tma_barrier.cc @@ -63,13 +63,17 @@ class TmaTraitsCollector : public StmtExprVisitor { private: void VisitExpr_(const CallNode *call) final { - if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || + call->op.same_as(tma_load_multicast())) { auto arg0 = call->args[0].as(); if (call->op.same_as(tma_load()) && arg0 && !arg0.value()->op.same_as(create_tma_descriptor())) { // 1D TMA load has tvm_access_ptr of shared tensor in its args[0] bulk_copy_bytes = call->args[3] * loop_extents; } else { + // Descriptor-based TMA load: smem_ptr is always at args[2] + // (for tma_load_multicast, args[3] is the mask, but args[2] is still + // the smem access_ptr) Call access_ptr = Downcast(call->args[2]); ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); int type_bytes = access_ptr->args[0]->dtype.bytes(); @@ -86,6 +90,17 @@ class TmaTraitsCollector : public StmtExprVisitor { loop_extents = old_loop_evtents; } + // For if/else branches (mutually exclusive), count only the then branch. + // Both branches always transfer the same number of bytes (same tile size), + // so counting either one gives the correct mbarrier expect_tx byte count. + void VisitStmt_(const IfThenElseNode *op) final { + if (op->else_case.defined()) { + StmtExprVisitor::VisitStmt(op->then_case); + } else { + StmtExprVisitor::VisitStmt_(op); + } + } + PrimExpr bulk_copy_bytes = 0; PrimExpr loop_extents = 1; }; @@ -135,7 +150,21 @@ class TmaExpectTxRewriter : public IRMutatorWithAnalyzer { if (op->condition.as()) { flag = op->condition.as()->op.same_as(tl_shuffle_elect()); } - if (op->condition.as() || flag) { + // Only trigger for EQ conditions that involve the thread variable + // (directly or offset like threadIdx.x - 128 == 0 after warp-spec + // remapping). Runtime function-call conditions such as + // get_cluster_block_rank() == min_rank must NOT trigger byte injection: + // they don't involve any thread variable. + bool is_thread_eq = false; + if (auto eq = op->condition.as()) { + if (thread_var_.defined()) { + auto f_uses_thread = [this](const VarNode *v) { + return v == thread_var_->var.get(); + }; + is_thread_eq = UsesVar(op->condition, f_uses_thread); + } + } + if (is_thread_eq || flag) { Stmt ret = IRMutatorWithAnalyzer::VisitStmt_(op); if (visited_tma_load_) { @@ -163,8 +192,10 @@ class TmaExpectTxRewriter : public IRMutatorWithAnalyzer { } PrimExpr VisitExpr_(const CallNode *op) { - if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { + if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) || + op->op.same_as(tma_load_multicast())) { auto arg0 = op->args[0].as(); + // Only the 1D non-descriptor tma_load has its mbarrier at args[2] bool is_1d_tma_load = arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && op->op.same_as(tma_load()); @@ -203,7 +234,8 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer { void VisitStmt_(const EvaluateNode *op) final { if (const auto *call = op->value.as()) { - if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || + call->op.same_as(tma_load_multicast())) { pending_tma_ops_.push_back(tvm::ffi::GetRef(call)); } else if (call->op.same_as(mbarrier_expect_tx())) { pending_tma_ops_.push_back(tvm::ffi::GetRef(call)); @@ -506,7 +538,8 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { } PrimExpr VisitExpr_(const CallNode *op) { - if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { + if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) || + op->op.same_as(tma_load_multicast())) { auto call_ref = tvm::ffi::GetRef(op); if (!tma_op_to_barrier_id_.count(call_ref)) { // For 1D TMA loads, promote raw integer barrier id to get_mbarrier(id) @@ -515,7 +548,8 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { auto arg0 = op->args[0].as(); bool is_1d_tma_load = arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && - !arg0.value()->op.same_as(create_tma_im2col_descriptor()); + !arg0.value()->op.same_as(create_tma_im2col_descriptor()) && + op->op.same_as(tma_load()); if (is_1d_tma_load && op->args.size() >= 3) { if (const auto *imm = op->args[2].as()) { Array new_args = op->args; @@ -532,10 +566,13 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { auto arg0 = op->args[0].as(); auto is_1d_tma_load = arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && - !arg0.value()->op.same_as(create_tma_im2col_descriptor()); + !arg0.value()->op.same_as(create_tma_im2col_descriptor()) && + op->op.same_as(tma_load()); if (is_1d_tma_load) { new_args.Set(2, barrier_id); } else { + // Descriptor-based tma_load, tma_load_multicast, and tma_load_im2col + // all have the mbarrier at args[1] new_args.Set(1, barrier_id); } return Call(op->dtype, op->op, new_args, op->annotations); diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index e43657ade..af119e0bb 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -12,7 +12,6 @@ #include "../op/builtin.h" #include "../runtime/runtime.h" -#include "./common/mbarrier.h" namespace tvm { namespace tl { @@ -138,6 +137,12 @@ class LowerHopperIntrin : public StmtExprMutator { return AttrStmt(op->node, op->attr_key, op->value, body); } else { Array stmt_seq; + if (num_managed_barriers_ > 0) { + auto alloc_mbarrier = + Evaluate(Call(DataType::Handle(), builtin::create_barriers(), + {num_managed_barriers_})); + stmt_seq.push_back(alloc_mbarrier); + } auto stmts = prefetch_calls_; stmts.insert(stmts.end(), init_mbarrier_calls_.begin(), @@ -171,16 +176,9 @@ class LowerHopperIntrin : public StmtExprMutator { Stmt result = SeqStmt(stmt_seq); - if (!init_mbarrier_calls_.empty()) { - mbarrier_buffer_ = CreateMBarrierBuffer( - injected_mbarrier_name_, init_mbarrier_calls_.size()); - result = DeclBuffer(mbarrier_buffer_, result); - result = Allocate(mbarrier_buffer_->data, mbarrier_buffer_->dtype, - mbarrier_buffer_->shape, const_true(), result); - } - prefetch_calls_.clear(); init_mbarrier_calls_.clear(); + num_managed_barriers_ = 0; return AttrStmt(op->node, op->attr_key, op->value, result); } } @@ -206,8 +204,9 @@ class LowerHopperIntrin : public StmtExprMutator { } return var; } else if (call->op.same_as(create_list_of_mbarrier())) { - ICHECK(init_mbarrier_calls_.empty()); + // ICHECK(init_mbarrier_calls_.empty()); int num_barriers = static_cast(call->args.size()); + num_managed_barriers_ += num_barriers; for (int i = 0; i < num_barriers; i++) { PrimExpr mbarrier = Call(DataType::Handle(), get_mbarrier(), {i}); init_mbarrier_calls_.push_back(Evaluate( @@ -215,19 +214,46 @@ class LowerHopperIntrin : public StmtExprMutator { {mbarrier, call->args[i]}))); } return 0; + } else if (call->op.same_as(builtin::ptx_init_barrier_thread_count())) { + init_mbarrier_calls_.push_back(Evaluate(tvm::ffi::GetRef(call))); + return 0; } else { return StmtExprMutator::VisitExpr_(call); } } + Stmt VisitStmt_(const IfThenElseNode *op) final { + Stmt new_stmt = StmtExprMutator::VisitStmt_(op); + if (const auto *if_node = new_stmt.as()) { + if (IsNoOp(if_node->then_case) && (!if_node->else_case.defined() || + IsNoOp(if_node->else_case.value()))) { + return Evaluate(0); + } + } + return new_stmt; + } + + bool IsNoOp(const Stmt &stmt) { + if (const auto *eval = stmt.as()) { + return is_const_int(eval->value, 0); + } else if (const auto *seq = stmt.as()) { + for (const auto &s : seq->seq) { + if (!IsNoOp(s)) + return false; + } + return true; + } + return false; + } + private: Array prefetch_calls_; Array init_mbarrier_calls_; + int num_managed_barriers_ = 0; std::unordered_map desc_map_; LowerHopperIntrin(bool disable_shuffle_elect) : disable_shuffle_elect_(disable_shuffle_elect) {} bool disable_shuffle_elect_; - Buffer mbarrier_buffer_; }; using namespace tir::transform; diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index cff2f8c7f..fb94df56c 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -650,6 +650,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const tir::CallNode *op) final { if (op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_load_im2col()) || + op->op.same_as(tl::tma_load_multicast()) || op->op.same_as(tl::tma_store())) { // skip tma related calls, as they were transformed implicitly. has_tma_ = true; @@ -658,6 +659,16 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { in_tma_context_ = false; return call; } + if (op->op.same_as(tl::tma_store_cluster())) { + // SM-to-SM bulk async copy: only suppress swizzle transformation on the + // access pointers (in_tma_context_), but do NOT set has_tma_ because + // this operation does not use the TMA hardware engine and must not + // trigger warp specialization. + in_tma_context_ = true; + auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + in_tma_context_ = false; + return call; + } if (is_ptx_) { return Downcast(op); diff --git a/src/transform/multi_version_buffer_rewriter.cc b/src/transform/multi_version_buffer_rewriter.cc index 4338842cf..791ca0faf 100644 --- a/src/transform/multi_version_buffer_rewriter.cc +++ b/src/transform/multi_version_buffer_rewriter.cc @@ -141,6 +141,21 @@ class MultiVersionBufferRewriter : public StmtExprMutator { private: MultiVersionBufferRewriter() = default; + void EnsureRaggedPrefixBuffer() { + if (ragged_prefix_buf_.defined()) { + return; + } + Array shape = {IntImm(DataType::Int(32), 1)}; + ragged_prefix_buf_ = + decl_buffer(shape, DataType::Int(32), "tl_mvb_ragged_prefix", "local"); + } + + PrimExpr LoadRaggedPrefix() { + EnsureRaggedPrefixBuffer(); + Array zero_indices = {0}; + return BufferLoad(ragged_prefix_buf_, zero_indices); + } + Array GetVersionedBuffers(const Array &seq_stmt, const Array &scoped_buffers) { Array pipeline_stmts; @@ -292,6 +307,37 @@ class MultiVersionBufferRewriter : public StmtExprMutator { return stmt; } + Stmt VisitStmt_(const AttrStmtNode *op) final { + // Make sure any ragged prefix allocation lives inside the device thread + // scope (threadIdx.x). Allocating at PrimFunc scope can be lifted into an + // extra kernel parameter by downstream lowering. + stmt_stack_.push_back(op); + Stmt body = this->VisitStmt(op->body); + stmt_stack_.pop_back(); + + bool is_thread_extent = (op->attr_key == tir::attr::thread_extent); + bool is_threadidx_x = false; + if (is_thread_extent) { + if (const auto *iv = op->node.as()) { + is_threadidx_x = (iv->thread_tag == "threadIdx.x"); + } + } + + if (needs_ragged_prefix_ && is_threadidx_x && !inserted_ragged_prefix_) { + inserted_ragged_prefix_ = true; + EnsureRaggedPrefixBuffer(); + Array zero_indices = {0}; + Stmt init = BufferStore(ragged_prefix_buf_, IntImm(DataType::Int(32), 0), + zero_indices); + Stmt seq = SeqStmt({init, body}); + seq = DeclBuffer(ragged_prefix_buf_, seq); + body = Allocate(ragged_prefix_buf_->data, ragged_prefix_buf_->dtype, + ragged_prefix_buf_->shape, const_true(), seq); + } + + return AttrStmt(op->node, op->attr_key, op->value, body); + } + Stmt VisitStmt_(const ForNode *op) final { stmt_stack_.push_back(op); loop_stack_.emplace_back(op->loop_var, op->extent); @@ -392,6 +438,28 @@ class MultiVersionBufferRewriter : public StmtExprMutator { Buffer new_buffer = RewriteAllocBuffer(buffer, num_stages); buffer_remap_.Set(buffer, new_buffer); } + PrimExpr version_index_before = version_index_; + + // For ragged (dynamic extent) pipelined loops, rectangular linearization + // using loop extents is invalid. Use a runtime prefix counter so + // ping-pong buffer selection stays consistent across outer iterations. + bool is_dynamic_extent = !op->extent.as(); + if (is_dynamic_extent) { + needs_ragged_prefix_ = true; + version_index_ = + FloorMod(LoadRaggedPrefix() + (op->loop_var - op->min), num_stages); + Stmt for_node = StmtExprMutator::VisitStmt_(op); + version_index_ = version_index_before; + + Array zero_indices = {0}; + Stmt update = BufferStore(ragged_prefix_buf_, + LoadRaggedPrefix() + op->extent, zero_indices); + + loop_stack_.pop_back(); + stmt_stack_.pop_back(); + return SeqStmt({for_node, update}); + } + PrimExpr linear_index = loop_stack_[0].first; for (size_t i = 1; i < loop_stack_.size(); ++i) { linear_index = @@ -399,6 +467,7 @@ class MultiVersionBufferRewriter : public StmtExprMutator { } version_index_ = FloorMod(linear_index, num_stages); auto for_node = StmtExprMutator::VisitStmt_(op); + version_index_ = version_index_before; loop_stack_.pop_back(); stmt_stack_.pop_back(); @@ -472,6 +541,9 @@ class MultiVersionBufferRewriter : public StmtExprMutator { } PrimExpr version_index_; + Buffer ragged_prefix_buf_; + bool needs_ragged_prefix_ = false; + bool inserted_ragged_prefix_ = false; std::vector> loop_stack_; // Track ancestor statements to query whether an LCA is inside the current // loop. diff --git a/src/transform/pipeline_planning.cc b/src/transform/pipeline_planning.cc index 717dce27f..b06a15b1f 100644 --- a/src/transform/pipeline_planning.cc +++ b/src/transform/pipeline_planning.cc @@ -238,21 +238,29 @@ class BufferRegionCollector : public StmtExprVisitor { this->VisitExpr(op->args[i]); } } else if (op->op.same_as(tl::mbarrier_wait_parity())) { - ICHECK(args[0].as()); - Buffer mbar_buf = args[0].as()->buffer; - auto buffer_reads = - chain_builder_.mbar_to_buffer_reads_.find(mbar_buf.get()); - auto buffer_writes = - chain_builder_.mbar_to_buffer_writes_.find(mbar_buf.get()); - if (buffer_reads != chain_builder_.mbar_to_buffer_reads_.end()) { - reads_.insert(reads_.end(), buffer_reads->second.begin(), - buffer_reads->second.end()); - } - if (buffer_writes != chain_builder_.mbar_to_buffer_writes_.end()) { - writes_.insert( - writes_.end(), - chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).begin(), - chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).end()); + // mbarrier_wait_parity may take either a BufferLoad (preferred, allows + // linking to associated async dependencies) or a target-specific handle + // expression (e.g. tl.get_mbarrier(id)). For the latter case, we cannot + // associate the barrier with a concrete Buffer, so we conservatively + // fall back to normal traversal. + if (const auto *load = args[0].as()) { + Buffer mbar_buf = load->buffer; + auto buffer_reads = + chain_builder_.mbar_to_buffer_reads_.find(mbar_buf.get()); + auto buffer_writes = + chain_builder_.mbar_to_buffer_writes_.find(mbar_buf.get()); + if (buffer_reads != chain_builder_.mbar_to_buffer_reads_.end()) { + reads_.insert(reads_.end(), buffer_reads->second.begin(), + buffer_reads->second.end()); + } + if (buffer_writes != chain_builder_.mbar_to_buffer_writes_.end()) { + writes_.insert( + writes_.end(), + chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).begin(), + chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).end()); + } + } else { + StmtExprVisitor::VisitExpr_(op); } } else { StmtExprVisitor::VisitExpr_(op); diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index 117e513ee..d416f80b1 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -348,11 +348,33 @@ class ThreadPartialSyncRewriter : public IRMutatorWithAnalyzer { if (barrier_id_map_.count(key)) { return {barrier_id_map_[key], thread_count_map_[key]}; } + size_t thread_count = extent_tx * extent_ty * extent_tz; + + // Special-case: enforce distinct, fixed barrier IDs for the common + // producer/consumer split (256-thread CTA -> two 128-thread groups). + // This avoids accidentally mixing the same barrier id between the two + // halves, which can deadlock when both halves use bar.sync. + if (thread_count == 128 && key.ty_min == key.ty_max && + key.tz_min == key.tz_max) { + if (key.tx_min == 0 && key.tx_max == 127) { + size_t barrier_id = + static_cast(ReservedNamedBarriers::kConsumer); + barrier_id_map_[key] = barrier_id; + thread_count_map_[key] = thread_count; + return {barrier_id, thread_count}; + } + if (key.tx_min == 128 && key.tx_max == 255) { + size_t barrier_id = + static_cast(ReservedNamedBarriers::kProducer); + barrier_id_map_[key] = barrier_id; + thread_count_map_[key] = thread_count; + return {barrier_id, thread_count}; + } + } size_t barrier_id = barrier_id_map_.size() + static_cast(ReservedNamedBarriers::kFirstUsedBarrier); - size_t thread_count = extent_tx * extent_ty * extent_tz; barrier_id_map_[key] = barrier_id; thread_count_map_[key] = thread_count; @@ -887,7 +909,8 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { if (auto opt = op->op.as()) { const Op &call_op = opt.value(); return call_op.same_as(tl::tma_load()) || - call_op.same_as(tl::tma_load_im2col()); + call_op.same_as(tl::tma_load_im2col()) || + call_op.same_as(tl::tma_load_multicast()); } return false; }(); diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index d9af1ae16..44583c338 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -123,7 +123,8 @@ class ProducerUsedBufferFinder : public StmtExprVisitor { } void VisitExpr_(const CallNode *op) final { - if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { + if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) || + op->op.same_as(tma_load_multicast())) { for (auto arg : op->args) { // Collect buffers from args, including through let bindings CollectBuffersFromExpr(arg); @@ -157,7 +158,8 @@ class WarpSpecializedRoleMarker : public StmtVisitor { void VisitStmt_(const EvaluateNode *op) final { Role role = Role::kConsumer; if (auto call = op->value.as()) { - if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || + call->op.same_as(tma_load_multicast())) { role = Role::kProducer; has_bulk_copy_ = true; } @@ -178,6 +180,16 @@ class WarpSpecializedRoleMarker : public StmtVisitor { SetRole(op, Role::kBoth); return; } + + // Keep ragged-prefix bookkeeping in both producer and consumer. + // The MultiVersionBuffer rewriter inserts updates to a local buffer + // "tl_mvb_ragged_prefix" to linearize ragged pipelined loops; if we + // classify these as consumer-only, the producer path will miss the update + // and its ping-pong buffer selection will desynchronize. + if (op->buffer->name == "tl_mvb_ragged_prefix") { + SetRole(op, Role::kBoth); + return; + } if (!is_shared_store) { SetRole(op, Role::kConsumer); return; @@ -387,14 +399,18 @@ class ThreadIdxRewriter : public StmtExprMutator { UsesVar(op->condition, f_uses_thread_index) && !(UsesVar(op->then_case, f_uses_thread_index))) { auto eq_op = Downcast(op->condition); - if (eq_op->a.as() == thread_var_.get() || - eq_op->b.as() == thread_var_.get()) { - maybe_thread_opt_ = true; - } - auto then_case = StmtExprMutator::VisitStmt(op->then_case); - maybe_thread_opt_ = do_shuffle_ && maybe_thread_opt_ && has_tma_op_; + bool can_shuffle = (eq_op->a.as() == thread_var_.get() || + eq_op->b.as() == thread_var_.get()); + // Save state: visiting then_case may recursively reset maybe_thread_opt_ + // (e.g. nested IfThenElse from multicast cluster_rank check). + bool saved_has_tma = has_tma_op_; has_tma_op_ = false; - if (maybe_thread_opt_) { + StmtExprMutator::VisitStmt(op->then_case); + bool local_has_tma = has_tma_op_; + // Restore has_tma_op_ for any outer context and clear maybe_thread_opt_ + has_tma_op_ = saved_has_tma; + maybe_thread_opt_ = false; + if (do_shuffle_ && can_shuffle && local_has_tma) { return IfThenElse( Call(DataType::Bool(), tl_shuffle_elect(), {thread_extent_}), StmtExprMutator::VisitStmt(op->then_case), std::nullopt); @@ -406,6 +422,7 @@ class ThreadIdxRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const CallNode *op) final { if (op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_load_im2col()) || + op->op.same_as(tl::tma_load_multicast()) || op->op.same_as(tl::tma_store())) { has_tma_op_ = true; } @@ -638,7 +655,44 @@ class WSCodeEmitter : public StmtMutator { */ bool hasSimtCopy() const { return has_simt_copy_; } + /** + * @brief Emit code with any required ragged-pipeline bookkeeping. + * + * In particular, when encountering a pipelined loop (num_stages annotated) + * whose extent is dynamic (ragged), we maintain a per-thread running prefix + * counter to compute the correct stage/parity across outer iterations. + */ + Stmt Emit(const Stmt &stmt) { + Stmt out = StmtMutator::operator()(stmt); + if (!needs_ragged_prefix_) { + return out; + } + EnsureRaggedPrefixBuffer(); + Array zero_indices = {0}; + Stmt init = BufferStore(ragged_prefix_buf_, IntImm(DataType::Int(32), 0), + zero_indices); + Stmt seq = SeqStmt({init, out}); + seq = DeclBuffer(ragged_prefix_buf_, seq); + return Allocate(ragged_prefix_buf_->data, ragged_prefix_buf_->dtype, + ragged_prefix_buf_->shape, const_true(), seq); + } + private: + void EnsureRaggedPrefixBuffer() { + if (ragged_prefix_buf_.defined()) { + return; + } + Array shape = {IntImm(DataType::Int(32), 1)}; + ragged_prefix_buf_ = + decl_buffer(shape, DataType::Int(32), "tl_ws_ragged_prefix", "local"); + } + + PrimExpr LoadRaggedPrefix() { + EnsureRaggedPrefixBuffer(); + Array zero_indices = {0}; + return BufferLoad(ragged_prefix_buf_, zero_indices); + } + template < typename NodeType> /** * @brief Filter a statement by its producer/consumer @@ -888,16 +942,40 @@ class WSCodeEmitter : public StmtMutator { num_stages_ = num_stages; pipeline_info_ = pipeline_info; - PrimExpr linear_index = loop_stack_[0].loop_var - loop_stack_[0].min; - for (size_t i = 1; i < loop_stack_.size(); ++i) { - linear_index = linear_index * loop_stack_[i].extent + - (loop_stack_[i].loop_var - loop_stack_[i].min); - } - stage_ = FloorMod(linear_index, num_stages); - parity_ = FloorMod( - parity_before * op->extent + FloorDiv(linear_index, num_stages), 2); + + // Default (rectangular) linearization assumes loop extents are invariant. + // For pipelined loops with a dynamic (ragged) extent, this can produce an + // incorrect stage/parity mapping across outer iterations. We instead use a + // runtime prefix counter to linearize the iteration space. + bool is_pipelined_loop = static_cast(num_stages_anno); + bool is_dynamic_extent = is_pipelined_loop && !op->extent.as(); + + PrimExpr linear_index; + if (is_dynamic_extent) { + needs_ragged_prefix_ = true; + PrimExpr base = LoadRaggedPrefix(); + linear_index = base + (op->loop_var - op->min); + stage_ = FloorMod(linear_index, num_stages); + parity_ = FloorMod(FloorDiv(linear_index, num_stages), 2); + } else { + linear_index = loop_stack_[0].loop_var - loop_stack_[0].min; + for (size_t i = 1; i < loop_stack_.size(); ++i) { + linear_index = linear_index * loop_stack_[i].extent + + (loop_stack_[i].loop_var - loop_stack_[i].min); + } + stage_ = FloorMod(linear_index, num_stages); + parity_ = FloorMod( + parity_before * op->extent + FloorDiv(linear_index, num_stages), 2); + } auto result = FilterByRole(op); + if (is_dynamic_extent) { + Array zero_indices = {0}; + PrimExpr new_prefix = LoadRaggedPrefix() + op->extent; + Stmt update = BufferStore(ragged_prefix_buf_, new_prefix, zero_indices); + result = SeqStmt({result, update}); + } + Stmt grouped_for_node; if (result.as() && group_anno && !group_info_array.empty() && !is_emitting_producer_) { @@ -1152,6 +1230,36 @@ class WSCodeEmitter : public StmtMutator { PipelineInfo pipeline_info_; friend class WarpSpecializedRewriter; bool has_simt_copy_ = false; + Buffer ragged_prefix_buf_; + bool needs_ragged_prefix_ = false; +}; + +class UserBarrierInitExtractor : public StmtMutator { +public: + std::vector init_stmts; + + Stmt VisitStmt_(const IfThenElseNode *op) final { + if (IsOnlyInit(op->then_case)) { + init_stmts.push_back(GetRef(op)); + return Evaluate(0); + } + return StmtMutator::VisitStmt_(op); + } + + bool IsOnlyInit(const Stmt &stmt) { + if (const auto *eval = stmt.as()) { + if (const auto *call = eval->value.as()) { + return call->op.same_as(builtin::ptx_init_barrier_thread_count()); + } + } else if (const auto *seq = stmt.as()) { + for (const auto &s : seq->seq) { + if (!IsOnlyInit(s)) + return false; + } + return true; + } + return false; + } }; class WarpSpecializedRewriter : public StmtExprMutator { @@ -1234,10 +1342,18 @@ class WarpSpecializedRewriter : public StmtExprMutator { return block_realize; } + UserBarrierInitExtractor extractor; + Stmt body_without_inits = extractor(block->body); + + // Re-run marker on the new body to ensure all nodes are mapped + WarpSpecializedRoleMarker marker_new(buffer_data_to_buffer_); + marker_new.Prepare(body_without_inits); + marker_new(body_without_inits); + if (disable_warp_specialized_) { WSCodeEmitter mbarrier_emitter(true, thread_iv_, buffer_data_to_buffer_, - marker, true); - auto code = mbarrier_emitter(block->body); + marker_new, true); + auto code = mbarrier_emitter.Emit(body_without_inits); int num_barriers = mbarrier_emitter.num_barriers_; Array barrier_num_threads; barrier_num_threads.reserve(num_barriers); @@ -1247,15 +1363,22 @@ class WarpSpecializedRewriter : public StmtExprMutator { } Stmt init_barrier = Evaluate(Call( DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads)); - block.CopyOnWrite()->body = SeqStmt({init_barrier, code}); + std::vector all_inits; + all_inits.push_back(init_barrier); + all_inits.insert(all_inits.end(), extractor.init_stmts.begin(), + extractor.init_stmts.end()); + // Avoid constructing SeqStmt with a single element (disallowed in TVM). + Stmt init_seq = SeqStmt::Flatten(all_inits); + block.CopyOnWrite()->body = SeqStmt({init_seq, code}); block_realize.CopyOnWrite()->block = block; return block_realize; } - WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker); - WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker, - false); - Stmt producer_code = producer(block->body); - Stmt consumer_code = consumer(block->body); + WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, + marker_new); + WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, + marker_new, false); + Stmt producer_code = producer.Emit(body_without_inits); + Stmt consumer_code = consumer.Emit(body_without_inits); PrimExpr consumer_thread_extent = thread_iv_->dom->extent; PrimExpr producer_thread_extent = thread_iv_->dom->extent; // Need one warp-group for bulk-copy only case @@ -1288,6 +1411,10 @@ class WarpSpecializedRewriter : public StmtExprMutator { Stmt init_barrier = Evaluate(Call( DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads)); + std::vector all_inits; + all_inits.push_back(init_barrier); + all_inits.insert(all_inits.end(), extractor.init_stmts.begin(), + extractor.init_stmts.end()); Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent), producer_code, consumer_code); // Add an attr here to handle the partial thread count in ThreadSync pass. @@ -1295,7 +1422,9 @@ class WarpSpecializedRewriter : public StmtExprMutator { Downcast(consumer_thread_extent)}; body = AttrStmt(ws_partition, attr::kWarpSpecializationScope, 0, body); - block.CopyOnWrite()->body = SeqStmt({init_barrier, body}); + // Avoid SeqStmt of length 1 by flattening any single-init sequence. + Stmt init_seq = SeqStmt::Flatten(all_inits); + block.CopyOnWrite()->body = SeqStmt({init_seq, body}); block_realize.CopyOnWrite()->block = block; return block_realize; } diff --git a/src/transform/warp_specialized_rewriter.h b/src/transform/warp_specialized_rewriter.h index 01a2474a8..cba6d6c26 100644 --- a/src/transform/warp_specialized_rewriter.h +++ b/src/transform/warp_specialized_rewriter.h @@ -30,7 +30,7 @@ using arith::IRVisitorWithAnalyzer; class WarpSpecializedDetector : public IRVisitorWithAnalyzer { public: - // return true means this aws will be disabled + // return true means auto warp specialization will be disabled static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) { WarpSpecializedDetector detector; detector.VisitStmt(stmt); @@ -39,7 +39,11 @@ class WarpSpecializedDetector : public IRVisitorWithAnalyzer { "specialization is manually enabled"; return true; } - if (detector.has_tma_op_ && detector.has_mbarrier_op_) { + // When mbarrier ops coexist with TMA loads but tma_store_cluster is also + // present, the barriers are for SM-to-SM cluster copy synchronisation and + // should not block auto warp specialisation. + if (detector.has_tma_op_ && detector.has_mbarrier_op_ && + !detector.has_cluster_copy_) { LOG(WARNING) << "Auto warp specialization will be disabled because TMA " "and mbarrier are both present"; return true; @@ -51,6 +55,7 @@ class WarpSpecializedDetector : public IRVisitorWithAnalyzer { has_tma_op_ = false; has_mbarrier_op_ = false; has_warp_specialization_ = false; + has_cluster_copy_ = false; } private: @@ -68,9 +73,13 @@ class WarpSpecializedDetector : public IRVisitorWithAnalyzer { void VisitExpr_(const CallNode *op) final { if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) || + op->op.same_as(tma_load_multicast()) || op->op.same_as(set_max_nreg())) { has_tma_op_ = true; } + if (op->op.same_as(tma_store_cluster())) { + has_cluster_copy_ = true; + } IRVisitorWithAnalyzer::VisitExpr_(op); } @@ -93,6 +102,7 @@ class WarpSpecializedDetector : public IRVisitorWithAnalyzer { IterVar thread_var_; bool has_mbarrier_op_{false}; bool has_warp_specialization_{false}; + bool has_cluster_copy_{false}; }; } // namespace tl diff --git a/testing/python/cuda/test_tma_dsmem.py b/testing/python/cuda/test_tma_dsmem.py new file mode 100644 index 000000000..1fbff65c5 --- /dev/null +++ b/testing/python/cuda/test_tma_dsmem.py @@ -0,0 +1,92 @@ +""" +Demo / regression test for SM-to-SM bulk async copy via tl::tma_store_cluster. + +T.copy with dst_block + barrier now lowers to a single +tl::tma_store_cluster call instead of a SIMT element-by-element loop. + +Expected generated producer code (block 0): + if (((int)threadIdx.x) == 0) { + tl::tma_store_cluster(&s_dst[0], &s_src[0], 1, + (uint32_t)(512), s_barrier[0]); + } + +Block 1 waits on its own s_barrier and then reads the result. +""" + +import torch +import tilelang +import tilelang.language as T +import numpy as np + + +@tilelang.jit(verbose=True, execution_backend="cython") +def make_store_cluster_kernel(N: int): + @T.prim_func + def kernel( + A: T.Tensor((N,), "float32"), + B: T.Tensor((N,), "float32"), + ): + # 2 CTAs in a cluster of size 2 + with T.Kernel(2, threads=128, cluster_size=2) as pid: + s_src = T.alloc_shared((N,), "float32") + s_dst = T.alloc_shared((N,), "float32") + s_barrier = T.alloc_shared((1,), "uint64") + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + + # Every CTA initialises its own barrier: expect 1 arrival + # carrying N*4 bytes (the cp.async.bulk signals on completion). + if T.get_thread_binding() == 0: + T.mbarrier_init(s_barrier[0], 1) + + T.cluster_sync() + + if pid == 0: + # Load A into s_src + for i in T.Parallel(N): + s_src[i] = A[i] + + # Bulk-async copy s_src (local) → s_dst (remote, block 1) + # using tl::tma_store_cluster, signalling block 1's barrier. + T.copy(s_src, s_dst, dst_block=1, remote_barrier=s_barrier[0]) + + if pid == 1: + # Wait until block 0 finishes writing to our s_dst. + T.mbarrier_wait_parity(s_barrier[0], 0) + + # Store result to global memory + for i in T.Parallel(N): + B[i] = s_dst[i] + + return kernel + + +def main(): + major, minor = torch.cuda.get_device_capability() + if major < 9: + print(f"Skipping: requires Compute Capability 9.0+, found {major}.{minor}") + return + + N = 128 + A = torch.arange(N, dtype=torch.float32, device="cuda") + B = torch.zeros(N, dtype=torch.float32, device="cuda") + + kernel = make_store_cluster_kernel(N) + kernel(A, B) + + result = B.cpu().numpy() + expected = A.cpu().numpy() + + print("Result (first 8):", result[:8]) + print("Expected(first 8):", expected[:8]) + + if np.allclose(result, expected): + print("PASS: tma_store_cluster copy successful") + else: + diff = np.abs(result - expected).max() + print(f"FAIL: max diff = {diff}") + + +if __name__ == "__main__": + main() diff --git a/testing/python/cuda/test_tma_multicast_demo.py b/testing/python/cuda/test_tma_multicast_demo.py new file mode 100644 index 000000000..61aea88bb --- /dev/null +++ b/testing/python/cuda/test_tma_multicast_demo.py @@ -0,0 +1,103 @@ +""" +TMA multicast validation demo. + +Verification logic: +- cluster_size=4, cluster_mask=0b0011 (bits 0 and 1 set, i.e. CTA ranks 0 and 1 are in the mask) +- CTA rank 0: issues tma_load_multicast, broadcasting its A tile to both rank 0 and rank 1 +- CTA rank 1: does not issue a load; passively receives the multicast data (same tile as rank 0) +- CTA ranks 2, 3: not in the mask, each performs a regular tma_load for its own tile + +Therefore within the same cluster: +- B at rank 0's region = A at rank 0's region +- B at rank 1's region = A at rank 0's region (multicast result, identical to rank 0) +- B at ranks 2, 3 regions = A at ranks 2, 3 respective regions + +The test verifies multicast by checking that rank 1's B region equals rank 0's A region. +""" + +import torch +import tilelang +import tilelang.language as T + + +def make_tma_multicast_demo_kernel(M, N, block_M, block_N, cluster_mask): + """ + Build the TMA multicast demo kernel. + + cluster_mask: multicast bitmask. A set bit means the corresponding CTA + participates in multicast (receives the tile from the min-rank CTA). + e.g. 0b0011 means ranks 0 and 1 are in the mask; rank 0 issues + the multicast, rank 1 passively receives. + """ + + @T.prim_func + def kernel( + A: T.Tensor((M, N), "float16"), + B: T.Tensor((M, N), "float16"), + ): + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=128, + cluster_size=4, + ) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), "float16") + T.copy(A[by * block_M, bx * block_N], A_shared, cluster_mask=cluster_mask) + T.copy(A_shared, B[by * block_M, bx * block_N]) + + return kernel + + +def test_tma_multicast_demo(): + """Verify TMA multicast: rank 1's B region should equal rank 0's A region within the same cluster.""" + M, N = 1024, 1024 + block_M, block_N = 128, 64 + # mask=0b0011: rank 0 multicasts, rank 1 receives, ranks 2/3 each do regular tma_load + cluster_mask = 0b0011 + + kernel = make_tma_multicast_demo_kernel(M, N, block_M, block_N, cluster_mask) + mod = tilelang.compile( + kernel, + out_idx=[1], + verbose=True, + execution_backend="cython", + ) + + print("--- TMA Multicast Demo Kernel Source ---") + print(mod.get_kernel_source()) + + A = torch.randn(M, N, device="cuda", dtype=torch.float16) + B = mod(A) + + # Within a cluster: the first 4 blocks in the grid are (0,0),(1,0),(2,0),(3,0) -> by=0, bx=0,1,2,3 + # rank 0 -> bx=0: A[0:block_M, 0:block_N] -> B[0:block_M, 0:block_N] + # rank 1 -> bx=1: multicast receives A[0:block_M, 0:block_N] -> B[0:block_M, block_N:2*block_N] + # rank 2 -> bx=2: A[0:block_M, 2*block_N:3*block_N] -> B[0:block_M, 2*block_N:3*block_N] + # rank 3 -> bx=3: A[0:block_M, 3*block_N:4*block_N] -> B[0:block_M, 3*block_N:4*block_N] + + # Multicast check: rank 1's B region should equal rank 0's A region + B_rank1 = B[0:block_M, block_N : 2 * block_N] + A_rank0 = A[0:block_M, 0:block_N] + torch.testing.assert_close(B_rank1, A_rank0, rtol=1e-2, atol=1e-2) + print("PASS: Multicast verified (B[rank1_region] == A[rank0_region])") + + # rank 0 itself: B should equal A + torch.testing.assert_close(B[0:block_M, 0:block_N], A[0:block_M, 0:block_N], rtol=1e-2, atol=1e-2) + # ranks 2, 3: each B region equals its own A region + torch.testing.assert_close( + B[0:block_M, 2 * block_N : 3 * block_N], + A[0:block_M, 2 * block_N : 3 * block_N], + rtol=1e-2, + atol=1e-2, + ) + torch.testing.assert_close( + B[0:block_M, 3 * block_N : 4 * block_N], + A[0:block_M, 3 * block_N : 4 * block_N], + rtol=1e-2, + atol=1e-2, + ) + print("PASS: TMA multicast demo passed") + + +if __name__ == "__main__": + test_tma_multicast_demo() diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index fe6214fc7..98f382685 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -274,7 +274,12 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # MergeSharedMemoryAllocations must be applied after SplitHostDevice # because the merged allocation site is at the beginning of each device function enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) - mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod) + shared_align_bytes = 128 if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target) else 16 + + mod = tilelang.transform.MergeSharedMemoryAllocations( + enable_aggressive_merge=enable_aggressive_merge, + align_bytes=shared_align_bytes, + )(mod) if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target): mod = tilelang.transform.InjectFenceProxy()(mod) else: diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index dccb3b264..2269ea649 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -29,6 +29,13 @@ def _normalize_index_arg(value: int | PrimExpr | None) -> PrimExpr | None: raise TypeError(f"Expect warp sizing argument to be int or PrimExpr, but got {type(value)}.") +def _get_mbarrier(barrier_id: int | PrimExpr): + """Create an intermediate mbarrier handle from barrier id for internal lowering only.""" + raise NotImplementedError( + "Direct mbarrier handle creation from id is not supported in the frontend. Use T.alloc_barrier to create mbarriers instead." + ) + + def _mbar_to_buffer_load(mbar: BarrierType) -> BufferLoad: """Convert a memory barrier to a buffer load. @@ -405,6 +412,25 @@ def mbarrier_wait_parity(mbarrier: BarrierType, parity: int | Var): return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity) +def mbarrier_init(mbarrier: int | PrimExpr | tir.Call, arrive_count: int | PrimExpr): + """Initialize a memory barrier. + + Args: + mbarrier: The memory barrier to initialize + arrive_count: The expected arrival count + """ + if isinstance(mbarrier, (tir.Call, tir.BufferLoad)): + mbarrier = mbarrier + elif isinstance(mbarrier, (tir.PrimExpr, int)): + mbarrier = _get_mbarrier(mbarrier) + elif isinstance(mbarrier, tir.Buffer): + mbarrier = tir.BufferLoad(mbarrier, [0]) + else: + raise TypeError(f"mbarrier must be an integer or a tir.Call, but got {type(mbarrier)}") + + return tir.call_intrin("handle", tir.op.Op.get("tir.ptx_init_barrier_thread_count"), mbarrier, arrive_count) + + def mbarrier_arrive(mbarrier: BarrierType): """Arrive at memory barrier. @@ -578,6 +604,37 @@ def get_warp_group_idx( return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_group_idx"), *args) +def get_cluster_id() -> PrimExpr: + """Return the cluster id (rank) of the current block within the cluster. + + This lowers to the intrinsic `tl.get_cluster_id` and is emitted for CUDA + as `cooperative_groups::this_grid().cluster_rank()`. + """ + return tir.call_intrin("int32", tir.op.Op.get("tl.get_cluster_id")) + + +def get_cluster_block_rank() -> PrimExpr: + """Return the block rank within the current cluster. + + Lowers to `tl.get_cluster_block_rank` and emits + `cooperative_groups::this_cluster().block_rank()` in CUDA codegen. + """ + return tir.call_intrin("int32", tir.op.Op.get("tl.get_cluster_block_rank")) + + +def cluster_block_nums() -> PrimExpr: + """Return the number of blocks in the cluster. + + Lowers to `tl.get_cluster_block_nums` and emits + `cooperative_groups::this_cluster().num_blocks()` in CUDA codegen. + """ + return tir.call_intrin("int32", tir.op.Op.get("tl.get_cluster_block_nums")) + + +def cluster_sync(): + return tir.call_intrin("handle", tir.op.Op.get("tl.cluster_sync")) + + def shuffle_elect(thread_extent: int) -> PrimExpr: """Elect exactly one lane within a logical thread group. diff --git a/tilelang/language/copy_op.py b/tilelang/language/copy_op.py index fc69b3bbc..eef5c0fe7 100644 --- a/tilelang/language/copy_op.py +++ b/tilelang/language/copy_op.py @@ -20,6 +20,9 @@ def copy( eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, annotations: dict | None = None, loop_layout: Any | None = None, + dst_block: int | tir.PrimExpr | None = None, + cluster_mask: int | None = None, + remote_barrier: tir.BufferLoad | None = None, ) -> tir.PrimExpr | tir.Stmt: """Copy data between memory regions. @@ -36,6 +39,20 @@ def copy( (only valid for normal SIMT copy; incompatible with TMA/LDSM/STSM/TMem). When provided, it is attached to the outermost parallel loop generated by this copy. + dst_block (Optional[Union[int, tir.PrimExpr]], optional): Destination block index for cluster copy. Defaults to None. + cluster_mask (Optional[int], keyword-only): Bitmask specifying which CTAs in the cluster + receive the TMA multicast broadcast. When set, the CTA whose rank equals the lowest + set bit in the mask issues ``tma_load_multicast`` (sending data to *all* masked CTAs + simultaneously); every other CTA falls back to a regular ``tma_load`` for its own + shared memory. A value of ``None`` (default) disables multicast and always uses a + regular TMA load. + remote_barrier (Optional[tir.BufferLoad], keyword-only): Shared-memory mbarrier element used for + SM-to-SM cluster copy (``dst_block`` must also be set). When provided, the copy is + performed with a single bulk-async ``tl::tma_store_cluster`` call instead of the + default element-by-element SIMT loop. The barrier must be at the **same** shared + memory offset in every CTA of the cluster (which is automatically true when all CTAs + run the same kernel and declare the same shared-memory layout). The destination CTA + should wait on its local copy of this barrier after the copy completes. Raises: TypeError: If copy extents cannot be deduced from arguments @@ -98,6 +115,15 @@ def copy( if loop_layout is not None and "parallel_loop_layout" not in ann: ann["parallel_loop_layout"] = loop_layout + if dst_block is not None: + ann["dst_block"] = dst_block + + if "cluster_mask" not in ann and cluster_mask is not None: + ann["cluster_mask"] = cluster_mask + + if "barrier" not in ann and remote_barrier is not None: + ann["barrier"] = remote_barrier + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, annotations=ann if ann else None) From 01c73b3d7ae04c315eb27ec40f03a3042d9b1a49 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Sat, 7 Mar 2026 04:22:43 +0800 Subject: [PATCH 02/51] [Docs] Add programming guide for Cluster TMA features Add docs/programming_guides/cluster_tma.md covering the two new T.copy extensions introduced in the t_copy_extend feature branch: - TMA multicast (cluster_mask): explains how a single TMA transaction broadcasts one global tile to multiple CTAs in a cluster simultaneously, with API usage, a per-rank behaviour table, and a complete code example. - SM-to-SM cluster copy (dst_block / remote_barrier): documents the fast path (cp.async.bulk.shared::cluster + mbarrier, single-thread async DMA) and the slow path (map_shared_rank element-by-element SIMT fallback), including the synchronisation contract for source and destination CTAs. Also covers cluster helper builtins (T.get_cluster_block_rank, T.cluster_sync, T.mbarrier_init/arrive/wait_parity, etc.) and a split-K sketch combining both features end-to-end. --- docs/programming_guides/cluster_tma.md | 300 +++++++++++++++++++++++++ 1 file changed, 300 insertions(+) create mode 100644 docs/programming_guides/cluster_tma.md diff --git a/docs/programming_guides/cluster_tma.md b/docs/programming_guides/cluster_tma.md new file mode 100644 index 000000000..ae9040483 --- /dev/null +++ b/docs/programming_guides/cluster_tma.md @@ -0,0 +1,300 @@ +# Cluster TMA: Multicast and SM-to-SM Copy + +Authors: +- Jingkai He +- Guangda Sun <2012661711@qq.com> + + +This page describes two advanced data-movement features that are available on +NVIDIA Hopper (SM90) and later: **TMA multicast** and **SM-to-SM cluster +copy**. Both features are exposed through extensions to the existing `T.copy` +operator and require a kernel launched with `cluster_size > 1`. + +Requirements: +- CUDA Compute Capability ≥ 9.0 (Hopper / Blackwell / RTX 5090) + +--- + +## Background: Thread Block Clusters + +A *thread block cluster* is a group of CTAs that share a common virtual address +space for their shared-memory regions and can communicate without going through +global memory. Within a cluster, each CTA has a *block rank* (0-indexed +position inside the cluster), and all CTAs can observe each other's shared +memory via the `shared::cluster` address space. + +```python +with T.Kernel(grid_x, grid_y, threads=128, cluster_size=4) as (bx, by): + rank = T.get_cluster_block_rank() # 0..3 within this cluster + cid = T.get_cluster_id() # which cluster am I in + nctas = T.get_cluster_block_nums() # always equals cluster_size (4) + T.cluster_sync() # barrier across all CTAs in cluster +``` + +--- + +## Feature 1 — TMA Multicast (`cluster_mask`) + +### What it does + +Normally each CTA issues its own TMA load, fetching a tile from global memory +into its private shared memory. With multicast, **a single TMA transaction +broadcasts one global tile to every participating CTA simultaneously**, saving +repeated DRAM traffic when multiple CTAs in a cluster need the same data (e.g., +the same K-panel in a split-K GEMM). + +``` +Global memory ──TMA multicast──▶ shared memory (rank 0) + └─▶ shared memory (rank 1) (same tile, no extra DRAM read) + TMA load ──▶ shared memory (rank 2) (independent tile) + TMA load ──▶ shared memory (rank 3) (independent tile) +``` + +### API + +```python +T.copy(src_global, dst_shared, cluster_mask=) +``` + +`cluster_mask` is a bitmask where each set bit identifies a CTA rank that +participates in the multicast. The CTA whose rank equals the lowest set bit +in the mask issues `cp.async.bulk.tensor … multicast::cluster`; every other +CTA in the mask receives the data passively (no instruction issued). CTAs +outside the mask perform a regular TMA load for their own tile. + +### Example + +```python +import tilelang +import tilelang.language as T + +def make_tma_multicast_kernel(M, N, block_M, block_N, cluster_mask): + @T.prim_func + def kernel( + A: T.Tensor((M, N), "float16"), + B: T.Tensor((M, N), "float16"), + ): + # 4 CTAs per cluster; ranks 0 and 1 share the same tile via multicast. + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=128, + cluster_size=4, + ) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), "float16") + + # cluster_mask=0b0011: ranks 0 and 1 participate. + # Rank 0 issues tma_load_multicast; rank 1 receives passively. + # Ranks 2 and 3 each issue a regular tma_load. + T.copy(A[by * block_M, bx * block_N], A_shared, + cluster_mask=cluster_mask) + + T.copy(A_shared, B[by * block_M, bx * block_N]) + + return kernel +``` + +Running the kernel above with `cluster_mask = 0b0011`: + +| Rank | Action | `B` slice receives | +|------|--------|--------------------| +| 0 | issues multicast load | A tile at rank-0 address | +| 1 | passively receives | **same** A tile as rank 0 | +| 2 | regular TMA load | A tile at rank-2 address | +| 3 | regular TMA load | A tile at rank-3 address | + +### Notes + +- The compiler lowers `cluster_mask != 0` to + `cp.async.bulk.tensor.Nd.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster` + for the issuing CTA and a standard `cp.async.bulk.tensor` for the rest. +- Software-pipelining (`T.Pipelined`) is fully supported; the warp-specialized + rewriter recognises `tma_load_multicast` as a producer operation. +- `cluster_mask` is a compile-time constant; dynamic masks are not supported. + +--- + +## Feature 2 — SM-to-SM Cluster Copy (`dst_block`) + +### What it does + +SM-to-SM copy lets one CTA **push data directly from its own shared memory +into another CTA's shared memory** within the same cluster, without a round +trip through global memory. This is useful for patterns such as: + +- Partial result exchange (e.g., split-K partial sums across SM boundaries) +- Producer–consumer pipelines where the producer fills a neighbor's buffer +- All-to-all collective communication within a cluster + +Two sub-variants are provided depending on whether an mbarrier is supplied: + +| Variant | Parameter | Hardware instruction | Threads used | +|---------|-----------|---------------------|--------------| +| **Fast path** | `dst_block` + `remote_barrier` | `cp.async.bulk.shared::cluster` | 1 (async DMA) | +| **Slow path** | `dst_block` only | `map_shared_rank` + scalar stores | all (SIMT loop) | + +### Fast path — bulk async copy with mbarrier + +```python +T.copy(src_shared, dst_shared, dst_block=, remote_barrier=) +``` + +A single elected thread issues one `cp.async.bulk.shared::cluster` instruction. +The hardware DMA engine transfers the entire tile asynchronously and signals +the destination CTA's mbarrier on completion. The destination CTA waits with +`T.mbarrier_wait_parity`. + +Steps: +1. Both CTAs allocate the **same** shared memory layout so their mbarriers live + at the same offset. +2. Every CTA initialises its own barrier for 1 arrival. +3. The source CTA (`pid == 0` below) calls `T.copy(... dst_block=1, remote_barrier=...)`. +4. The destination CTA (`pid == 1`) waits on its local barrier copy. + +```python +import tilelang +import tilelang.language as T + +@tilelang.jit(verbose=True, execution_backend="cython") +def make_cluster_copy_kernel(N: int): + @T.prim_func + def kernel( + A: T.Tensor((N,), "float32"), + B: T.Tensor((N,), "float32"), + ): + with T.Kernel(2, threads=128, cluster_size=2) as pid: + s_src = T.alloc_shared((N,), "float32") + s_dst = T.alloc_shared((N,), "float32") + s_barrier = T.alloc_shared((1,), "uint64") + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + + # Each CTA initialises its own barrier: 1 expected arrival. + if T.get_thread_binding() == 0: + T.mbarrier_init(s_barrier[0], 1) + + T.cluster_sync() + + if pid == 0: + # Load A into local shared memory. + for i in T.Parallel(N): + s_src[i] = A[i] + + # Async-push s_src → s_dst in CTA 1, signal CTA 1's barrier. + T.copy(s_src, s_dst, dst_block=1, + remote_barrier=s_barrier[0]) + + if pid == 1: + # Wait until CTA 0 finishes writing. + T.mbarrier_wait_parity(s_barrier[0], 0) + + for i in T.Parallel(N): + B[i] = s_dst[i] + + return kernel +``` + +Generated producer code (single-thread guard, one PTX instruction): + +```cuda +if (((int)threadIdx.x) == 0) { + tl::tma_store_cluster(&s_dst[0], &s_src[0], 1, + (uint32_t)(N * 4), s_barrier[0]); +} +``` + +### Slow path — element-by-element SIMT fallback + +Omit `remote_barrier` to use the slow path: + +```python +T.copy(s_src, s_dst, dst_block=1) +``` + +This lowers to a SIMT parallel loop where every thread writes one (or a few) +elements into the remote CTA's shared memory via +`cooperative_groups::map_shared_rank`. Because `map_shared_rank` returns a +scalar pointer, vectorised writes are not possible. Use this path only when an +mbarrier is unavailable or when the tile is too small to justify barrier +overhead. + +### Synchronisation contract + +| | Fast path | Slow path | +|-|-----------|-----------| +| Source CTA | no wait needed; copy is async | effectively sync after the loop | +| Destination CTA | `T.mbarrier_wait_parity(barrier, parity)` | external `T.cluster_sync()` or equivalent | + +### Notes + +- Both paths require `src` and `dst` to be in `shared` or `shared.dyn` scope. +- The mbarrier must be allocated with `T.alloc_shared((count,), "uint64")` and + initialised with `T.mbarrier_init` before use. +- `T.cluster_sync()` after allocation but before the copy is required to ensure + all CTAs have reached the barrier-init barrier before any data is pushed. +- `dst_block` may be a compile-time integer or a runtime `tir.PrimExpr`. + +--- + +## Cluster Helper Builtins + +| Builtin | Return | Description | +|---------|--------|-------------| +| `T.get_cluster_id()` | `int32` | Index of this cluster in the grid | +| `T.get_cluster_block_rank()` | `int32` | Block rank (0-indexed) within the cluster | +| `T.get_cluster_block_nums()` | `int32` | Total number of CTAs in the cluster | +| `T.cluster_sync()` | — | Barrier synchronisation across all cluster CTAs | +| `T.mbarrier_init(bar, count)` | — | Initialise an mbarrier for `count` arrivals | +| `T.mbarrier_arrive(bar)` | — | Signal one arrival on an mbarrier | +| `T.mbarrier_wait_parity(bar, parity)` | — | Wait until `bar` flips to `parity` | + +--- + +## Putting It Together: Split-K Sketch + +A common pattern combining both features: multicast the shared K-panel to +all cluster CTAs (saving DRAM bandwidth), then reduce partial sums with +SM-to-SM copy (saving global-memory round trips). + +```python +@T.prim_func +def split_k_gemm(A, B, C): + with T.Kernel(grid_x, grid_y, threads=256, cluster_size=4) as (bx, by): + rank = T.get_cluster_block_rank() + A_s = T.alloc_shared((BM, BK), "float16") + B_s = T.alloc_shared((BK, BN), "float16") + C_f = T.alloc_fragment((BM, BN), "float32") + C_s = T.alloc_shared((BM, BN), "float32") + barrier = T.alloc_shared((1,), "uint64") + T.clear(C_f) + + # Phase 1: each CTA loads its K-slice; A is multicast to rank 0 and 1. + for ko in T.Pipelined(T.ceildiv(K, BK * 4), num_stages=3): + k_off = (rank + ko * 4) * BK + T.copy(A[by * BM, k_off], A_s, cluster_mask=0b0011) + T.copy(B[k_off, bx * BN], B_s) + T.gemm(A_s, B_s, C_f) + + # Phase 2: push partial sums to rank 0 via SM-to-SM copy. + T.copy(C_f, C_s) + if T.get_thread_binding() == 0: + T.mbarrier_init(barrier[0], 1) + T.cluster_sync() + + if rank != 0: + T.copy(C_s, C_s, dst_block=0, remote_barrier=barrier[0]) + if rank == 0: + T.mbarrier_wait_parity(barrier[0], 0) + # accumulate and store ... + T.copy(C_s, C[by * BM, bx * BN]) +``` + +--- + +## See Also + +- `testing/python/cuda/test_tma_multicast_demo.py` — multicast validation +- `testing/python/cuda/test_tma_dsmem.py` — SM-to-SM copy validation +- Programming Guides → Instructions — complete `T.copy` parameter reference +- Programming Guides → Control Flow — `T.Pipelined` and warp-specialized pipelines From e8b3bd0c8b849c1c56aa7650ce32514e5d2f0632 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Sat, 7 Mar 2026 05:53:25 +0800 Subject: [PATCH 03/51] change cluster_size -> cluster_dims --- testing/python/cuda/test_tma_dsmem.py | 2 +- testing/python/cuda/test_tma_multicast_demo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/testing/python/cuda/test_tma_dsmem.py b/testing/python/cuda/test_tma_dsmem.py index 1fbff65c5..e038fb3eb 100644 --- a/testing/python/cuda/test_tma_dsmem.py +++ b/testing/python/cuda/test_tma_dsmem.py @@ -27,7 +27,7 @@ def kernel( B: T.Tensor((N,), "float32"), ): # 2 CTAs in a cluster of size 2 - with T.Kernel(2, threads=128, cluster_size=2) as pid: + with T.Kernel(2, threads=128, cluster_dims=(2,1,1)) as pid: s_src = T.alloc_shared((N,), "float32") s_dst = T.alloc_shared((N,), "float32") s_barrier = T.alloc_shared((1,), "uint64") diff --git a/testing/python/cuda/test_tma_multicast_demo.py b/testing/python/cuda/test_tma_multicast_demo.py index 61aea88bb..2952201ec 100644 --- a/testing/python/cuda/test_tma_multicast_demo.py +++ b/testing/python/cuda/test_tma_multicast_demo.py @@ -39,7 +39,7 @@ def kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128, - cluster_size=4, + cluster_dims=(4,1,1), ) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), "float16") T.copy(A[by * block_M, bx * block_N], A_shared, cluster_mask=cluster_mask) From deab40214dd25d01b0f54ea1d5ceb74c33ffad93 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Sat, 7 Mar 2026 11:14:13 +0800 Subject: [PATCH 04/51] fix merge conflict --- src/op/builtin.cc | 5 ----- src/op/builtin.h | 6 ------ src/target/codegen_cuda.cc | 4 ---- 3 files changed, 15 deletions(-) diff --git a/src/op/builtin.cc b/src/op/builtin.cc index c91a7c228..0eb238181 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -576,10 +576,5 @@ TIR_DEFINE_TL_BUILTIN(tma_store_cluster) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(cluster_sync) - .set_num_inputs(0) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); - } // namespace tl } // namespace tvm diff --git a/src/op/builtin.h b/src/op/builtin.h index 8b54f1f8a..7f8f8f93d 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -1006,12 +1006,6 @@ TVM_DLL const Op &ptx_cluster_store(); */ TVM_DLL const Op &tma_store_cluster(); -/*! - * \brief tilelang intrinsic for cluster sync. - * - * This op is used to represent a cluster sync operation in tilelang. - */ -TVM_DLL const Op &cluster_sync(); } // namespace tl } // namespace tvm diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 9e251e5a8..579175921 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2271,10 +2271,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->PrintIndent(); this->stream << "}\n"; - } else if (op->op.same_as(tl::cluster_sync())) { - this->need_cooperative_groups_ = true; - this->PrintIndent(); - this->stream << "cooperative_groups::this_cluster().sync();\n"; } else if (op->op.same_as(tl::ptx_mma_sm70())) { // arg 0: shape: mXnXkX // arg 1: A layout: row/col From 2ef8eba830955246258dd3cac5133cc4114c9d21 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Sat, 7 Mar 2026 12:06:46 +0800 Subject: [PATCH 05/51] fix TILELANG CHECK bug --- src/tl_templates/cuda/cluster.h | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/tl_templates/cuda/cluster.h b/src/tl_templates/cuda/cluster.h index b61d2f70e..9d45b4e50 100644 --- a/src/tl_templates/cuda/cluster.h +++ b/src/tl_templates/cuda/cluster.h @@ -1,6 +1,7 @@ #pragma once #include "common.h" +#include // Config #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ @@ -15,7 +16,7 @@ TL_DEVICE void cluster_arrive_relaxed() { #if defined(CLUSTER_ENABLED) asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : :); #else - TILELANG_CHECK(false, "CLUSTER_ENABLED is not defined"); + assert(false && "CLUSTER_ENABLED is not defined"); #endif } @@ -23,7 +24,7 @@ TL_DEVICE void cluster_arrive() { #if defined(CLUSTER_ENABLED) asm volatile("barrier.cluster.arrive.aligned;\n" : :); #else - TILELANG_CHECK(false, "CLUSTER_ENABLED is not defined"); + assert(false && "CLUSTER_ENABLED is not defined"); #endif } @@ -31,7 +32,7 @@ TL_DEVICE void cluster_wait() { #if defined(CLUSTER_ENABLED) asm volatile("barrier.cluster.wait.aligned;\n" : :); #else - TILELANG_CHECK(false, "CLUSTER_ENABLED is not defined"); + assert(false && "CLUSTER_ENABLED is not defined"); #endif } @@ -40,7 +41,7 @@ TL_DEVICE void cluster_sync() { cluster_arrive(); cluster_wait(); #else - TILELANG_CHECK(false, "CLUSTER_ENABLED is not defined"); + assert(false && "CLUSTER_ENABLED is not defined"); #endif } @@ -53,7 +54,8 @@ TL_DEVICE dim3 cluster_grid_dims() { asm volatile("mov.u32 %0, %%nclusterid.z;\n" : "=r"(z) :); return {x, y, z}; #else - TILELANG_CHECK(false, "CLUSTER_ENABLED is not defined"); + assert(false && "CLUSTER_ENABLED is not defined"); + return {0, 0, 0}; #endif } @@ -66,7 +68,8 @@ TL_DEVICE dim3 cluster_id_in_grid() { asm volatile("mov.u32 %0, %%clusterid.z;\n" : "=r"(z) :); return {x, y, z}; #else - TILELANG_CHECK(false, "CLUSTER_ENABLED is not defined"); + assert(false && "CLUSTER_ENABLED is not defined"); + return {0, 0, 0}; #endif } @@ -79,7 +82,8 @@ TL_DEVICE dim3 cluster_shape() { asm volatile("mov.u32 %0, %%cluster_nctaid.z;\n" : "=r"(z) :); return {x, y, z}; #else - TILELANG_CHECK(false, "CLUSTER_ENABLED is not defined"); + assert(false && "CLUSTER_ENABLED is not defined"); + return {0, 0, 0}; #endif } @@ -92,7 +96,8 @@ TL_DEVICE dim3 block_id_in_cluster() { asm volatile("mov.u32 %0, %%cluster_ctaid.z;\n" : "=r"(z) :); return {x, y, z}; #else - TILELANG_CHECK(false, "CLUSTER_ENABLED is not defined"); + assert(false && "CLUSTER_ENABLED is not defined"); + return {0, 0, 0}; #endif } @@ -103,7 +108,8 @@ TL_DEVICE uint32_t block_rank_in_cluster() { asm volatile("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank) :); return rank; #else - TILELANG_CHECK(false, "CLUSTER_ENABLED is not defined"); + assert(false && "CLUSTER_ENABLED is not defined"); + return 0; #endif } @@ -116,7 +122,8 @@ TL_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) { : "r"(smemAddr), "r"(rank)); return result; #else - TILELANG_CHECK(false, "CLUSTER_ENABLED is not defined"); + assert(false && "CLUSTER_ENABLED is not defined"); + return 0; #endif } From 29a7971e45caf818b55c5321195629620cc1ef81 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Sat, 7 Mar 2026 12:26:20 +0800 Subject: [PATCH 06/51] unify cluster function --- docs/programming_guides/cluster_tma.md | 11 +++-------- src/op/builtin.cc | 5 ----- src/op/builtin.h | 7 ------- src/op/copy.cc | 4 ++-- src/target/codegen_cuda.cc | 5 ----- src/tl_templates/cuda/cluster.h | 1 + src/transform/inject_tma_barrier.cc | 2 +- tilelang/language/builtin.py | 8 -------- 8 files changed, 7 insertions(+), 36 deletions(-) diff --git a/docs/programming_guides/cluster_tma.md b/docs/programming_guides/cluster_tma.md index ae9040483..cf7e4fc81 100644 --- a/docs/programming_guides/cluster_tma.md +++ b/docs/programming_guides/cluster_tma.md @@ -1,10 +1,5 @@ # Cluster TMA: Multicast and SM-to-SM Copy -Authors: -- Jingkai He -- Guangda Sun <2012661711@qq.com> - - This page describes two advanced data-movement features that are available on NVIDIA Hopper (SM90) and later: **TMA multicast** and **SM-to-SM cluster copy**. Both features are exposed through extensions to the existing `T.copy` @@ -25,7 +20,7 @@ memory via the `shared::cluster` address space. ```python with T.Kernel(grid_x, grid_y, threads=128, cluster_size=4) as (bx, by): - rank = T.get_cluster_block_rank() # 0..3 within this cluster + rank = T.block_rank_in_cluster() # 0..3 within this cluster cid = T.get_cluster_id() # which cluster am I in nctas = T.get_cluster_block_nums() # always equals cluster_size (4) T.cluster_sync() # barrier across all CTAs in cluster @@ -242,7 +237,7 @@ overhead. | Builtin | Return | Description | |---------|--------|-------------| | `T.get_cluster_id()` | `int32` | Index of this cluster in the grid | -| `T.get_cluster_block_rank()` | `int32` | Block rank (0-indexed) within the cluster | +| `T.block_rank_in_cluster()` | `int32` | Block rank (0-indexed) within the cluster | | `T.get_cluster_block_nums()` | `int32` | Total number of CTAs in the cluster | | `T.cluster_sync()` | — | Barrier synchronisation across all cluster CTAs | | `T.mbarrier_init(bar, count)` | — | Initialise an mbarrier for `count` arrivals | @@ -261,7 +256,7 @@ SM-to-SM copy (saving global-memory round trips). @T.prim_func def split_k_gemm(A, B, C): with T.Kernel(grid_x, grid_y, threads=256, cluster_size=4) as (bx, by): - rank = T.get_cluster_block_rank() + rank = T.block_rank_in_cluster() A_s = T.alloc_shared((BM, BK), "float16") B_s = T.alloc_shared((BK, BN), "float16") C_f = T.alloc_fragment((BM, BN), "float32") diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 0eb238181..775c660c9 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -303,11 +303,6 @@ TIR_DEFINE_TL_BUILTIN(get_cluster_id) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_TL_BUILTIN(get_cluster_block_rank) - .set_num_inputs(0) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kPure)); - TIR_DEFINE_TL_BUILTIN(get_cluster_block_nums) .set_num_inputs(0) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 7f8f8f93d..600707909 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -517,13 +517,6 @@ TVM_DLL const Op &warpgroup_fence_operand(); */ TVM_DLL const Op &get_cluster_id(); -/*! - * \brief Return the block rank within the current cluster. - * - * get_cluster_block_rank() - * - */ -TVM_DLL const Op &get_cluster_block_rank(); /*! * \brief Return the number of blocks in the cluster. diff --git a/src/op/copy.cc b/src/op/copy.cc index d18c37c80..9c471cff3 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1924,7 +1924,7 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, int min_cta_rank = static_cast( __builtin_ctzll(static_cast(cluster_mask))); PrimExpr block_rank = - Call(DataType::Int(32), get_cluster_block_rank(), {}); + Call(DataType::Int(32), block_rank_in_cluster(), {}); PrimExpr mask_imm = IntImm(DataType::Int(32), cluster_mask); // (mask >> block_rank) & 1 == 0 ⟺ block_rank is NOT set in the mask PrimExpr not_in_mask = EQ(bitwise_and(right_shift(mask_imm, block_rank), @@ -1959,7 +1959,7 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, int min_cta_rank = static_cast( __builtin_ctzll(static_cast(cluster_mask))); PrimExpr block_rank = - Call(DataType::Int(32), get_cluster_block_rank(), {}); + Call(DataType::Int(32), block_rank_in_cluster(), {}); PrimExpr mask_imm = IntImm(DataType::Int(32), cluster_mask); PrimExpr not_in_mask = EQ(bitwise_and(right_shift(mask_imm, block_rank), IntImm(DataType::Int(32), 1)), diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 579175921..6a968838b 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -3142,11 +3142,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ICHECK_EQ(op->args.size(), 0) << "tl.get_cluster_id expects no arguments."; this->need_cooperative_groups_ = true; os << "cooperative_groups::this_grid().cluster_rank()"; - } else if (op->op.same_as(tl::get_cluster_block_rank())) { - ICHECK_EQ(op->args.size(), 0) - << "tl.get_cluster_block_rank expects no arguments."; - this->need_cooperative_groups_ = true; - os << "cooperative_groups::this_cluster().block_rank()"; } else if (op->op.same_as(tl::get_cluster_block_nums())) { ICHECK_EQ(op->args.size(), 0) << "tl.get_cluster_block_nums expects no arguments."; diff --git a/src/tl_templates/cuda/cluster.h b/src/tl_templates/cuda/cluster.h index 9d45b4e50..bdc56d6ca 100644 --- a/src/tl_templates/cuda/cluster.h +++ b/src/tl_templates/cuda/cluster.h @@ -16,6 +16,7 @@ TL_DEVICE void cluster_arrive_relaxed() { #if defined(CLUSTER_ENABLED) asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : :); #else +// TILELANG_CHECK is defined as a CUDA error-checking macro that takes a single cudaError_t argument, so we use assert instead assert(false && "CLUSTER_ENABLED is not defined"); #endif } diff --git a/src/transform/inject_tma_barrier.cc b/src/transform/inject_tma_barrier.cc index fd4142c8b..c789ae4a9 100644 --- a/src/transform/inject_tma_barrier.cc +++ b/src/transform/inject_tma_barrier.cc @@ -163,7 +163,7 @@ class TmaExpectTxRewriter : public IRMutatorWithAnalyzer { // Only trigger for EQ conditions that involve the thread variable // (directly or offset like threadIdx.x - 128 == 0 after warp-spec // remapping). Runtime function-call conditions such as - // get_cluster_block_rank() == min_rank must NOT trigger byte injection: + // block_rank_in_cluster() == min_rank must NOT trigger byte injection: // they don't involve any thread variable. bool is_thread_eq = false; if (auto eq = op->condition.as()) { diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 3510737e7..a05ec808e 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -632,14 +632,6 @@ def get_cluster_id() -> PrimExpr: return tir.call_intrin("int32", tir.op.Op.get("tl.get_cluster_id")) -def get_cluster_block_rank() -> PrimExpr: - """Return the block rank within the current cluster. - - Lowers to `tl.get_cluster_block_rank` and emits - `cooperative_groups::this_cluster().block_rank()` in CUDA codegen. - """ - return tir.call_intrin("int32", tir.op.Op.get("tl.get_cluster_block_rank")) - def cluster_block_nums() -> PrimExpr: """Return the number of blocks in the cluster. From 5d57c11486b94399ee24191b2a35103c6dcee0b3 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Sat, 7 Mar 2026 17:24:02 +0800 Subject: [PATCH 07/51] fix pre-commit errors --- src/op/builtin.h | 1 - src/tl_templates/cuda/cluster.h | 3 ++- testing/python/cuda/test_tma_dsmem.py | 2 +- testing/python/cuda/test_tma_multicast_demo.py | 2 +- tilelang/language/builtin.py | 1 - 5 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/op/builtin.h b/src/op/builtin.h index 600707909..7151936a6 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -517,7 +517,6 @@ TVM_DLL const Op &warpgroup_fence_operand(); */ TVM_DLL const Op &get_cluster_id(); - /*! * \brief Return the number of blocks in the cluster. * diff --git a/src/tl_templates/cuda/cluster.h b/src/tl_templates/cuda/cluster.h index bdc56d6ca..f27e8c492 100644 --- a/src/tl_templates/cuda/cluster.h +++ b/src/tl_templates/cuda/cluster.h @@ -16,7 +16,8 @@ TL_DEVICE void cluster_arrive_relaxed() { #if defined(CLUSTER_ENABLED) asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : :); #else -// TILELANG_CHECK is defined as a CUDA error-checking macro that takes a single cudaError_t argument, so we use assert instead + // TILELANG_CHECK is defined as a CUDA error-checking macro that takes a + // single cudaError_t argument, so we use assert instead assert(false && "CLUSTER_ENABLED is not defined"); #endif } diff --git a/testing/python/cuda/test_tma_dsmem.py b/testing/python/cuda/test_tma_dsmem.py index e038fb3eb..27f31c2e7 100644 --- a/testing/python/cuda/test_tma_dsmem.py +++ b/testing/python/cuda/test_tma_dsmem.py @@ -27,7 +27,7 @@ def kernel( B: T.Tensor((N,), "float32"), ): # 2 CTAs in a cluster of size 2 - with T.Kernel(2, threads=128, cluster_dims=(2,1,1)) as pid: + with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: s_src = T.alloc_shared((N,), "float32") s_dst = T.alloc_shared((N,), "float32") s_barrier = T.alloc_shared((1,), "uint64") diff --git a/testing/python/cuda/test_tma_multicast_demo.py b/testing/python/cuda/test_tma_multicast_demo.py index 2952201ec..523f28c8d 100644 --- a/testing/python/cuda/test_tma_multicast_demo.py +++ b/testing/python/cuda/test_tma_multicast_demo.py @@ -39,7 +39,7 @@ def kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128, - cluster_dims=(4,1,1), + cluster_dims=(4, 1, 1), ) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), "float16") T.copy(A[by * block_M, bx * block_N], A_shared, cluster_mask=cluster_mask) diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index a05ec808e..10b28428c 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -632,7 +632,6 @@ def get_cluster_id() -> PrimExpr: return tir.call_intrin("int32", tir.op.Op.get("tl.get_cluster_id")) - def cluster_block_nums() -> PrimExpr: """Return the number of blocks in the cluster. From c2efb8879eb0960ea1b540ba65b1ee4a6bc94fae Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Mon, 9 Mar 2026 10:52:37 +0800 Subject: [PATCH 08/51] docs(cluster_tma): clarify multicast non-issuer behavior in lowering note --- docs/programming_guides/cluster_tma.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/programming_guides/cluster_tma.md b/docs/programming_guides/cluster_tma.md index cf7e4fc81..d1a242812 100644 --- a/docs/programming_guides/cluster_tma.md +++ b/docs/programming_guides/cluster_tma.md @@ -102,7 +102,9 @@ Running the kernel above with `cluster_mask = 0b0011`: - The compiler lowers `cluster_mask != 0` to `cp.async.bulk.tensor.Nd.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster` - for the issuing CTA and a standard `cp.async.bulk.tensor` for the rest. + for the issuing CTA; CTAs in the mask but not elected as issuer receive + passively, and only CTAs outside the mask issue a standard + `cp.async.bulk.tensor`. - Software-pipelining (`T.Pipelined`) is fully supported; the warp-specialized rewriter recognises `tma_load_multicast` as a producer operation. - `cluster_mask` is a compile-time constant; dynamic masks are not supported. From 258c948b0585ae4a9e4735def2d71ceb2a895d4a Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Mon, 9 Mar 2026 10:57:54 +0800 Subject: [PATCH 09/51] fix(copy): gate dst_block cluster lowering on target capability --- src/op/copy.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/op/copy.cc b/src/op/copy.cc index 9c471cff3..3229ba1ae 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -932,6 +932,9 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { T.layout_map, analyzer, /*buffer_oob=*/false, /*in_pipeline=*/T.in_pipeline); if (dst_block.defined()) { + ICHECK(TargetHasBulkCopy(target)) + << "T.copy with dst_block requires cluster-copy support (CUDA SM90+). " + << "Got target=" << target; return LowerClusterCopy(T, analyzer); } if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { From d2d7f1e134b8cc779a4087b203bc3b6cb5ecb587 Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Mon, 9 Mar 2026 15:57:49 +0800 Subject: [PATCH 10/51] fix(cuda/cluster): hard-trap when cluster intrinsics are unavailable instead of assert(false)+dummy return. --- src/tl_templates/cuda/cluster.h | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/tl_templates/cuda/cluster.h b/src/tl_templates/cuda/cluster.h index bdc56d6ca..f2503915d 100644 --- a/src/tl_templates/cuda/cluster.h +++ b/src/tl_templates/cuda/cluster.h @@ -1,7 +1,6 @@ #pragma once #include "common.h" -#include // Config #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ @@ -12,12 +11,13 @@ namespace tl { +TL_DEVICE void cluster_unsupported_trap() { asm volatile("trap;"); } + TL_DEVICE void cluster_arrive_relaxed() { #if defined(CLUSTER_ENABLED) asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : :); #else -// TILELANG_CHECK is defined as a CUDA error-checking macro that takes a single cudaError_t argument, so we use assert instead - assert(false && "CLUSTER_ENABLED is not defined"); + cluster_unsupported_trap(); #endif } @@ -25,7 +25,7 @@ TL_DEVICE void cluster_arrive() { #if defined(CLUSTER_ENABLED) asm volatile("barrier.cluster.arrive.aligned;\n" : :); #else - assert(false && "CLUSTER_ENABLED is not defined"); + cluster_unsupported_trap(); #endif } @@ -33,7 +33,7 @@ TL_DEVICE void cluster_wait() { #if defined(CLUSTER_ENABLED) asm volatile("barrier.cluster.wait.aligned;\n" : :); #else - assert(false && "CLUSTER_ENABLED is not defined"); + cluster_unsupported_trap(); #endif } @@ -42,7 +42,7 @@ TL_DEVICE void cluster_sync() { cluster_arrive(); cluster_wait(); #else - assert(false && "CLUSTER_ENABLED is not defined"); + cluster_unsupported_trap(); #endif } @@ -55,7 +55,7 @@ TL_DEVICE dim3 cluster_grid_dims() { asm volatile("mov.u32 %0, %%nclusterid.z;\n" : "=r"(z) :); return {x, y, z}; #else - assert(false && "CLUSTER_ENABLED is not defined"); + cluster_unsupported_trap(); return {0, 0, 0}; #endif } @@ -69,7 +69,7 @@ TL_DEVICE dim3 cluster_id_in_grid() { asm volatile("mov.u32 %0, %%clusterid.z;\n" : "=r"(z) :); return {x, y, z}; #else - assert(false && "CLUSTER_ENABLED is not defined"); + cluster_unsupported_trap(); return {0, 0, 0}; #endif } @@ -83,7 +83,7 @@ TL_DEVICE dim3 cluster_shape() { asm volatile("mov.u32 %0, %%cluster_nctaid.z;\n" : "=r"(z) :); return {x, y, z}; #else - assert(false && "CLUSTER_ENABLED is not defined"); + cluster_unsupported_trap(); return {0, 0, 0}; #endif } @@ -97,7 +97,7 @@ TL_DEVICE dim3 block_id_in_cluster() { asm volatile("mov.u32 %0, %%cluster_ctaid.z;\n" : "=r"(z) :); return {x, y, z}; #else - assert(false && "CLUSTER_ENABLED is not defined"); + cluster_unsupported_trap(); return {0, 0, 0}; #endif } @@ -109,7 +109,7 @@ TL_DEVICE uint32_t block_rank_in_cluster() { asm volatile("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank) :); return rank; #else - assert(false && "CLUSTER_ENABLED is not defined"); + cluster_unsupported_trap(); return 0; #endif } @@ -123,7 +123,7 @@ TL_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) { : "r"(smemAddr), "r"(rank)); return result; #else - assert(false && "CLUSTER_ENABLED is not defined"); + cluster_unsupported_trap(); return 0; #endif } From 2c67281ece848980e939ad42120064ce759f40d8 Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Mon, 9 Mar 2026 19:09:39 +0800 Subject: [PATCH 11/51] fix inject_tma & lower_hopper & lower_tile bugs --- src/transform/inject_tma_barrier.cc | 34 ++++++++++++++++++++-------- src/transform/lower_hopper_intrin.cc | 4 +++- src/transform/lower_tile_op.cc | 12 ++++------ 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/src/transform/inject_tma_barrier.cc b/src/transform/inject_tma_barrier.cc index c789ae4a9..c8334a64e 100644 --- a/src/transform/inject_tma_barrier.cc +++ b/src/transform/inject_tma_barrier.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -160,18 +161,33 @@ class TmaExpectTxRewriter : public IRMutatorWithAnalyzer { if (op->condition.as()) { flag = op->condition.as()->op.same_as(tl_shuffle_elect()); } - // Only trigger for EQ conditions that involve the thread variable - // (directly or offset like threadIdx.x - 128 == 0 after warp-spec - // remapping). Runtime function-call conditions such as - // block_rank_in_cluster() == min_rank must NOT trigger byte injection: - // they don't involve any thread variable. + // Only trigger for EQ conditions that select exactly one threadIdx.x + // value. The condition must be of the form (threadIdx.x - c == 0), i.e., + // linear in threadIdx.x with coefficient ±1. + // + // Predicates like (threadIdx.x % 32) == 0 or (threadIdx.x & 31) == 0 + // also contain threadIdx.x inside an EQNode but are satisfied by multiple + // threads. Injecting mbarrier_expect_tx under such guards overcounts + // expected transactions and deadlocks the corresponding mbarrier_wait. + // + // Runtime conditions such as get_cluster_block_rank() == min_rank do not + // involve the thread variable at all and must also be excluded. bool is_thread_eq = false; if (auto eq = op->condition.as()) { if (thread_var_.defined()) { - auto f_uses_thread = [this](const VarNode *v) { - return v == thread_var_->var.get(); - }; - is_thread_eq = UsesVar(op->condition, f_uses_thread); + // Compute lhs - rhs and check if it is affine in threadIdx.x with + // coefficient exactly ±1. DetectLinearEquation returns [coef, offset] + // when the expression equals coef*var + offset; an empty array means + // the expression is not linear in the variable (e.g. involves mod or + // bitwise ops). + PrimExpr diff = eq->a - eq->b; + Array coefs = + arith::DetectLinearEquation(diff, {thread_var_->var}); + if (coefs.size() == 2) { + if (auto imm = coefs[0].as()) { + is_thread_eq = (imm->value == 1 || imm->value == -1); + } + } } } if (is_thread_eq || flag) { diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index af119e0bb..827c72f3c 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -206,9 +206,11 @@ class LowerHopperIntrin : public StmtExprMutator { } else if (call->op.same_as(create_list_of_mbarrier())) { // ICHECK(init_mbarrier_calls_.empty()); int num_barriers = static_cast(call->args.size()); + int barrier_base = num_managed_barriers_; num_managed_barriers_ += num_barriers; for (int i = 0; i < num_barriers; i++) { - PrimExpr mbarrier = Call(DataType::Handle(), get_mbarrier(), {i}); + PrimExpr mbarrier = + Call(DataType::Handle(), get_mbarrier(), {barrier_base + i}); init_mbarrier_calls_.push_back(Evaluate( Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(), {mbarrier, call->args[i]}))); diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 5bc46fc02..acab1db3e 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -661,14 +661,10 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { return call; } if (op->op.same_as(tl::tma_store_cluster())) { - // SM-to-SM bulk async copy: only suppress swizzle transformation on the - // access pointers (in_tma_context_), but do NOT set has_tma_ because - // this operation does not use the TMA hardware engine and must not - // trigger warp specialization. - in_tma_context_ = true; - auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); - in_tma_context_ = false; - return call; + // SM-to-SM bulk async copy does not use a tensor-map descriptor, so + // shared-memory swizzle must still be reflected in pointer/index + // remapping. + return Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); } if (is_ptx_) { From db56e9278bf950230b3d68a3f90417d94092e10c Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Mon, 9 Mar 2026 19:41:16 +0800 Subject: [PATCH 12/51] [warp specialize]Fix mbarrier retarget for tma_load_multicast in warp-specialized producer path --- src/transform/warp_specialized_rewriter.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index 0ae6f202c..19c23af74 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -365,14 +365,17 @@ class MbarrierRewriter : public StmtExprMutator { private: PrimExpr VisitExpr_(const CallNode *op) final { auto call = Downcast(StmtExprMutator::VisitExpr_(op)); - if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || + call->op.same_as(tma_load_multicast())) { auto mbar = makeGetBarrier(producer_barrier_idx_); auto arg0 = call->args[0].as(); // Check if this is a 1D TMA load auto is_1d_tma_load = arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && call->op.same_as(tma_load()); - if (is_1d_tma_load) { + if (call->op.same_as(tma_load_multicast())) { + call.CopyOnWrite()->args.Set(1, mbar); + } else if (is_1d_tma_load) { call.CopyOnWrite()->args.Set(2, mbar); } else { Call access_ptr = Downcast(call->args[2]); From 8af6bd0a0dc4c146959d53604d25dd7dbadf3c4b Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Mon, 9 Mar 2026 20:53:47 +0800 Subject: [PATCH 13/51] fix(tma-barrier): correct arrive thread count under equality-guarded threadIdx --- src/transform/inject_tma_barrier.cc | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/transform/inject_tma_barrier.cc b/src/transform/inject_tma_barrier.cc index c8334a64e..b2d91b0a1 100644 --- a/src/transform/inject_tma_barrier.cc +++ b/src/transform/inject_tma_barrier.cc @@ -400,7 +400,23 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { return 1; } auto bound = analyzer_.const_int_bound(thread_var_); - int64_t extent = bound->max_value - bound->min_value + 1; + int64_t min_val = bound->min_value; + int64_t max_val = bound->max_value; + + // TVM's const_int_bound does not tighten the upper bound under equality + // constraints (e.g., "tx - 128 == 0" leaves the range as [128, 255]). + // Use CanProve to detect when the thread variable is provably pinned to + // a single value, which indicates a single-thread arrive guard. + if (min_val != arith::ConstIntBound::kNegInf && min_val < max_val) { + PrimExpr thread_expr = + thread_var_; // IterVar -> Var via operator PrimExpr() + if (analyzer_.CanProve(thread_expr <= + IntImm(thread_expr.dtype(), min_val))) { + return 1; + } + } + + int64_t extent = max_val - min_val + 1; return static_cast(std::max(extent, 1)); } From f2cb665442d5616ac752051f7768b09ef229da78 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 9 Mar 2026 21:53:12 +0800 Subject: [PATCH 14/51] fix: overlapped TMA barrier IDs --- src/transform/lower_hopper_intrin.cc | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index 827c72f3c..d1aac0879 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -19,6 +19,11 @@ namespace tl { using namespace tir; #if (CUDA_MAJOR_VERSION >= 12) +// Barrier IDs 0–2 are reserved: 0 by InjectTmaBarrier (descriptor-based TMA +// loads) and 1–2 by the backend for internal synchronization (e.g. AllReduce). +// User-managed barriers must start after this reserved range. +static constexpr int kReservedBarriers = 3; + class LowerHopperIntrin : public StmtExprMutator { public: static PrimFunc Substitute(PrimFunc &f, bool disable_shuffle_elect) { @@ -138,9 +143,11 @@ class LowerHopperIntrin : public StmtExprMutator { } else { Array stmt_seq; if (num_managed_barriers_ > 0) { - auto alloc_mbarrier = - Evaluate(Call(DataType::Handle(), builtin::create_barriers(), - {num_managed_barriers_})); + // Size must cover reserved slots [0, kReservedBarriers) plus all + // user-managed slots so that IDs never alias. + auto alloc_mbarrier = Evaluate( + Call(DataType::Handle(), builtin::create_barriers(), + {num_managed_barriers_ + kReservedBarriers})); stmt_seq.push_back(alloc_mbarrier); } @@ -206,7 +213,8 @@ class LowerHopperIntrin : public StmtExprMutator { } else if (call->op.same_as(create_list_of_mbarrier())) { // ICHECK(init_mbarrier_calls_.empty()); int num_barriers = static_cast(call->args.size()); - int barrier_base = num_managed_barriers_; + // Offset by kReservedBarriers so user IDs begin after the reserved range. + int barrier_base = num_managed_barriers_ + kReservedBarriers; num_managed_barriers_ += num_barriers; for (int i = 0; i < num_barriers; i++) { PrimExpr mbarrier = From febfd78264383b40d80e3a2f6ef3742009821fef Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 9 Mar 2026 22:11:18 +0800 Subject: [PATCH 15/51] fix: multicast checks --- src/op/copy.cc | 27 +++++++++++++++++++++++++++ src/transform/lower_hopper_intrin.cc | 6 +++--- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index 3229ba1ae..0dfb869de 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -931,6 +931,21 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto copy_inst = GetCopyInst(target, disable_tma_lower || GetDisableTMA(), T.layout_map, analyzer, /*buffer_oob=*/false, /*in_pipeline=*/T.in_pipeline); + // cluster_mask is only honored in the descriptor-based TMA path (kBulkLoad). + // Any other instruction type silently drops the multicast semantics and + // causes masked CTAs to fall back to per-CTA loads, potentially changing + // results when different ranks require different source coordinates. + { + int64_t cluster_mask = GetClusterMask(); + ICHECK(cluster_mask == 0 || copy_inst == CopyInst::kBulkLoad) + << "cluster_mask=0x" << std::hex << cluster_mask + << " requires descriptor-based TMA (kBulkLoad), but this copy was " + "routed to copy_inst=" + << static_cast(copy_inst) + << ". Ensure the copy meets TMA bulk-load constraints. src=" + << src->name << " (scope=" << src.scope() << "), dst=" << dst->name + << " (scope=" << dst.scope() << ")."; + } if (dst_block.defined()) { ICHECK(TargetHasBulkCopy(target)) << "T.copy with dst_block requires cluster-copy support (CUDA SM90+). " @@ -1996,6 +2011,18 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, ICHECK(copy_inst == CopyInst::kBulkLoad1D || copy_inst == CopyInst::kBulkStore1D); + // 1D TMA uses cp.async.bulk which has no multicast variant; a descriptor- + // based bulk load (kBulkLoad) is required instead. + { + int64_t cluster_mask = GetClusterMask(); + ICHECK(cluster_mask == 0) + << "cluster_mask=0x" << std::hex << cluster_mask + << " requires descriptor-based TMA (kBulkLoad); the 1D bulk-copy path " + "(kBulkLoad1D) does not support multicast. src=" + << src->name << " (scope=" << src.scope() << "), dst=" << dst->name + << " (scope=" << dst.scope() << ")."; + } + // Add 1D TMA copy when the global and shared memory is contiguous // Check if shared_tensor->name is present in T.buffer_var_gemm // (Array) to avoid use 1D TMA copy for swizzled layout diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index d1aac0879..92f67acbc 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -145,9 +145,9 @@ class LowerHopperIntrin : public StmtExprMutator { if (num_managed_barriers_ > 0) { // Size must cover reserved slots [0, kReservedBarriers) plus all // user-managed slots so that IDs never alias. - auto alloc_mbarrier = Evaluate( - Call(DataType::Handle(), builtin::create_barriers(), - {num_managed_barriers_ + kReservedBarriers})); + auto alloc_mbarrier = + Evaluate(Call(DataType::Handle(), builtin::create_barriers(), + {num_managed_barriers_ + kReservedBarriers})); stmt_seq.push_back(alloc_mbarrier); } From bf38cfa7a3e6842bc5768eabfed48fca5bbc5dab Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 9 Mar 2026 22:12:03 +0800 Subject: [PATCH 16/51] minor fix --- tilelang/language/copy_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tilelang/language/copy_op.py b/tilelang/language/copy_op.py index 64cc59aa6..e27da8f3b 100644 --- a/tilelang/language/copy_op.py +++ b/tilelang/language/copy_op.py @@ -134,7 +134,7 @@ def copy( if loop_layout is not None and "parallel_loop_layout" not in ann: ann["parallel_loop_layout"] = loop_layout - if dst_block is not None: + if "dst_block" not in ann and dst_block is not None: ann["dst_block"] = dst_block if "cluster_mask" not in ann and cluster_mask is not None: From 4008456bf83d0d99b1616a87621581eaec0fc776 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 9 Mar 2026 22:19:48 +0800 Subject: [PATCH 17/51] minor fix --- docs/programming_guides/cluster_tma.md | 14 +++++++------- tilelang/language/copy_op.py | 16 ++++++++++++---- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/docs/programming_guides/cluster_tma.md b/docs/programming_guides/cluster_tma.md index d1a242812..10ff99c28 100644 --- a/docs/programming_guides/cluster_tma.md +++ b/docs/programming_guides/cluster_tma.md @@ -3,7 +3,7 @@ This page describes two advanced data-movement features that are available on NVIDIA Hopper (SM90) and later: **TMA multicast** and **SM-to-SM cluster copy**. Both features are exposed through extensions to the existing `T.copy` -operator and require a kernel launched with `cluster_size > 1`. +operator and require a kernel launched with thread block cluster, i.e., with `cluster_dims != (1, 1, 1)`. Requirements: - CUDA Compute Capability ≥ 9.0 (Hopper / Blackwell / RTX 5090) @@ -19,10 +19,10 @@ position inside the cluster), and all CTAs can observe each other's shared memory via the `shared::cluster` address space. ```python -with T.Kernel(grid_x, grid_y, threads=128, cluster_size=4) as (bx, by): +with T.Kernel(grid_x, grid_y, threads=128, cluster_dims=(4, 1, 1)) as (bx, by): rank = T.block_rank_in_cluster() # 0..3 within this cluster cid = T.get_cluster_id() # which cluster am I in - nctas = T.get_cluster_block_nums() # always equals cluster_size (4) + nctas = T.get_cluster_block_nums() T.cluster_sync() # barrier across all CTAs in cluster ``` @@ -38,7 +38,7 @@ broadcasts one global tile to every participating CTA simultaneously**, saving repeated DRAM traffic when multiple CTAs in a cluster need the same data (e.g., the same K-panel in a split-K GEMM). -``` +```text Global memory ──TMA multicast──▶ shared memory (rank 0) └─▶ shared memory (rank 1) (same tile, no extra DRAM read) TMA load ──▶ shared memory (rank 2) (independent tile) @@ -74,7 +74,7 @@ def make_tma_multicast_kernel(M, N, block_M, block_N, cluster_mask): T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128, - cluster_size=4, + cluster_dims=(4, 1, 1) ) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), "float16") @@ -159,7 +159,7 @@ def make_cluster_copy_kernel(N: int): A: T.Tensor((N,), "float32"), B: T.Tensor((N,), "float32"), ): - with T.Kernel(2, threads=128, cluster_size=2) as pid: + with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: s_src = T.alloc_shared((N,), "float32") s_dst = T.alloc_shared((N,), "float32") s_barrier = T.alloc_shared((1,), "uint64") @@ -257,7 +257,7 @@ SM-to-SM copy (saving global-memory round trips). ```python @T.prim_func def split_k_gemm(A, B, C): - with T.Kernel(grid_x, grid_y, threads=256, cluster_size=4) as (bx, by): + with T.Kernel(grid_x, grid_y, threads=256, cluster_dims=(4, 1, 1)) as (bx, by): rank = T.block_rank_in_cluster() A_s = T.alloc_shared((BM, BK), "float16") B_s = T.alloc_shared((BK, BN), "float16") diff --git a/tilelang/language/copy_op.py b/tilelang/language/copy_op.py index e27da8f3b..1a13b997d 100644 --- a/tilelang/language/copy_op.py +++ b/tilelang/language/copy_op.py @@ -78,10 +78,18 @@ def copy( dst_block (Optional[Union[int, tir.PrimExpr]], optional): Destination block index for cluster copy. Defaults to None. cluster_mask (Optional[int], keyword-only): Bitmask specifying which CTAs in the cluster - receive the TMA multicast broadcast. When set, the CTA whose rank equals the lowest - set bit in the mask issues ``tma_load_multicast`` (sending data to *all* masked CTAs - simultaneously); every other CTA falls back to a regular ``tma_load`` for its own - shared memory. A value of ``None`` (default) disables multicast and always uses a + participate in a TMA multicast broadcast. The hardware delivers the data to every + masked CTA's shared memory in a single transfer. At runtime the kernel splits into + three cases based on each CTA's rank within the cluster: + + * **Leader** (rank == lowest set bit in mask): issues ``tma_load_multicast``, which + fills the shared memory of *all* masked CTAs simultaneously. + * **Masked peer** (rank is set in mask, but not the lowest): does *nothing* — its + shared memory is written passively by the leader's multicast. + * **Unmasked CTA** (rank is not set in mask): issues a regular ``tma_load`` for its + own shared memory independently. + + A value of ``None`` (default) disables multicast and every CTA issues its own regular TMA load. remote_barrier (Optional[tir.BufferLoad], keyword-only): Shared-memory mbarrier element used for SM-to-SM cluster copy (``dst_block`` must also be set). When provided, the copy is From cb5305fc0445ed786f1de19b330fa0d5e7bdcdf1 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 9 Mar 2026 22:27:27 +0800 Subject: [PATCH 18/51] fix: tma_load_multicast() in the stage-local mbarrier rewrite --- src/transform/warp_specialized_rewriter.cc | 11 ++++++++--- testing/python/cuda/test_tma_dsmem.py | 22 ++++++++++------------ 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index 0ae6f202c..748a1656b 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -365,14 +365,19 @@ class MbarrierRewriter : public StmtExprMutator { private: PrimExpr VisitExpr_(const CallNode *op) final { auto call = Downcast(StmtExprMutator::VisitExpr_(op)); - if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || + call->op.same_as(tma_load_multicast())) { auto mbar = makeGetBarrier(producer_barrier_idx_); auto arg0 = call->args[0].as(); - // Check if this is a 1D TMA load + // Check if this is a 1D TMA load (raw address, no descriptor, no multicast) auto is_1d_tma_load = arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && call->op.same_as(tma_load()); - if (is_1d_tma_load) { + if (call->op.same_as(tma_load_multicast())) { + // tma_load_multicast layout: [desc, mbar, smem_ptr, mask, coords..., eviction] + call.CopyOnWrite()->args.Set(1, mbar); + } else if (is_1d_tma_load) { + // 1D bulk copy layout: [smem_ptr, global_ptr, mbar, bytes, eviction] call.CopyOnWrite()->args.Set(2, mbar); } else { Call access_ptr = Downcast(call->args[2]); diff --git a/testing/python/cuda/test_tma_dsmem.py b/testing/python/cuda/test_tma_dsmem.py index 27f31c2e7..7ad404a7a 100644 --- a/testing/python/cuda/test_tma_dsmem.py +++ b/testing/python/cuda/test_tma_dsmem.py @@ -13,6 +13,7 @@ Block 1 waits on its own s_barrier and then reads the result. """ +import pytest import torch import tilelang import tilelang.language as T @@ -62,11 +63,12 @@ def kernel( return kernel -def main(): +def test_tma_store_cluster(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required") major, minor = torch.cuda.get_device_capability() if major < 9: - print(f"Skipping: requires Compute Capability 9.0+, found {major}.{minor}") - return + pytest.skip(f"requires Compute Capability 9.0+, found {major}.{minor}") N = 128 A = torch.arange(N, dtype=torch.float32, device="cuda") @@ -78,15 +80,11 @@ def main(): result = B.cpu().numpy() expected = A.cpu().numpy() - print("Result (first 8):", result[:8]) - print("Expected(first 8):", expected[:8]) - - if np.allclose(result, expected): - print("PASS: tma_store_cluster copy successful") - else: - diff = np.abs(result - expected).max() - print(f"FAIL: max diff = {diff}") + diff = np.abs(result - expected).max() + assert np.allclose(result, expected), ( + f"tma_store_cluster copy failed: max diff = {diff}" + ) if __name__ == "__main__": - main() + test_tma_store_cluster() From 310ccbbc089a12a748ffbe6225e2bed5f539ec87 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 9 Mar 2026 22:29:58 +0800 Subject: [PATCH 19/51] minor fix --- testing/python/cuda/test_tma_multicast_demo.py | 6 ++++++ tilelang/engine/phase.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/testing/python/cuda/test_tma_multicast_demo.py b/testing/python/cuda/test_tma_multicast_demo.py index 523f28c8d..3e5cb3725 100644 --- a/testing/python/cuda/test_tma_multicast_demo.py +++ b/testing/python/cuda/test_tma_multicast_demo.py @@ -15,6 +15,7 @@ The test verifies multicast by checking that rank 1's B region equals rank 0's A region. """ +import pytest import torch import tilelang import tilelang.language as T @@ -50,6 +51,11 @@ def kernel( def test_tma_multicast_demo(): """Verify TMA multicast: rank 1's B region should equal rank 0's A region within the same cluster.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required") + major, minor = torch.cuda.get_device_capability() + if major < 9: + pytest.skip(f"requires Compute Capability 9.0+, found {major}.{minor}") M, N = 1024, 1024 block_M, block_N = 128, 64 # mask=0b0011: rank 0 multicasts, rank 1 receives, ranks 2/3 each do regular tma_load diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index b24349f8c..3c5136a05 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -292,7 +292,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # MergeSharedMemoryAllocations must be applied after SplitHostDevice # because the merged allocation site is at the beginning of each device function enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) - shared_align_bytes = 128 if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target) else 16 + shared_align_bytes = 128 if allow_tma_lower(pass_ctx=pass_ctx, target=target) else 16 mod = tilelang.transform.MergeSharedMemoryAllocations( enable_aggressive_merge=enable_aggressive_merge, From c50f426f4f5002bf43a944ca3275f21f9b9b0c5d Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 9 Mar 2026 22:30:27 +0800 Subject: [PATCH 20/51] format --- src/transform/warp_specialized_rewriter.cc | 6 ++++-- testing/python/cuda/test_tma_dsmem.py | 4 +--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index 748a1656b..53df46da2 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -369,12 +369,14 @@ class MbarrierRewriter : public StmtExprMutator { call->op.same_as(tma_load_multicast())) { auto mbar = makeGetBarrier(producer_barrier_idx_); auto arg0 = call->args[0].as(); - // Check if this is a 1D TMA load (raw address, no descriptor, no multicast) + // Check if this is a 1D TMA load (raw address, no descriptor, no + // multicast) auto is_1d_tma_load = arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && call->op.same_as(tma_load()); if (call->op.same_as(tma_load_multicast())) { - // tma_load_multicast layout: [desc, mbar, smem_ptr, mask, coords..., eviction] + // tma_load_multicast layout: [desc, mbar, smem_ptr, mask, coords..., + // eviction] call.CopyOnWrite()->args.Set(1, mbar); } else if (is_1d_tma_load) { // 1D bulk copy layout: [smem_ptr, global_ptr, mbar, bytes, eviction] diff --git a/testing/python/cuda/test_tma_dsmem.py b/testing/python/cuda/test_tma_dsmem.py index 7ad404a7a..ea7c81ed6 100644 --- a/testing/python/cuda/test_tma_dsmem.py +++ b/testing/python/cuda/test_tma_dsmem.py @@ -81,9 +81,7 @@ def test_tma_store_cluster(): expected = A.cpu().numpy() diff = np.abs(result - expected).max() - assert np.allclose(result, expected), ( - f"tma_store_cluster copy failed: max diff = {diff}" - ) + assert np.allclose(result, expected), f"tma_store_cluster copy failed: max diff = {diff}" if __name__ == "__main__": From add3de98919ce54feac5b2361f57871b1213f7fd Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 9 Mar 2026 22:34:39 +0800 Subject: [PATCH 21/51] minor fix --- docs/programming_guides/cluster_tma.md | 31 ++++++++++++++++++++------ 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/docs/programming_guides/cluster_tma.md b/docs/programming_guides/cluster_tma.md index 10ff99c28..315fa4e0c 100644 --- a/docs/programming_guides/cluster_tma.md +++ b/docs/programming_guides/cluster_tma.md @@ -273,18 +273,35 @@ def split_k_gemm(A, B, C): T.copy(B[k_off, bx * BN], B_s) T.gemm(A_s, B_s, C_f) - # Phase 2: push partial sums to rank 0 via SM-to-SM copy. - T.copy(C_f, C_s) - if T.get_thread_binding() == 0: - T.mbarrier_init(barrier[0], 1) + # Phase 2: push each rank's partial sums to rank 0 for accumulation. + # + # Use a per-rank staging slot so every non-zero rank writes to a + # distinct destination region — avoiding both a destination race and + # an arrival-count mismatch. Each CTA stores its own partial into + # C_parts[rank]; non-zero ranks then push that slot to the matching + # slot in rank 0's shared memory. + # + # Arrival count must equal the number of producers: cluster_size - 1. + C_parts = T.alloc_shared((4, BM, BN), "float32") # one slot per rank + T.copy(C_f, C_parts[rank]) + + # Only rank 0 needs its barrier initialised (it is the sole consumer). + # Arrival count = 3: ranks 1, 2, and 3 each signal exactly once. + if T.get_thread_binding() == 0 and rank == 0: + T.mbarrier_init(barrier[0], 3) T.cluster_sync() if rank != 0: - T.copy(C_s, C_s, dst_block=0, remote_barrier=barrier[0]) + # Push this rank's slot to the *same* slot index in rank 0's + # C_parts — different offsets, so no destination race. + T.copy(C_parts[rank], C_parts[rank], + dst_block=0, remote_barrier=barrier[0]) + if rank == 0: - T.mbarrier_wait_parity(barrier[0], 0) + T.mbarrier_wait_parity(barrier[0], 0) # wakes after all 3 arrivals + # C_parts[0..3] in rank 0's smem now hold all four partial sums. # accumulate and store ... - T.copy(C_s, C[by * BM, bx * BN]) + T.copy(C_parts[0], C[by * BM, bx * BN]) ``` --- From 531c51f16e7236dda49fcd6d111e0df8c298c242 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 9 Mar 2026 22:48:05 +0800 Subject: [PATCH 22/51] fix: Track allocation from every hoisted barrier init. --- src/transform/lower_hopper_intrin.cc | 30 ++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index 92f67acbc..ae64ff95e 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -142,12 +142,14 @@ class LowerHopperIntrin : public StmtExprMutator { return AttrStmt(op->node, op->attr_key, op->value, body); } else { Array stmt_seq; - if (num_managed_barriers_ > 0) { - // Size must cover reserved slots [0, kReservedBarriers) plus all - // user-managed slots so that IDs never alias. + if (num_required_barriers_ > 0) { + // num_required_barriers_ already accounts for kReservedBarriers + // and covers the highest ID referenced by any barrier init, + // whether from create_list_of_mbarrier() or a direct + // ptx_init_barrier_thread_count() call. auto alloc_mbarrier = Evaluate(Call(DataType::Handle(), builtin::create_barriers(), - {num_managed_barriers_ + kReservedBarriers})); + {num_required_barriers_})); stmt_seq.push_back(alloc_mbarrier); } @@ -186,6 +188,7 @@ class LowerHopperIntrin : public StmtExprMutator { prefetch_calls_.clear(); init_mbarrier_calls_.clear(); num_managed_barriers_ = 0; + num_required_barriers_ = 0; return AttrStmt(op->node, op->attr_key, op->value, result); } } @@ -216,6 +219,9 @@ class LowerHopperIntrin : public StmtExprMutator { // Offset by kReservedBarriers so user IDs begin after the reserved range. int barrier_base = num_managed_barriers_ + kReservedBarriers; num_managed_barriers_ += num_barriers; + // Track the total slots needed: highest assigned ID + 1. + num_required_barriers_ = + std::max(num_required_barriers_, barrier_base + num_barriers); for (int i = 0; i < num_barriers; i++) { PrimExpr mbarrier = Call(DataType::Handle(), get_mbarrier(), {barrier_base + i}); @@ -225,6 +231,16 @@ class LowerHopperIntrin : public StmtExprMutator { } return 0; } else if (call->op.same_as(builtin::ptx_init_barrier_thread_count())) { + // args[0] is get_mbarrier(id); extract id to size the allocation. + if (const auto *mbar_call = call->args[0].as()) { + if (mbar_call->op.same_as(get_mbarrier())) { + if (const auto *id = mbar_call->args[0].as()) { + // Slots needed = id + 1 (IDs are 0-based). + num_required_barriers_ = + std::max(num_required_barriers_, (int)(id->value + 1)); + } + } + } init_mbarrier_calls_.push_back(Evaluate(tvm::ffi::GetRef(call))); return 0; } else { @@ -259,7 +275,13 @@ class LowerHopperIntrin : public StmtExprMutator { private: Array prefetch_calls_; Array init_mbarrier_calls_; + // Tracks the next free user-barrier slot (offset within the user range). int num_managed_barriers_ = 0; + // Tracks 1 + the highest barrier ID referenced by any init call, across + // both create_list_of_mbarrier() and direct ptx_init_barrier_thread_count() + // paths. Used to size create_barriers() so every get_mbarrier(id) has a + // backing slot. + int num_required_barriers_ = 0; std::unordered_map desc_map_; LowerHopperIntrin(bool disable_shuffle_elect) : disable_shuffle_elect_(disable_shuffle_elect) {} From 1086f249c1b2700409ac94db3eeda04e6a9ed1fe Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 9 Mar 2026 22:56:36 +0800 Subject: [PATCH 23/51] fix: Scope the 128-byte alignment to kernels that actually use TMA. --- tilelang/engine/phase.py | 41 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 3c5136a05..be403340f 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -1,11 +1,40 @@ from __future__ import annotations from tvm import tir, IRModule from tvm.target import Target +import tvm import tilelang from tilelang.transform import PassContext from tilelang.contrib.nvcc import have_tma, is_hopper, have_pdl +def _mod_uses_tma_barriers(mod: IRModule) -> bool: + """Return True if any PrimFunc in *mod* contains a ``create_barriers`` call. + + ``LowerHopperIntrin`` emits ``create_barriers`` into every device function + that uses TMA/mbarrier and nowhere else. Checking for that call after + lowering therefore tells us whether a given kernel actually needs the + 128-byte shared-memory alignment that TMA descriptors require, without + forcing the alignment on every kernel that merely runs on a Hopper GPU. + """ + found = [False] + + def _visit(node): + if found[0]: + return + if isinstance(node, tir.Call): + op = getattr(node, "op", None) + if op is not None and getattr(op, "name", "").endswith("create_barriers"): + found[0] = True + + for func in mod.functions.values(): + if found[0]: + break + if isinstance(func, tir.PrimFunc): + tvm.tir.stmt_functor.post_order_visit(func.body, _visit) + + return found[0] + + def allow_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: # avoid circular import from tilelang.jit.adapter.utils import is_cuda_target @@ -292,7 +321,17 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # MergeSharedMemoryAllocations must be applied after SplitHostDevice # because the merged allocation site is at the beginning of each device function enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) - shared_align_bytes = 128 if allow_tma_lower(pass_ctx=pass_ctx, target=target) else 16 + # Use 128-byte alignment only when the lowered IR actually contains TMA/ + # mbarrier usage (signalled by a create_barriers call injected by + # LowerHopperIntrin). Target-level TMA availability alone is not + # sufficient: non-TMA kernels on Hopper would otherwise pick up + # unnecessary padding, inflating SMEM usage and reducing occupancy. + shared_align_bytes = ( + 128 + if allow_tma_lower(pass_ctx=pass_ctx, target=target) + and _mod_uses_tma_barriers(mod) + else 16 + ) mod = tilelang.transform.MergeSharedMemoryAllocations( enable_aggressive_merge=enable_aggressive_merge, From 1996788ef701d55a4e1a8a4888dcf1747c4e8b5a Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 9 Mar 2026 22:56:49 +0800 Subject: [PATCH 24/51] fix: Assert the tl::tma_store_cluster lowering, not just the output. --- testing/python/cuda/test_tma_dsmem.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/testing/python/cuda/test_tma_dsmem.py b/testing/python/cuda/test_tma_dsmem.py index ea7c81ed6..1ebff5a31 100644 --- a/testing/python/cuda/test_tma_dsmem.py +++ b/testing/python/cuda/test_tma_dsmem.py @@ -20,7 +20,6 @@ import numpy as np -@tilelang.jit(verbose=True, execution_backend="cython") def make_store_cluster_kernel(N: int): @T.prim_func def kernel( @@ -71,11 +70,23 @@ def test_tma_store_cluster(): pytest.skip(f"requires Compute Capability 9.0+, found {major}.{minor}") N = 128 - A = torch.arange(N, dtype=torch.float32, device="cuda") - B = torch.zeros(N, dtype=torch.float32, device="cuda") + prim_func = make_store_cluster_kernel(N) + mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") + + # Assert that the lowering actually produced tl::tma_store_cluster. + # The SIMT fallback (map_shared_rank + scalar stores) also copies data + # correctly, so a pure numerical check would miss a regression where + # T.copy(dst_block=..., remote_barrier=...) stops emitting the bulk-async + # cluster intrinsic. + src = mod.get_kernel_source() + assert "tl::tma_store_cluster" in src, ( + "Expected tl::tma_store_cluster in generated kernel source; " + "T.copy(dst_block=..., remote_barrier=...) may have regressed to the " + f"SIMT fallback.\nKernel source:\n{src}" + ) - kernel = make_store_cluster_kernel(N) - kernel(A, B) + A = torch.arange(N, dtype=torch.float32, device="cuda") + B = mod(A) result = B.cpu().numpy() expected = A.cpu().numpy() From 57fd7971eb659357586a14f19159fe91ba8d39a8 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 9 Mar 2026 23:10:37 +0800 Subject: [PATCH 25/51] format fix --- tilelang/engine/phase.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index be403340f..6fbd68ddf 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -326,12 +326,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # LowerHopperIntrin). Target-level TMA availability alone is not # sufficient: non-TMA kernels on Hopper would otherwise pick up # unnecessary padding, inflating SMEM usage and reducing occupancy. - shared_align_bytes = ( - 128 - if allow_tma_lower(pass_ctx=pass_ctx, target=target) - and _mod_uses_tma_barriers(mod) - else 16 - ) + shared_align_bytes = 128 if allow_tma_lower(pass_ctx=pass_ctx, target=target) and _mod_uses_tma_barriers(mod) else 16 mod = tilelang.transform.MergeSharedMemoryAllocations( enable_aggressive_merge=enable_aggressive_merge, From a39604d8cd4ed7bbafa9292f49c782c5e28eeaef Mon Sep 17 00:00:00 2001 From: jingkai he Date: Tue, 10 Mar 2026 00:18:44 +0800 Subject: [PATCH 26/51] fix: Don't assume both if arms transfer the same number of bytes. --- src/transform/inject_tma_barrier.cc | 39 ++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/src/transform/inject_tma_barrier.cc b/src/transform/inject_tma_barrier.cc index c8334a64e..0b2fb593e 100644 --- a/src/transform/inject_tma_barrier.cc +++ b/src/transform/inject_tma_barrier.cc @@ -101,15 +101,42 @@ class TmaTraitsCollector : public StmtExprVisitor { loop_extents = old_loop_evtents; } - // For if/else branches (mutually exclusive), count only the then branch. - // Both branches always transfer the same number of bytes (same tile size), - // so counting either one gives the correct mbarrier expect_tx byte count. + // IfThenElse branches are mutually exclusive: exactly one executes at + // runtime, so the single unconditional mbarrier_expect_tx that the rewriter + // injects must match the byte count for *whichever* branch runs. When an + // else branch is present, collect both sides independently (preserving the + // loop_extents context), verify they carry the same byte count, and advance + // the running total by that count. Asymmetric arms (e.g. a tail partition + // or a passive multicast receiver that moves a different number of bytes) + // cannot be represented by a single expect_tx and will fail the check below. void VisitStmt_(const IfThenElseNode *op) final { - if (op->else_case.defined()) { - StmtExprVisitor::VisitStmt(op->then_case); - } else { + if (!op->else_case.defined()) { + // No else arm: standard traversal visits only the then branch. StmtExprVisitor::VisitStmt_(op); + return; } + + // Save the running total accumulated by outer context. + PrimExpr base = bulk_copy_bytes; + + // Collect then branch. + bulk_copy_bytes = 0; + StmtExprVisitor::VisitStmt(op->then_case); + PrimExpr then_bytes = bulk_copy_bytes; + + // Collect else branch. + bulk_copy_bytes = 0; + StmtExprVisitor::VisitStmt(op->else_case.value()); + PrimExpr else_bytes = bulk_copy_bytes; + + ICHECK(StructuralEqual()(then_bytes, else_bytes)) + << "IfThenElse branches carry different TMA byte counts: " + << "then=" << then_bytes << " else=" << else_bytes + << ". A single unconditional mbarrier_expect_tx cannot represent both " + "paths. Ensure both arms of this branch transfer the same tile or " + "restructure so expect_tx is issued per-branch."; + + bulk_copy_bytes = base + then_bytes; } PrimExpr bulk_copy_bytes = 0; From d1c7bb84a960f2929c58b0ab1d9da3308989e06f Mon Sep 17 00:00:00 2001 From: jingkai he Date: Tue, 10 Mar 2026 01:07:03 +0800 Subject: [PATCH 27/51] fix: Don't wrap the kept pipelined loop in a SeqStmt before the For post-processing runs. --- src/transform/lower_hopper_intrin.cc | 13 +++++++++++ src/transform/warp_specialized_rewriter.cc | 25 +++++++++++++--------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index ae64ff95e..109feb1d1 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -243,6 +243,19 @@ class LowerHopperIntrin : public StmtExprMutator { } init_mbarrier_calls_.push_back(Evaluate(tvm::ffi::GetRef(call))); return 0; + } else if (call->op.same_as(get_mbarrier())) { + // All get_mbarrier(i) calls that reach this branch carry a pipeline + // stage index (0, 1, …) emitted by InjectPipeline, InjectTmaBarrier, + // or MbarrierRewriter. Shift by kReservedBarriers so they address + // the correct slot in the barrier array, which starts user-managed + // barriers after the reserved prefix. + // + // get_mbarrier calls that live inside create_list_of_mbarrier or + // ptx_init_barrier_thread_count are consumed by their own early-return + // branches above and never reach this handler, so there is no risk of + // double-shifting. + return Call(call->dtype, get_mbarrier(), + {call->args[0] + kReservedBarriers}); } else { return StmtExprMutator::VisitExpr_(call); } diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index 53df46da2..82f4adde8 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -1208,13 +1208,6 @@ class WSCodeEmitter : public StmtMutator { } auto result = FilterByRole(op); - if (is_dynamic_extent) { - Array zero_indices = {0}; - PrimExpr new_prefix = LoadRaggedPrefix() + op->extent; - Stmt update = BufferStore(ragged_prefix_buf_, new_prefix, zero_indices); - result = SeqStmt({result, update}); - } - Stmt grouped_for_node; if (result.as() && group_anno && !group_info_array.empty() && !is_emitting_producer_) { @@ -1237,12 +1230,24 @@ class WSCodeEmitter : public StmtMutator { for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order"); for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage"); } + // Choose between the grouped and plain loop before potentially wrapping. + Stmt final_node; if (is_emitting_producer_ || !group_anno || group_info_array.empty()) { - loop_stack_.pop_back(); - return for_node; + final_node = for_node; + } else { + final_node = grouped_for_node; } loop_stack_.pop_back(); - return grouped_for_node; + // Append the ragged-prefix update after the fully post-processed loop. + // This is deferred to here (rather than applied early before annotation + // stripping) so the ForNode remains visible to GroupOpRewriter above. + if (is_dynamic_extent) { + Array zero_indices = {0}; + PrimExpr new_prefix = LoadRaggedPrefix() + op->extent; + Stmt update = BufferStore(ragged_prefix_buf_, new_prefix, zero_indices); + return SeqStmt({final_node, update}); + } + return final_node; } loop_stack_.pop_back(); return result; From 4244e0764e7343643ead3cab31eb339ba6cc6552 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Tue, 10 Mar 2026 01:11:09 +0800 Subject: [PATCH 28/51] rm mbarrier_init and use alloc_cluster_barrier --- docs/programming_guides/cluster_tma.md | 17 ++++------------- testing/python/cuda/test_tma_dsmem.py | 7 +------ tilelang/language/builtin.py | 19 ------------------- 3 files changed, 5 insertions(+), 38 deletions(-) diff --git a/docs/programming_guides/cluster_tma.md b/docs/programming_guides/cluster_tma.md index 315fa4e0c..62313d2bd 100644 --- a/docs/programming_guides/cluster_tma.md +++ b/docs/programming_guides/cluster_tma.md @@ -162,15 +162,11 @@ def make_cluster_copy_kernel(N: int): with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: s_src = T.alloc_shared((N,), "float32") s_dst = T.alloc_shared((N,), "float32") - s_barrier = T.alloc_shared((1,), "uint64") + s_barrier = T.alloc_cluster_barrier([1]) T.fill(s_src, 0.0) T.fill(s_dst, 0.0) - # Each CTA initialises its own barrier: 1 expected arrival. - if T.get_thread_binding() == 0: - T.mbarrier_init(s_barrier[0], 1) - T.cluster_sync() if pid == 0: @@ -226,8 +222,7 @@ overhead. ### Notes - Both paths require `src` and `dst` to be in `shared` or `shared.dyn` scope. -- The mbarrier must be allocated with `T.alloc_shared((count,), "uint64")` and - initialised with `T.mbarrier_init` before use. +- The mbarrier must be allocated with `T.alloc_cluster_barrier([arrive_count])`. - `T.cluster_sync()` after allocation but before the copy is required to ensure all CTAs have reached the barrier-init barrier before any data is pushed. - `dst_block` may be a compile-time integer or a runtime `tir.PrimExpr`. @@ -242,7 +237,7 @@ overhead. | `T.block_rank_in_cluster()` | `int32` | Block rank (0-indexed) within the cluster | | `T.get_cluster_block_nums()` | `int32` | Total number of CTAs in the cluster | | `T.cluster_sync()` | — | Barrier synchronisation across all cluster CTAs | -| `T.mbarrier_init(bar, count)` | — | Initialise an mbarrier for `count` arrivals | +| `T.alloc_cluster_barrier([count])` | `Buffer` | Allocate and initialise an mbarrier for `count` arrivals | | `T.mbarrier_arrive(bar)` | — | Signal one arrival on an mbarrier | | `T.mbarrier_wait_parity(bar, parity)` | — | Wait until `bar` flips to `parity` | @@ -263,7 +258,7 @@ def split_k_gemm(A, B, C): B_s = T.alloc_shared((BK, BN), "float16") C_f = T.alloc_fragment((BM, BN), "float32") C_s = T.alloc_shared((BM, BN), "float32") - barrier = T.alloc_shared((1,), "uint64") + barrier = T.alloc_cluster_barrier([3]) T.clear(C_f) # Phase 1: each CTA loads its K-slice; A is multicast to rank 0 and 1. @@ -285,10 +280,6 @@ def split_k_gemm(A, B, C): C_parts = T.alloc_shared((4, BM, BN), "float32") # one slot per rank T.copy(C_f, C_parts[rank]) - # Only rank 0 needs its barrier initialised (it is the sole consumer). - # Arrival count = 3: ranks 1, 2, and 3 each signal exactly once. - if T.get_thread_binding() == 0 and rank == 0: - T.mbarrier_init(barrier[0], 3) T.cluster_sync() if rank != 0: diff --git a/testing/python/cuda/test_tma_dsmem.py b/testing/python/cuda/test_tma_dsmem.py index 1ebff5a31..5fe5893ae 100644 --- a/testing/python/cuda/test_tma_dsmem.py +++ b/testing/python/cuda/test_tma_dsmem.py @@ -30,16 +30,11 @@ def kernel( with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: s_src = T.alloc_shared((N,), "float32") s_dst = T.alloc_shared((N,), "float32") - s_barrier = T.alloc_shared((1,), "uint64") + s_barrier = T.alloc_cluster_barrier([1]) T.fill(s_src, 0.0) T.fill(s_dst, 0.0) - # Every CTA initialises its own barrier: expect 1 arrival - # carrying N*4 bytes (the cp.async.bulk signals on completion). - if T.get_thread_binding() == 0: - T.mbarrier_init(s_barrier[0], 1) - T.cluster_sync() if pid == 0: diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 10b28428c..2e8716198 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -424,25 +424,6 @@ def mbarrier_wait_parity(mbarrier: BarrierType, parity: int | Var): return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity) -def mbarrier_init(mbarrier: int | PrimExpr | tir.Call, arrive_count: int | PrimExpr): - """Initialize a memory barrier. - - Args: - mbarrier: The memory barrier to initialize - arrive_count: The expected arrival count - """ - if isinstance(mbarrier, (tir.Call, tir.BufferLoad)): - mbarrier = mbarrier - elif isinstance(mbarrier, (tir.PrimExpr, int)): - mbarrier = _get_mbarrier(mbarrier) - elif isinstance(mbarrier, tir.Buffer): - mbarrier = tir.BufferLoad(mbarrier, [0]) - else: - raise TypeError(f"mbarrier must be an integer or a tir.Call, but got {type(mbarrier)}") - - return tir.call_intrin("handle", tir.op.Op.get("tir.ptx_init_barrier_thread_count"), mbarrier, arrive_count) - - def mbarrier_arrive(mbarrier: BarrierType, cta_id: int | Var | None = None): """Arrive at memory barrier. From e5b00c53ec38a83a5ff600343deb33ab241e98bc Mon Sep 17 00:00:00 2001 From: jingkai he Date: Tue, 10 Mar 2026 01:50:52 +0800 Subject: [PATCH 29/51] fix: Don't make cluster copy a global exemption for unrelated TMA loads. --- src/transform/warp_specialized_rewriter.h | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/transform/warp_specialized_rewriter.h b/src/transform/warp_specialized_rewriter.h index 44881d1e6..68aed0882 100644 --- a/src/transform/warp_specialized_rewriter.h +++ b/src/transform/warp_specialized_rewriter.h @@ -39,11 +39,12 @@ class WarpSpecializedDetector : public IRVisitorWithAnalyzer { "specialization is manually enabled"; return true; } - // When mbarrier ops coexist with TMA loads but tma_store_cluster is also - // present, the barriers are for SM-to-SM cluster copy synchronisation and - // should not block auto warp specialisation. - if (detector.has_tma_op_ && detector.has_mbarrier_op_ && - !detector.has_cluster_copy_) { + // TMA loads with mbarrier ops indicate a user-managed pipeline that is + // incompatible with auto warp specialisation. The cluster-copy exemption + // only applies when mbarrier ops exist *without* any regular TMA-load + // path (i.e. barriers are solely for SM-to-SM cluster copy sync), which + // naturally falls through here because has_tma_op_ is false in that case. + if (detector.has_tma_op_ && detector.has_mbarrier_op_) { LOG(WARNING) << "Auto warp specialization will be disabled because TMA " "and mbarrier are both present"; return true; @@ -55,7 +56,6 @@ class WarpSpecializedDetector : public IRVisitorWithAnalyzer { has_tma_op_ = false; has_mbarrier_op_ = false; has_warp_specialization_ = false; - has_cluster_copy_ = false; } private: @@ -78,9 +78,6 @@ class WarpSpecializedDetector : public IRVisitorWithAnalyzer { op->op.same_as(set_max_nreg())) { has_tma_op_ = true; } - if (op->op.same_as(tma_store_cluster())) { - has_cluster_copy_ = true; - } IRVisitorWithAnalyzer::VisitExpr_(op); } @@ -103,7 +100,6 @@ class WarpSpecializedDetector : public IRVisitorWithAnalyzer { IterVar thread_var_; bool has_mbarrier_op_{false}; bool has_warp_specialization_{false}; - bool has_cluster_copy_{false}; }; } // namespace tl From 25f0aa19538f1178a359bf5e5944f3cc2651cdac Mon Sep 17 00:00:00 2001 From: jingkai he Date: Tue, 10 Mar 2026 01:51:13 +0800 Subject: [PATCH 30/51] fix: Don't drop dependency modeling for handle-based mbarrier_wait_parity(). --- src/tl_templates/cuda/copy_sm90.h | 6 ++++-- src/transform/pipeline_planning.cc | 21 +++++++++++---------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/tl_templates/cuda/copy_sm90.h b/src/tl_templates/cuda/copy_sm90.h index 53145933f..dd89bfa30 100644 --- a/src/tl_templates/cuda/copy_sm90.h +++ b/src/tl_templates/cuda/copy_sm90.h @@ -39,9 +39,11 @@ TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr, } // Generic SM-to-SM async bulk copy via cp.async.bulk.shared::cluster +template TL_DEVICE void tma_store_cluster(void *dst, void *src, int dst_cta, - uint32_t size_bytes, uint64_t &bar) { - uint32_t mbarrier_ptr = static_cast(__cvta_generic_to_shared(&bar)); + uint32_t size_bytes, BarrierType &bar) { + uint32_t mbarrier_ptr = static_cast( + __cvta_generic_to_shared(reinterpret_cast(&bar))); uint32_t src_ptr = static_cast(__cvta_generic_to_shared(src)); uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); diff --git a/src/transform/pipeline_planning.cc b/src/transform/pipeline_planning.cc index e36281c35..afaa3dac0 100644 --- a/src/transform/pipeline_planning.cc +++ b/src/transform/pipeline_planning.cc @@ -369,11 +369,6 @@ class BufferRegionCollector : public StmtExprVisitor { this->VisitExpr(op->args[i]); } } else if (op->op.same_as(tl::mbarrier_wait_parity())) { - // mbarrier_wait_parity may take either a BufferLoad (preferred, allows - // linking to associated async dependencies) or a target-specific handle - // expression (e.g. tl.get_mbarrier(id)). For the latter case, we cannot - // associate the barrier with a concrete Buffer, so we conservatively - // fall back to normal traversal. if (const auto *load = args[0].as()) { Buffer mbar_buf = load->buffer; auto buffer_reads = @@ -385,13 +380,19 @@ class BufferRegionCollector : public StmtExprVisitor { buffer_reads->second.end()); } if (buffer_writes != chain_builder_.mbar_to_buffer_writes_.end()) { - writes_.insert( - writes_.end(), - chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).begin(), - chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).end()); + writes_.insert(writes_.end(), buffer_writes->second.begin(), + buffer_writes->second.end()); } } else { - StmtExprVisitor::VisitExpr_(op); + // Handle-based mbarrier (e.g. get_mbarrier(id)): cannot resolve to a + // concrete Buffer. Conservatively attach all known async buffer + // dependencies so the wait is not treated as dependency-free. + for (const auto &[_, regions] : chain_builder_.mbar_to_buffer_reads_) { + reads_.insert(reads_.end(), regions.begin(), regions.end()); + } + for (const auto &[_, regions] : chain_builder_.mbar_to_buffer_writes_) { + writes_.insert(writes_.end(), regions.begin(), regions.end()); + } } } else { StmtExprVisitor::VisitExpr_(op); From f0d4169b41cfa281d3407b7c7a48614b547bef1c Mon Sep 17 00:00:00 2001 From: jingkai he Date: Tue, 10 Mar 2026 01:59:17 +0800 Subject: [PATCH 31/51] fix: Don't hoist extracted barrier-init if statements to the block root. --- src/transform/warp_specialized_rewriter.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index 82f4adde8..aad010dca 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -1491,13 +1491,18 @@ class UserBarrierInitExtractor : public StmtMutator { std::vector init_stmts; Stmt VisitStmt_(const IfThenElseNode *op) final { - if (IsOnlyInit(op->then_case)) { + if (!op->else_case.defined() && IsOnlyInit(op->then_case)) { init_stmts.push_back(GetRef(op)); return Evaluate(0); } return StmtMutator::VisitStmt_(op); } + // Don't descend into scope-creating nodes; extracting an init from inside + // an Allocate/LetStmt would hoist it out of the variable's scope. + Stmt VisitStmt_(const AllocateNode *op) final { return GetRef(op); } + Stmt VisitStmt_(const LetStmtNode *op) final { return GetRef(op); } + bool IsOnlyInit(const Stmt &stmt) { if (const auto *eval = stmt.as()) { if (const auto *call = eval->value.as()) { From 8a947bb3b690c4c07a7194c3dee38340a694ce87 Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Tue, 10 Mar 2026 09:57:31 +0800 Subject: [PATCH 32/51] fix(copy): honor remapped dst layout in cluster-copy slow path --- src/op/copy.cc | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index 0dfb869de..8dcfbc51c 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1570,8 +1570,9 @@ Stmt CopyNode::LowerClusterCopy(const LowerArgs &T, class ClusterCopyReplacer : public StmtExprMutator { public: ClusterCopyReplacer(const Buffer &dst, PrimExpr dst_block, - const Buffer &target_dst) - : dst_(dst), dst_block_(dst_block), target_dst_(target_dst) {} + const Buffer &target_dst, Optional dst_layout) + : dst_(dst), dst_block_(dst_block), target_dst_(target_dst), + dst_layout_(dst_layout) {} Stmt VisitStmt_(const BufferStoreNode *op) final { if (op->buffer.same_as(dst_)) { @@ -1580,15 +1581,23 @@ Stmt CopyNode::LowerClusterCopy(const LowerArgs &T, args.push_back(op->value); // The value to store args.push_back(dst_block_); // The destination block index - // linearize the index. - PrimExpr linearized_index = op->indices[0]; - if (op->indices.size() > 1) { + // Compute the physical linear index in the target buffer. + // When dst is remapped to a layout-transformed shared buffer, we must + // forward logical indices through that layout before flattening. + Array physical_indices = op->indices; + if (!target_dst_.same_as(dst_) && dst_layout_.defined()) { + physical_indices = dst_layout_.value()->Forward(op->indices); + } + + PrimExpr linearized_index = physical_indices[0]; + if (physical_indices.size() > 1) { PrimExpr multiplier = 1; linearized_index = 0; - for (int i = op->indices.size() - 1; i >= 0; --i) { - linearized_index = linearized_index + op->indices[i] * multiplier; + for (int i = physical_indices.size() - 1; i >= 0; --i) { + linearized_index = + linearized_index + physical_indices[i] * multiplier; if (i > 0) { - multiplier = multiplier * op->buffer->shape[i]; + multiplier = multiplier * target_dst_->shape[i]; } } } @@ -1620,6 +1629,7 @@ Stmt CopyNode::LowerClusterCopy(const LowerArgs &T, const Buffer &dst_; PrimExpr dst_block_; const Buffer &target_dst_; + Optional dst_layout_; }; Buffer target_dst = dst; @@ -1627,8 +1637,13 @@ Stmt CopyNode::LowerClusterCopy(const LowerArgs &T, target_dst = T.buffer_remap[dst]; } - return ClusterCopyReplacer(dst, dst_block.value(), - target_dst)(vectorized_thread_loop); + Optional dst_layout = std::nullopt; + if (T.layout_map.count(dst)) { + dst_layout = T.layout_map[dst]; + } + + return ClusterCopyReplacer(dst, dst_block.value(), target_dst, + dst_layout)(vectorized_thread_loop); } // Lowers copy to a bulk TMA (Tensor Memory Accelerator) transfer. From 6ca3e73abc1ee14aff85fb21a7743b94e2691bb2 Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Tue, 10 Mar 2026 10:44:36 +0800 Subject: [PATCH 33/51] fix(copy): gate cluster TMA fast path on provable contiguity --- src/op/copy.cc | 124 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 90 insertions(+), 34 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index 8dcfbc51c..6207751f2 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1492,45 +1492,101 @@ Stmt CopyNode::LowerClusterCopy(const LowerArgs &T, // the destination CTA waits on its local copy of the same mbarrier. // --------------------------------------------------------------------------- if (auto barrier_opt = GetBarrier()) { - PrimExpr barrier_load = barrier_opt.value(); - - // Compute linear offsets from the copy ranges (one offset per buffer). - auto compute_linear_offset = [](const Buffer &buf, - const Array &ranges) -> PrimExpr { - PrimExpr offset = 0; - PrimExpr stride = 1; - for (int i = static_cast(ranges.size()) - 1; i >= 0; --i) { - offset = offset + ranges[i]->min * stride; - if (i > 0) - stride = stride * buf->shape[i]; + // Bulk cluster copy issues a single flat byte-span transfer. This is only + // correct when both src/dst regions are contiguous in row-major storage. + auto is_contiguous_region = [&](const Buffer &buf, + const Array &ranges) -> bool { + ICHECK_EQ(buf->shape.size(), ranges.size()) + << "Buffer/range rank mismatch for " << buf->name; + + int pivot = -1; + for (int i = 0; i < static_cast(ranges.size()); ++i) { + if (!analyzer->CanProveEqual(ranges[i]->extent, 1)) { + pivot = i; + break; + } + } + // Scalar region is contiguous. + if (pivot == -1) { + return true; + } + + // Outer dimensions must be fixed (extent == 1). + for (int i = 0; i < pivot; ++i) { + if (!analyzer->CanProveEqual(ranges[i]->extent, 1)) { + return false; + } } - return offset; + + // Inner dimensions must be full-span [0, shape[i]) to avoid strides. + for (int i = pivot + 1; i < static_cast(ranges.size()); ++i) { + if (!analyzer->CanProveEqual(ranges[i]->min, 0) || + !analyzer->CanProveEqual(ranges[i]->extent, buf->shape[i])) { + return false; + } + } + return true; }; - PrimExpr dst_offset = compute_linear_offset(dst, dst_range); - PrimExpr src_offset = compute_linear_offset(src, src_range); + bool src_contiguous = is_contiguous_region(src, src_range); + bool dst_contiguous = is_contiguous_region(dst, dst_range); - // Total number of elements to transfer. - PrimExpr total_elements = 1; + PrimExpr src_elements = 1; for (auto r : src_range) - total_elements = total_elements * r->extent; - PrimExpr size_bytes = - cast(DataType::UInt(32), total_elements * src->dtype.bytes()); - - // Build tvm_access_ptr arguments. These are processed by LowerTileOp's - // HandleAccessPtrAndOffset which, for TMA ops (in_tma_context_=true), - // keeps the raw linear offset without applying any swizzle transformation. - PrimExpr dst_ptr = - dst.access_ptr(2, DataType::Handle(), 1, dst_offset, total_elements); - PrimExpr src_ptr = - src.access_ptr(1, DataType::Handle(), 1, src_offset, total_elements); - - Stmt bulk_copy = Evaluate( - Call(DataType::Handle(), tma_store_cluster(), - {dst_ptr, src_ptr, dst_block.value(), size_bytes, barrier_load})); - - // Single-thread guard: only thread_bounds->min issues the instruction. - return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), bulk_copy); + src_elements = src_elements * r->extent; + PrimExpr dst_elements = 1; + for (auto r : dst_range) + dst_elements = dst_elements * r->extent; + bool element_match = analyzer->CanProveEqual(src_elements, dst_elements); + + if (!(src_contiguous && dst_contiguous && element_match)) { + LOG(WARNING) + << "Falling back to element-wise cluster copy: bulk cluster fast " + "path requires contiguous src/dst regions with matching element " + "counts. src=" + << src->name << ", dst=" << dst->name; + } else { + PrimExpr barrier_load = barrier_opt.value(); + + // Compute linear offsets from the copy ranges (one offset per buffer). + auto compute_linear_offset = [](const Buffer &buf, + const Array &ranges) -> PrimExpr { + PrimExpr offset = 0; + PrimExpr stride = 1; + for (int i = static_cast(ranges.size()) - 1; i >= 0; --i) { + offset = offset + ranges[i]->min * stride; + if (i > 0) + stride = stride * buf->shape[i]; + } + return offset; + }; + + PrimExpr dst_offset = compute_linear_offset(dst, dst_range); + PrimExpr src_offset = compute_linear_offset(src, src_range); + + // Total number of elements to transfer. + PrimExpr total_elements = 1; + for (auto r : src_range) + total_elements = total_elements * r->extent; + PrimExpr size_bytes = + cast(DataType::UInt(32), total_elements * src->dtype.bytes()); + + // Build tvm_access_ptr arguments. These are processed by LowerTileOp's + // HandleAccessPtrAndOffset which, for TMA ops (in_tma_context_=true), + // keeps the raw linear offset without applying any swizzle + // transformation. + PrimExpr dst_ptr = + dst.access_ptr(2, DataType::Handle(), 1, dst_offset, total_elements); + PrimExpr src_ptr = + src.access_ptr(1, DataType::Handle(), 1, src_offset, total_elements); + + Stmt bulk_copy = Evaluate(Call( + DataType::Handle(), tma_store_cluster(), + {dst_ptr, src_ptr, dst_block.value(), size_bytes, barrier_load})); + + // Single-thread guard: only thread_bounds->min issues the instruction. + return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), bulk_copy); + } } // --------------------------------------------------------------------------- From 989044ee88f1f097eed97f93bc37850aa533bc7b Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Tue, 10 Mar 2026 11:13:13 +0800 Subject: [PATCH 34/51] fix(copy): add barrier completion for SIMT cluster-copy fallback --- src/op/copy.cc | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index 6207751f2..821cffb65 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1698,8 +1698,27 @@ Stmt CopyNode::LowerClusterCopy(const LowerArgs &T, dst_layout = T.layout_map[dst]; } - return ClusterCopyReplacer(dst, dst_block.value(), target_dst, - dst_layout)(vectorized_thread_loop); + Stmt simt_copy = ClusterCopyReplacer(dst, dst_block.value(), target_dst, + dst_layout)(vectorized_thread_loop); + + // When a remote_barrier is supplied but the fast path (tma_store_cluster) is + // unavailable (e.g. non-contiguous layout), the SIMT stores are not tracked + // by any hardware completion mechanism. Auto-generate: + // __syncthreads(); + // if (threadIdx.x == 0) { s_barrier[0].arrive(u); } + // so the destination CTA can still wait on the barrier as usual, without + // requiring the caller to insert these statements manually. + if (auto barrier_opt = GetBarrier()) { + Stmt sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), + {StringImm("shared")})); + Stmt arrive = + Evaluate(Call(DataType::Handle(), ptx_arrive_cluster_barrier(), + {barrier_opt.value(), dst_block.value()})); + Stmt guarded_arrive = + IfThenElse(EQ(T.thread_var, T.thread_bounds->min), arrive); + return SeqStmt({simt_copy, sync, guarded_arrive}); + } + return simt_copy; } // Lowers copy to a bulk TMA (Tensor Memory Accelerator) transfer. From 2a26453f2044b0f5da9b299ad2df443d8120e329 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Wed, 11 Mar 2026 20:49:16 +0800 Subject: [PATCH 35/51] fix: remove dup cluster_sync --- tilelang/language/builtin.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 2e8716198..d59ef8eba 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -622,10 +622,6 @@ def cluster_block_nums() -> PrimExpr: return tir.call_intrin("int32", tir.op.Op.get("tl.get_cluster_block_nums")) -def cluster_sync(): - return tir.call_intrin("handle", tir.op.Op.get("tl.cluster_sync")) - - def shuffle_elect(thread_extent: int) -> PrimExpr: """Elect exactly one lane within a logical thread group. From c07a79c4d5ab49f55e70f498e8e53e85e094d8d9 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Wed, 11 Mar 2026 20:51:12 +0800 Subject: [PATCH 36/51] fix: remove not used _get_mbarrier --- tilelang/language/builtin.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index d59ef8eba..32313db82 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -29,13 +29,6 @@ def _normalize_index_arg(value: int | PrimExpr | None) -> PrimExpr | None: raise TypeError(f"Expect warp sizing argument to be int or PrimExpr, but got {type(value)}.") -def _get_mbarrier(barrier_id: int | PrimExpr): - """Create an intermediate mbarrier handle from barrier id for internal lowering only.""" - raise NotImplementedError( - "Direct mbarrier handle creation from id is not supported in the frontend. Use T.alloc_barrier to create mbarriers instead." - ) - - def _mbar_to_buffer_load(mbar: BarrierType) -> BufferLoad: """Convert a memory barrier to a buffer load. From a2a1fc6986b3fa92438a2908bb2db8952a8fef93 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Wed, 11 Mar 2026 21:01:04 +0800 Subject: [PATCH 37/51] fix: remove mbarrier related code in codegen_cuda.cc and barrier.h --- src/target/codegen_cuda.cc | 18 ++++++++---------- src/tl_templates/cuda/barrier.h | 30 ------------------------------ 2 files changed, 8 insertions(+), 40 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 6a968838b..d620e5287 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1819,8 +1819,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { auto mbarrier_obj = print_mbarrier_obj(op->args[0]); auto cta_id = this->PrintExpr(op->args[1]); auto pred = this->PrintExpr(op->args[2]); - this->stream << "tl::mbarrier_arrive(" << mbarrier_obj << ", " << cta_id - << ", " << pred << ");\n"; + this->stream << mbarrier_obj << ".arrive(" << cta_id << ", " << pred + << ");\n"; } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { ICHECK_EQ(op->args.size(), 1); this->PrintIndent(); @@ -1840,14 +1840,13 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->PrintIndent(); auto mbarrier_obj = this->PrintExpr(op->args[0]); auto arrive_count = this->PrintExpr(op->args[1]); - this->stream << "tl::mbarrier_init(" << mbarrier_obj << ", " << arrive_count - << ");\n"; + this->stream << mbarrier_obj << ".init(" << arrive_count << ");\n"; } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { if (op->args.size() == 2) { this->PrintIndent(); auto mbarrier_obj = this->PrintExpr(op->args[0]); auto transaction_bytes = this->PrintExpr(op->args[1]); - this->stream << "tl::mbarrier_arrive_expect_tx(" << mbarrier_obj << ", " + this->stream << mbarrier_obj << ".arrive_and_expect_tx(" << transaction_bytes << ");\n"; } else if (op->args.size() == 4) { this->PrintIndent(); @@ -1855,7 +1854,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { auto transaction_bytes = this->PrintExpr(op->args[1]); auto cta_id = this->PrintExpr(op->args[2]); auto pred = this->PrintExpr(op->args[3]); - this->stream << "tl::mbarrier_arrive_expect_tx(" << mbarrier_obj << ", " + this->stream << mbarrier_obj << ".arrive_and_expect_tx(" << transaction_bytes << ", " << cta_id << ", " << pred << ");\n"; } else { @@ -1873,15 +1872,14 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->PrintIndent(); auto mbarrier_obj = this->PrintExpr(op->args[0]); auto transaction_bytes = this->PrintExpr(op->args[1]); - this->stream << "tl::mbarrier_expect_tx(" << mbarrier_obj << ", " - << transaction_bytes << ");\n"; + this->stream << mbarrier_obj << ".expect_transaction(" << transaction_bytes + << ");\n"; } else if (op->op.same_as(tl::mbarrier_wait_parity())) { ICHECK_EQ(op->args.size(), 2); this->PrintIndent(); auto mbarrier_obj = this->PrintExpr(op->args[0]); auto phase = this->PrintExpr(op->args[1]); - this->stream << "tl::mbarrier_wait(" << mbarrier_obj << ", " << phase - << ");\n"; + this->stream << mbarrier_obj << ".wait(" << phase << ");\n"; } else if (op->op.same_as(tl::ptx_init_tensor_memory())) { print_extern_call_stmt("tl::tmem_allocate"); } else if (op->op.same_as(tl::ptx_deallocate_tensor_memory())) { diff --git a/src/tl_templates/cuda/barrier.h b/src/tl_templates/cuda/barrier.h index b8fc5d841..79a57f7df 100644 --- a/src/tl_templates/cuda/barrier.h +++ b/src/tl_templates/cuda/barrier.h @@ -8,10 +8,6 @@ using Barrier = cutlass::arch::ClusterTransactionBarrier; namespace tl { -TL_DEVICE void mbarrier_init(Barrier &barrier, uint32_t arrive_count) { - barrier.init(arrive_count); -} - TL_DEVICE void mbarrier_init(uint64_t &smem_barrier, uint32_t arrive_count) { uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); asm volatile("mbarrier.init.shared.b64 [%1], %0;" @@ -35,10 +31,6 @@ TL_DEVICE uint32_t mbarrier_try_wait(uint64_t &smem_barrier, int phase_bit) { return waitComplete; } -TL_DEVICE void mbarrier_wait(Barrier &barrier, int phase_bit) { - barrier.wait(phase_bit); -} - TL_DEVICE void mbarrier_wait(uint64_t &smem_barrier, int phase_bit) { if (mbarrier_try_wait(smem_barrier, phase_bit) == 0) { uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); @@ -73,17 +65,11 @@ TL_DEVICE void mbarrier_test_wait(uint64_t &smem_barrier, int phase_bit) { "r"(phase_bit)); } -TL_DEVICE void mbarrier_arrive(Barrier &barrier) { barrier.arrive(); } - TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier) { uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(smem_int_ptr)); } -TL_DEVICE void mbarrier_arrive(Barrier &barrier, int cta_id, uint32_t pred) { - barrier.arrive(cta_id, pred); -} - TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier, int cta_id, uint32_t pred) { uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); @@ -98,11 +84,6 @@ TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier, int cta_id, } } -TL_DEVICE void mbarrier_expect_tx(Barrier &barrier, - uint32_t transaction_bytes) { - barrier.expect_transaction(transaction_bytes); -} - TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier, uint32_t transaction_bytes) { uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); @@ -111,17 +92,6 @@ TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier, : "r"(transaction_bytes), "r"(smem_int_ptr)); } -TL_DEVICE void mbarrier_arrive_expect_tx(Barrier &barrier, - uint32_t transaction_bytes) { - barrier.arrive_and_expect_tx(transaction_bytes); -} - -TL_DEVICE void mbarrier_arrive_expect_tx(Barrier &barrier, - uint32_t transaction_bytes, int cta_id, - uint32_t pred) { - barrier.arrive_and_expect_tx(transaction_bytes, cta_id, pred); -} - TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t &smem_barrier, uint32_t transaction_bytes) { uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); From 85cebfbbecd50585ddb1094179e8eadeba0bc777 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Wed, 11 Mar 2026 21:20:56 +0800 Subject: [PATCH 38/51] fix: codegen_cuda.cc: reuse functions in cluster.h and remove dependency on cg --- docs/programming_guides/cluster_tma.md | 4 +--- src/target/codegen_cuda.cc | 8 ++++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/docs/programming_guides/cluster_tma.md b/docs/programming_guides/cluster_tma.md index 62313d2bd..bda710bc4 100644 --- a/docs/programming_guides/cluster_tma.md +++ b/docs/programming_guides/cluster_tma.md @@ -21,8 +21,7 @@ memory via the `shared::cluster` address space. ```python with T.Kernel(grid_x, grid_y, threads=128, cluster_dims=(4, 1, 1)) as (bx, by): rank = T.block_rank_in_cluster() # 0..3 within this cluster - cid = T.get_cluster_id() # which cluster am I in - nctas = T.get_cluster_block_nums() + nctas = T.get_cluster_block_nums() # total CTAs in this cluster T.cluster_sync() # barrier across all CTAs in cluster ``` @@ -233,7 +232,6 @@ overhead. | Builtin | Return | Description | |---------|--------|-------------| -| `T.get_cluster_id()` | `int32` | Index of this cluster in the grid | | `T.block_rank_in_cluster()` | `int32` | Block rank (0-indexed) within the cluster | | `T.get_cluster_block_nums()` | `int32` | Total number of CTAs in the cluster | | `T.cluster_sync()` | — | Barrier synchronisation across all cluster CTAs | diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index d620e5287..bbebf75e1 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -3138,13 +3138,13 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { os << ")"; } else if (op->op.same_as(tl::get_cluster_id())) { ICHECK_EQ(op->args.size(), 0) << "tl.get_cluster_id expects no arguments."; - this->need_cooperative_groups_ = true; - os << "cooperative_groups::this_grid().cluster_rank()"; + need_cluster_h_ = true; + os << "tl::block_rank_in_cluster()"; } else if (op->op.same_as(tl::get_cluster_block_nums())) { ICHECK_EQ(op->args.size(), 0) << "tl.get_cluster_block_nums expects no arguments."; - this->need_cooperative_groups_ = true; - os << "cooperative_groups::this_cluster().num_blocks()"; + need_cluster_h_ = true; + os << "([]{auto s=tl::cluster_shape();return (int)(s.x*s.y*s.z);}())"; } else if (op->op.same_as(tl::tl_shuffle_elect())) { os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()"; } else if (op->op.same_as(tl::initialize_wgmma_descriptor())) { From a5da2a3aaa38a5bc3dfc6fe3d1120600308d6068 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Thu, 12 Mar 2026 14:14:29 +0800 Subject: [PATCH 39/51] fix testting --- testing/python/cuda/test_tma_dsmem.py | 11 ++++------- testing/python/cuda/test_tma_multicast_demo.py | 11 ++++------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/testing/python/cuda/test_tma_dsmem.py b/testing/python/cuda/test_tma_dsmem.py index 5fe5893ae..1c1c08734 100644 --- a/testing/python/cuda/test_tma_dsmem.py +++ b/testing/python/cuda/test_tma_dsmem.py @@ -13,10 +13,10 @@ Block 1 waits on its own s_barrier and then reads the result. """ -import pytest import torch import tilelang import tilelang.language as T +import tilelang.testing import numpy as np @@ -57,12 +57,9 @@ def kernel( return kernel +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_tma_store_cluster(): - if not torch.cuda.is_available(): - pytest.skip("CUDA is required") - major, minor = torch.cuda.get_device_capability() - if major < 9: - pytest.skip(f"requires Compute Capability 9.0+, found {major}.{minor}") N = 128 prim_func = make_store_cluster_kernel(N) @@ -91,4 +88,4 @@ def test_tma_store_cluster(): if __name__ == "__main__": - test_tma_store_cluster() + tilelang.testing.main() diff --git a/testing/python/cuda/test_tma_multicast_demo.py b/testing/python/cuda/test_tma_multicast_demo.py index 3e5cb3725..d7f4486aa 100644 --- a/testing/python/cuda/test_tma_multicast_demo.py +++ b/testing/python/cuda/test_tma_multicast_demo.py @@ -15,10 +15,10 @@ The test verifies multicast by checking that rank 1's B region equals rank 0's A region. """ -import pytest import torch import tilelang import tilelang.language as T +import tilelang.testing def make_tma_multicast_demo_kernel(M, N, block_M, block_N, cluster_mask): @@ -49,13 +49,10 @@ def kernel( return kernel +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_tma_multicast_demo(): """Verify TMA multicast: rank 1's B region should equal rank 0's A region within the same cluster.""" - if not torch.cuda.is_available(): - pytest.skip("CUDA is required") - major, minor = torch.cuda.get_device_capability() - if major < 9: - pytest.skip(f"requires Compute Capability 9.0+, found {major}.{minor}") M, N = 1024, 1024 block_M, block_N = 128, 64 # mask=0b0011: rank 0 multicasts, rank 1 receives, ranks 2/3 each do regular tma_load @@ -106,4 +103,4 @@ def test_tma_multicast_demo(): if __name__ == "__main__": - test_tma_multicast_demo() + tilelang.testing.main() From 99c800c91dd21230d198af5b237ccfcb91633f7e Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Thu, 12 Mar 2026 14:17:22 +0800 Subject: [PATCH 40/51] fix(transform): prevent ragged_prefix free var from breaking MakePackedAPI --- .../multi_version_buffer_rewriter.cc | 28 ++++++++++++++++++- src/transform/warp_specialized_rewriter.cc | 11 ++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/transform/multi_version_buffer_rewriter.cc b/src/transform/multi_version_buffer_rewriter.cc index 791ca0faf..571c69a1c 100644 --- a/src/transform/multi_version_buffer_rewriter.cc +++ b/src/transform/multi_version_buffer_rewriter.cc @@ -134,7 +134,9 @@ class MultiVersionBufferRewriter : public StmtExprMutator { Var buffer_var = buffer->data; rewriter.buffer_data_to_buffer_.Set(buffer_var, buffer); } - f.CopyOnWrite()->body = rewriter(f->body); + Stmt rewritten = rewriter(f->body); + rewritten = rewriter.FinalizeRaggedPrefixAllocation(std::move(rewritten)); + f.CopyOnWrite()->body = std::move(rewritten); return f; } @@ -156,6 +158,21 @@ class MultiVersionBufferRewriter : public StmtExprMutator { return BufferLoad(ragged_prefix_buf_, zero_indices); } + Stmt FinalizeRaggedPrefixAllocation(Stmt body) { + if (!needs_ragged_prefix_ || inserted_ragged_prefix_) { + return body; + } + EnsureRaggedPrefixBuffer(); + Array zero_indices = {0}; + Stmt init = BufferStore(ragged_prefix_buf_, IntImm(DataType::Int(32), 0), + zero_indices); + Stmt seq = SeqStmt({init, body}); + seq = DeclBuffer(ragged_prefix_buf_, seq); + inserted_ragged_prefix_ = true; + return Allocate(ragged_prefix_buf_->data, ragged_prefix_buf_->dtype, + ragged_prefix_buf_->shape, const_true(), seq); + } + Array GetVersionedBuffers(const Array &seq_stmt, const Array &scoped_buffers) { Array pipeline_stmts; @@ -352,6 +369,15 @@ class MultiVersionBufferRewriter : public StmtExprMutator { ICHECK(num_stages_anno->as()); int num_stages = static_cast(num_stages_anno->as()->value); + // A single-stage software pipeline does not require multi-versioned + // buffers or ragged-prefix bookkeeping; keep the loop unchanged. + if (num_stages <= 1) { + auto for_node = StmtExprMutator::VisitStmt_(op); + loop_stack_.pop_back(); + stmt_stack_.pop_back(); + return for_node; + } + Stmt pipeline_body_root{nullptr}; if (const auto *realize = op->body.as()) { const auto &block = realize->block; diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index 48927c8bb..f12c24f54 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -252,6 +252,17 @@ class WarpSpecializedRoleMarker : public StmtVisitor { void VisitStmt_(const AllocateNode *op) final { StmtVisitor::VisitStmt_(op); + auto it = buffer_data_to_buffer_.find(op->buffer_var); + if (it != buffer_data_to_buffer_.end()) { + const Buffer &buf = (*it).second; + // Ragged-prefix bookkeeping buffers are referenced by both producer and + // consumer rewrites. Keep their allocation in both paths. + if (buf->name == "tl_mvb_ragged_prefix" || + buf->name == "tl_ws_ragged_prefix") { + SetRole(op, Role::kBoth); + return; + } + } Role role = Role::kConsumer; SetRole(op, role); } From 831beb42ca0919b5dd127cf257cc2e1b8d6084d0 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 16 Mar 2026 10:47:04 +0800 Subject: [PATCH 41/51] minor fix --- src/op/builtin.cc | 4 ---- src/op/builtin.h | 7 ------- src/target/codegen_cuda.cc | 4 ---- tilelang/language/builtin.py | 7 ------- 4 files changed, 22 deletions(-) diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 775c660c9..95094a408 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -298,10 +298,6 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(get_cluster_id) - .set_num_inputs(0) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kPure)); TIR_DEFINE_TL_BUILTIN(get_cluster_block_nums) .set_num_inputs(0) diff --git a/src/op/builtin.h b/src/op/builtin.h index 7151936a6..42a8bd11d 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -509,13 +509,6 @@ TVM_DLL const Op &warpgroup_wait(); */ TVM_DLL const Op &warpgroup_fence_operand(); -/*! - * \brief Return the cluster id (rank) of the current block within a cluster. - * - * get_cluster_id() - * - */ -TVM_DLL const Op &get_cluster_id(); /*! * \brief Return the number of blocks in the cluster. diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index bbebf75e1..cdd0a52fd 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -3136,10 +3136,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { os << PrintExpr(op->args[i]); } os << ")"; - } else if (op->op.same_as(tl::get_cluster_id())) { - ICHECK_EQ(op->args.size(), 0) << "tl.get_cluster_id expects no arguments."; - need_cluster_h_ = true; - os << "tl::block_rank_in_cluster()"; } else if (op->op.same_as(tl::get_cluster_block_nums())) { ICHECK_EQ(op->args.size(), 0) << "tl.get_cluster_block_nums expects no arguments."; diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 32313db82..c5fcc6082 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -597,13 +597,6 @@ def get_warp_group_idx( return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_group_idx"), *args) -def get_cluster_id() -> PrimExpr: - """Return the cluster id (rank) of the current block within the cluster. - - This lowers to the intrinsic `tl.get_cluster_id` and is emitted for CUDA - as `cooperative_groups::this_grid().cluster_rank()`. - """ - return tir.call_intrin("int32", tir.op.Op.get("tl.get_cluster_id")) def cluster_block_nums() -> PrimExpr: From b30e926d50b56784bd05173f4e7dbbb860461ddc Mon Sep 17 00:00:00 2001 From: jingkai he Date: Mon, 16 Mar 2026 10:48:22 +0800 Subject: [PATCH 42/51] format --- src/op/builtin.cc | 1 - src/op/builtin.h | 1 - tilelang/language/builtin.py | 2 -- 3 files changed, 4 deletions(-) diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 95094a408..d4fdf12b9 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -298,7 +298,6 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); - TIR_DEFINE_TL_BUILTIN(get_cluster_block_nums) .set_num_inputs(0) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 42a8bd11d..d6b766aeb 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -509,7 +509,6 @@ TVM_DLL const Op &warpgroup_wait(); */ TVM_DLL const Op &warpgroup_fence_operand(); - /*! * \brief Return the number of blocks in the cluster. * diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index c5fcc6082..006adb03e 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -597,8 +597,6 @@ def get_warp_group_idx( return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_group_idx"), *args) - - def cluster_block_nums() -> PrimExpr: """Return the number of blocks in the cluster. From a8753c81649544cd1e1b289ac10196470631df8f Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Mon, 16 Mar 2026 18:51:35 +0800 Subject: [PATCH 43/51] fix inject tma barrier: fix crash and correct mbarrier thread-count injection --- src/transform/inject_tma_barrier.cc | 69 ++++++++++++++++++++++------- 1 file changed, 54 insertions(+), 15 deletions(-) diff --git a/src/transform/inject_tma_barrier.cc b/src/transform/inject_tma_barrier.cc index 19a7b95a9..ed350f8af 100644 --- a/src/transform/inject_tma_barrier.cc +++ b/src/transform/inject_tma_barrier.cc @@ -270,6 +270,16 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer { Map barrier_id_to_range() { return barrier_id_to_range_; } private: + int GetCurrentThreadExtent() { + if (!thread_var_.defined()) { + return 1; + } + auto const_int_bound = analyzer_.const_int_bound(thread_var_); + int64_t extent = + const_int_bound->max_value - const_int_bound->min_value + 1; + return static_cast(std::max(extent, 1)); + } + void UpdateBarrierRange(const PrimExpr &barrier_id, const IntImm &extent) { if (barrier_id_to_range_.count(barrier_id)) { auto old_extent = barrier_id_to_range_[barrier_id]; @@ -293,9 +303,7 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer { for (const auto &tma_call : pending_tma_ops_) { tma_op_to_barrier_id_.Set(tma_call, barrier_id); } - auto const_int_bound = analyzer_.const_int_bound(thread_var_); - auto extent = - const_int_bound->max_value - const_int_bound->min_value + 1; + int extent = GetCurrentThreadExtent(); UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); pending_tma_ops_.clear(); } else if (call->op.same_as(builtin::ptx_cp_async_barrier()) || @@ -308,16 +316,12 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer { for (const auto &tma_call : pending_tma_ops_) { tma_op_to_barrier_id_.Set(tma_call, barrier_id); } - auto const_int_bound = analyzer_.const_int_bound(thread_var_); - auto extent = - const_int_bound->max_value - const_int_bound->min_value + 1; + int extent = GetCurrentThreadExtent(); UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); pending_tma_ops_.clear(); } else if (call->op.same_as(builtin::ptx_wait_barrier())) { PrimExpr barrier_id = call->args[0]; - auto const_int_bound = analyzer_.const_int_bound(thread_var_); - auto extent = - const_int_bound->max_value - const_int_bound->min_value + 1; + int extent = GetCurrentThreadExtent(); UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); } } @@ -408,6 +412,10 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { return barrier_thread_counts_; } + const std::unordered_set &expect_tx_barrier_ids() const { + return expect_tx_barrier_ids_; + } + private: PrimExpr NormalizeBarrierExpr(const PrimExpr &barrier_expr) const { if (const auto *call = barrier_expr.as()) { @@ -424,7 +432,7 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { return 1; } if (!thread_var_.defined()) { - return 1; + return default_thread_count_; } auto bound = analyzer_.const_int_bound(thread_var_); int64_t min_val = bound->min_value; @@ -447,8 +455,8 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { return static_cast(std::max(extent, 1)); } - void UpdateBarrierThreadCount(const PrimExpr &barrier_expr, - int thread_count) { + void UpdateBarrierThreadCount(const PrimExpr &barrier_expr, int thread_count, + bool from_expect_tx) { PrimExpr normalized_barrier_expr = NormalizeBarrierExpr(barrier_expr); if (const auto *imm = normalized_barrier_expr.as()) { @@ -459,6 +467,9 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { } else { it->second = std::max(it->second, thread_count); } + if (from_expect_tx) { + expect_tx_barrier_ids_.insert(id); + } return; } @@ -477,6 +488,9 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { } else { it->second = std::max(it->second, thread_count); } + if (from_expect_tx) { + expect_tx_barrier_ids_.insert(id); + } } } @@ -500,6 +514,12 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { bool is_elect_if = false; if (const auto *call = op->condition.as()) { is_elect_if = call->op.same_as(tl_shuffle_elect()); + if (is_elect_if && call->args.size() >= 1) { + if (const auto *imm = call->args[0].as()) { + default_thread_count_ = + std::max(default_thread_count_, static_cast(imm->value)); + } + } } if (is_elect_if) { bool old_inside = inside_elect_if_; @@ -515,28 +535,41 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { } void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(tl_shuffle_elect()) && op->args.size() >= 1) { + if (const auto *imm = op->args[0].as()) { + default_thread_count_ = + std::max(default_thread_count_, static_cast(imm->value)); + } + } if (op->op.same_as(builtin::ptx_arrive_barrier()) || op->op.same_as(builtin::ptx_arrive_barrier_expect_tx()) || op->op.same_as(builtin::ptx_cp_async_barrier()) || op->op.same_as(tl::ptx_cp_async_barrier_noinc())) { ICHECK_GE(op->args.size(), 1); - UpdateBarrierThreadCount(op->args[0], GetCurrentThreadCount()); + bool from_expect_tx = + op->op.same_as(builtin::ptx_arrive_barrier_expect_tx()); + UpdateBarrierThreadCount(op->args[0], GetCurrentThreadCount(), + from_expect_tx); } IRVisitorWithAnalyzer::VisitExpr_(op); } IterVar thread_var_; bool inside_elect_if_{false}; + int default_thread_count_{1}; Map var_int_set_; std::unordered_map barrier_thread_counts_; + std::unordered_set expect_tx_barrier_ids_; }; class BarrierCreationRewriter : public StmtExprMutator { public: BarrierCreationRewriter(std::unordered_map barrier_thread_counts, + std::unordered_set expect_tx_barrier_ids, int ensure_min_count = 0, PrimExpr default_barrier_thread_count = 1) : barrier_thread_counts_(std::move(barrier_thread_counts)), + expect_tx_barrier_ids_(std::move(expect_tx_barrier_ids)), ensure_min_count_(ensure_min_count), default_barrier_thread_count_(std::move(default_barrier_thread_count)) { } @@ -554,7 +587,11 @@ class BarrierCreationRewriter : public StmtExprMutator { for (size_t i{0}; i < cur_n; ++i) { auto it = barrier_thread_counts_.find(static_cast(i)); if (it != barrier_thread_counts_.end()) { - new_args.push_back(Integer(it->second)); + PrimExpr updated = Integer(it->second); + if (!expect_tx_barrier_ids_.count(static_cast(i))) { + updated = op->args[i] + updated; + } + new_args.push_back(updated); } else { new_args.push_back(op->args[i]); } @@ -576,6 +613,7 @@ class BarrierCreationRewriter : public StmtExprMutator { private: std::unordered_map barrier_thread_counts_; + std::unordered_set expect_tx_barrier_ids_; int ensure_min_count_{0}; PrimExpr default_barrier_thread_count_{1}; }; @@ -640,7 +678,8 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { // Default appended barriers to leader-only (=1), but prefer explicit // arrive-domain counts collected from actual arrive sites. auto barrier_creation_rewriter = BarrierCreationRewriter( - arrive_thread_count_collector.barrier_thread_counts(), ensure_min_count, + arrive_thread_count_collector.barrier_thread_counts(), + arrive_thread_count_collector.expect_tx_barrier_ids(), ensure_min_count, Integer(1)); f.CopyOnWrite()->body = barrier_creation_rewriter(f->body); return f; From 9abd5ec976574691891418cb5750fe24f1d7a46a Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Mon, 16 Mar 2026 20:16:04 +0800 Subject: [PATCH 44/51] revert inject tma --- src/transform/inject_tma_barrier.cc | 69 +++++++---------------------- 1 file changed, 15 insertions(+), 54 deletions(-) diff --git a/src/transform/inject_tma_barrier.cc b/src/transform/inject_tma_barrier.cc index ed350f8af..19a7b95a9 100644 --- a/src/transform/inject_tma_barrier.cc +++ b/src/transform/inject_tma_barrier.cc @@ -270,16 +270,6 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer { Map barrier_id_to_range() { return barrier_id_to_range_; } private: - int GetCurrentThreadExtent() { - if (!thread_var_.defined()) { - return 1; - } - auto const_int_bound = analyzer_.const_int_bound(thread_var_); - int64_t extent = - const_int_bound->max_value - const_int_bound->min_value + 1; - return static_cast(std::max(extent, 1)); - } - void UpdateBarrierRange(const PrimExpr &barrier_id, const IntImm &extent) { if (barrier_id_to_range_.count(barrier_id)) { auto old_extent = barrier_id_to_range_[barrier_id]; @@ -303,7 +293,9 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer { for (const auto &tma_call : pending_tma_ops_) { tma_op_to_barrier_id_.Set(tma_call, barrier_id); } - int extent = GetCurrentThreadExtent(); + auto const_int_bound = analyzer_.const_int_bound(thread_var_); + auto extent = + const_int_bound->max_value - const_int_bound->min_value + 1; UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); pending_tma_ops_.clear(); } else if (call->op.same_as(builtin::ptx_cp_async_barrier()) || @@ -316,12 +308,16 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer { for (const auto &tma_call : pending_tma_ops_) { tma_op_to_barrier_id_.Set(tma_call, barrier_id); } - int extent = GetCurrentThreadExtent(); + auto const_int_bound = analyzer_.const_int_bound(thread_var_); + auto extent = + const_int_bound->max_value - const_int_bound->min_value + 1; UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); pending_tma_ops_.clear(); } else if (call->op.same_as(builtin::ptx_wait_barrier())) { PrimExpr barrier_id = call->args[0]; - int extent = GetCurrentThreadExtent(); + auto const_int_bound = analyzer_.const_int_bound(thread_var_); + auto extent = + const_int_bound->max_value - const_int_bound->min_value + 1; UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); } } @@ -412,10 +408,6 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { return barrier_thread_counts_; } - const std::unordered_set &expect_tx_barrier_ids() const { - return expect_tx_barrier_ids_; - } - private: PrimExpr NormalizeBarrierExpr(const PrimExpr &barrier_expr) const { if (const auto *call = barrier_expr.as()) { @@ -432,7 +424,7 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { return 1; } if (!thread_var_.defined()) { - return default_thread_count_; + return 1; } auto bound = analyzer_.const_int_bound(thread_var_); int64_t min_val = bound->min_value; @@ -455,8 +447,8 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { return static_cast(std::max(extent, 1)); } - void UpdateBarrierThreadCount(const PrimExpr &barrier_expr, int thread_count, - bool from_expect_tx) { + void UpdateBarrierThreadCount(const PrimExpr &barrier_expr, + int thread_count) { PrimExpr normalized_barrier_expr = NormalizeBarrierExpr(barrier_expr); if (const auto *imm = normalized_barrier_expr.as()) { @@ -467,9 +459,6 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { } else { it->second = std::max(it->second, thread_count); } - if (from_expect_tx) { - expect_tx_barrier_ids_.insert(id); - } return; } @@ -488,9 +477,6 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { } else { it->second = std::max(it->second, thread_count); } - if (from_expect_tx) { - expect_tx_barrier_ids_.insert(id); - } } } @@ -514,12 +500,6 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { bool is_elect_if = false; if (const auto *call = op->condition.as()) { is_elect_if = call->op.same_as(tl_shuffle_elect()); - if (is_elect_if && call->args.size() >= 1) { - if (const auto *imm = call->args[0].as()) { - default_thread_count_ = - std::max(default_thread_count_, static_cast(imm->value)); - } - } } if (is_elect_if) { bool old_inside = inside_elect_if_; @@ -535,41 +515,28 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { } void VisitExpr_(const CallNode *op) final { - if (op->op.same_as(tl_shuffle_elect()) && op->args.size() >= 1) { - if (const auto *imm = op->args[0].as()) { - default_thread_count_ = - std::max(default_thread_count_, static_cast(imm->value)); - } - } if (op->op.same_as(builtin::ptx_arrive_barrier()) || op->op.same_as(builtin::ptx_arrive_barrier_expect_tx()) || op->op.same_as(builtin::ptx_cp_async_barrier()) || op->op.same_as(tl::ptx_cp_async_barrier_noinc())) { ICHECK_GE(op->args.size(), 1); - bool from_expect_tx = - op->op.same_as(builtin::ptx_arrive_barrier_expect_tx()); - UpdateBarrierThreadCount(op->args[0], GetCurrentThreadCount(), - from_expect_tx); + UpdateBarrierThreadCount(op->args[0], GetCurrentThreadCount()); } IRVisitorWithAnalyzer::VisitExpr_(op); } IterVar thread_var_; bool inside_elect_if_{false}; - int default_thread_count_{1}; Map var_int_set_; std::unordered_map barrier_thread_counts_; - std::unordered_set expect_tx_barrier_ids_; }; class BarrierCreationRewriter : public StmtExprMutator { public: BarrierCreationRewriter(std::unordered_map barrier_thread_counts, - std::unordered_set expect_tx_barrier_ids, int ensure_min_count = 0, PrimExpr default_barrier_thread_count = 1) : barrier_thread_counts_(std::move(barrier_thread_counts)), - expect_tx_barrier_ids_(std::move(expect_tx_barrier_ids)), ensure_min_count_(ensure_min_count), default_barrier_thread_count_(std::move(default_barrier_thread_count)) { } @@ -587,11 +554,7 @@ class BarrierCreationRewriter : public StmtExprMutator { for (size_t i{0}; i < cur_n; ++i) { auto it = barrier_thread_counts_.find(static_cast(i)); if (it != barrier_thread_counts_.end()) { - PrimExpr updated = Integer(it->second); - if (!expect_tx_barrier_ids_.count(static_cast(i))) { - updated = op->args[i] + updated; - } - new_args.push_back(updated); + new_args.push_back(Integer(it->second)); } else { new_args.push_back(op->args[i]); } @@ -613,7 +576,6 @@ class BarrierCreationRewriter : public StmtExprMutator { private: std::unordered_map barrier_thread_counts_; - std::unordered_set expect_tx_barrier_ids_; int ensure_min_count_{0}; PrimExpr default_barrier_thread_count_{1}; }; @@ -678,8 +640,7 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { // Default appended barriers to leader-only (=1), but prefer explicit // arrive-domain counts collected from actual arrive sites. auto barrier_creation_rewriter = BarrierCreationRewriter( - arrive_thread_count_collector.barrier_thread_counts(), - arrive_thread_count_collector.expect_tx_barrier_ids(), ensure_min_count, + arrive_thread_count_collector.barrier_thread_counts(), ensure_min_count, Integer(1)); f.CopyOnWrite()->body = barrier_creation_rewriter(f->body); return f; From 0ec43749ce0e1765636ecfadebd26e835297830b Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Mon, 16 Mar 2026 20:56:11 +0800 Subject: [PATCH 45/51] fix(inject_tma_barrier): avoid SIGSEGV and correctly infer barrier arrive thread counts --- src/transform/inject_tma_barrier.cc | 49 +++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/src/transform/inject_tma_barrier.cc b/src/transform/inject_tma_barrier.cc index 19a7b95a9..9461908e5 100644 --- a/src/transform/inject_tma_barrier.cc +++ b/src/transform/inject_tma_barrier.cc @@ -270,6 +270,16 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer { Map barrier_id_to_range() { return barrier_id_to_range_; } private: + int GetCurrentThreadExtent() { + if (!thread_var_.defined()) { + return 1; + } + auto const_int_bound = analyzer_.const_int_bound(thread_var_); + int64_t extent = + const_int_bound->max_value - const_int_bound->min_value + 1; + return static_cast(std::max(extent, 1)); + } + void UpdateBarrierRange(const PrimExpr &barrier_id, const IntImm &extent) { if (barrier_id_to_range_.count(barrier_id)) { auto old_extent = barrier_id_to_range_[barrier_id]; @@ -293,9 +303,7 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer { for (const auto &tma_call : pending_tma_ops_) { tma_op_to_barrier_id_.Set(tma_call, barrier_id); } - auto const_int_bound = analyzer_.const_int_bound(thread_var_); - auto extent = - const_int_bound->max_value - const_int_bound->min_value + 1; + int extent = GetCurrentThreadExtent(); UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); pending_tma_ops_.clear(); } else if (call->op.same_as(builtin::ptx_cp_async_barrier()) || @@ -308,16 +316,12 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer { for (const auto &tma_call : pending_tma_ops_) { tma_op_to_barrier_id_.Set(tma_call, barrier_id); } - auto const_int_bound = analyzer_.const_int_bound(thread_var_); - auto extent = - const_int_bound->max_value - const_int_bound->min_value + 1; + int extent = GetCurrentThreadExtent(); UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); pending_tma_ops_.clear(); } else if (call->op.same_as(builtin::ptx_wait_barrier())) { PrimExpr barrier_id = call->args[0]; - auto const_int_bound = analyzer_.const_int_bound(thread_var_); - auto extent = - const_int_bound->max_value - const_int_bound->min_value + 1; + int extent = GetCurrentThreadExtent(); UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); } } @@ -424,7 +428,8 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { return 1; } if (!thread_var_.defined()) { - return 1; + int inferred = max_init_barrier_thread_count_ + elect_thread_count_; + return std::max(inferred, 1); } auto bound = analyzer_.const_int_bound(thread_var_); int64_t min_val = bound->min_value; @@ -500,6 +505,12 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { bool is_elect_if = false; if (const auto *call = op->condition.as()) { is_elect_if = call->op.same_as(tl_shuffle_elect()); + if (is_elect_if && !call->args.empty()) { + if (const auto *imm = call->args[0].as()) { + elect_thread_count_ = + std::max(elect_thread_count_, static_cast(imm->value)); + } + } } if (is_elect_if) { bool old_inside = inside_elect_if_; @@ -515,6 +526,22 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { } void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(create_list_of_mbarrier())) { + for (const PrimExpr &arg : op->args) { + if (const auto *imm = arg.as()) { + max_init_barrier_thread_count_ = std::max( + max_init_barrier_thread_count_, static_cast(imm->value)); + } + } + } else if (op->op.same_as(tl_shuffle_elect())) { + if (!op->args.empty()) { + if (const auto *imm = op->args[0].as()) { + elect_thread_count_ = + std::max(elect_thread_count_, static_cast(imm->value)); + } + } + } + if (op->op.same_as(builtin::ptx_arrive_barrier()) || op->op.same_as(builtin::ptx_arrive_barrier_expect_tx()) || op->op.same_as(builtin::ptx_cp_async_barrier()) || @@ -529,6 +556,8 @@ class ArriveThreadCountCollector : public IRVisitorWithAnalyzer { bool inside_elect_if_{false}; Map var_int_set_; std::unordered_map barrier_thread_counts_; + int elect_thread_count_{0}; + int max_init_barrier_thread_count_{0}; }; class BarrierCreationRewriter : public StmtExprMutator { From 79cfee7b8821213a53f705d55bbb3d463227a4df Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Mon, 16 Mar 2026 20:57:35 +0800 Subject: [PATCH 46/51] fix: add tma_load_multicast same as tma_load --- src/transform/inject_fence_proxy.cc | 4 ++-- src/transform/multi_version_buffer_rewriter.cc | 3 ++- src/transform/warp_specialized_rewriter.cc | 6 ++++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transform/inject_fence_proxy.cc b/src/transform/inject_fence_proxy.cc index 14f130e78..97d1b2b3d 100644 --- a/src/transform/inject_fence_proxy.cc +++ b/src/transform/inject_fence_proxy.cc @@ -94,8 +94,8 @@ bool IsAsyncIntrinsic(const CallNode *call) { // TileLang async intrinsics if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || - call->op.same_as(tma_store()) || call->op.same_as(ptx_wgmma_ss()) || - call->op.same_as(ptx_wgmma_rs()) || + call->op.same_as(tma_load_multicast()) || call->op.same_as(tma_store()) || + call->op.same_as(ptx_wgmma_ss()) || call->op.same_as(ptx_wgmma_rs()) || call->op.same_as(ptx_tcgen05_mma_ss()) || call->op.same_as(ptx_tcgen05_mma_ts())) { return true; diff --git a/src/transform/multi_version_buffer_rewriter.cc b/src/transform/multi_version_buffer_rewriter.cc index 571c69a1c..082aacbfe 100644 --- a/src/transform/multi_version_buffer_rewriter.cc +++ b/src/transform/multi_version_buffer_rewriter.cc @@ -41,7 +41,8 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor { void VisitStmt_(const EvaluateNode *op) final { Role role = Role::kConsumer; if (auto call = op->value.as()) { - if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || + call->op.same_as(tma_load_multicast())) { role = Role::kProducer; has_bulk_copy_ = true; } diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index f12c24f54..91dd12880 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -29,7 +29,8 @@ class ProducerBufferDetector : public StmtExprVisitor { void clear() { has_producer_buffer_ = false; } void VisitExpr_(const CallNode *call) final { - if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || + call->op.same_as(tma_load_multicast())) { has_producer_buffer_ = true; } StmtExprVisitor::VisitExpr_(call); @@ -376,7 +377,8 @@ class MbarrierRewriter : public StmtExprMutator { private: PrimExpr VisitExpr_(const CallNode *op) final { auto call = Downcast(StmtExprMutator::VisitExpr_(op)); - if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) || + call->op.same_as(tma_load_multicast())) { auto mbar = makeGetBarrier(producer_barrier_idx_); auto arg0 = call->args[0].as(); // Check if this is a 1D TMA load (raw address, no descriptor, no From 00ea64bf7f5f8e36d992b7286613c3bcf30b93ca Mon Sep 17 00:00:00 2001 From: jingkai he Date: Tue, 17 Mar 2026 10:28:40 +0800 Subject: [PATCH 47/51] T.copy_cluster --- docs/programming_guides/cluster_tma.md | 22 ++--- testing/python/cuda/test_tma_dsmem.py | 6 +- .../python/cuda/test_tma_multicast_demo.py | 2 +- tilelang/language/__init__.py | 2 +- tilelang/language/copy_op.py | 97 +++++++++++++------ 5 files changed, 83 insertions(+), 46 deletions(-) diff --git a/docs/programming_guides/cluster_tma.md b/docs/programming_guides/cluster_tma.md index bda710bc4..5e91ec205 100644 --- a/docs/programming_guides/cluster_tma.md +++ b/docs/programming_guides/cluster_tma.md @@ -47,7 +47,7 @@ Global memory ──TMA multicast──▶ shared memory (rank 0) ### API ```python -T.copy(src_global, dst_shared, cluster_mask=) +T.copy_cluster(src_global, dst_shared, cluster_mask=) ``` `cluster_mask` is a bitmask where each set bit identifies a CTA rank that @@ -80,8 +80,8 @@ def make_tma_multicast_kernel(M, N, block_M, block_N, cluster_mask): # cluster_mask=0b0011: ranks 0 and 1 participate. # Rank 0 issues tma_load_multicast; rank 1 receives passively. # Ranks 2 and 3 each issue a regular tma_load. - T.copy(A[by * block_M, bx * block_N], A_shared, - cluster_mask=cluster_mask) + T.copy_cluster(A[by * block_M, bx * block_N], A_shared, + cluster_mask=cluster_mask) T.copy(A_shared, B[by * block_M, bx * block_N]) @@ -132,7 +132,7 @@ Two sub-variants are provided depending on whether an mbarrier is supplied: ### Fast path — bulk async copy with mbarrier ```python -T.copy(src_shared, dst_shared, dst_block=, remote_barrier=) +T.copy_cluster(src_shared, dst_shared, dst_block=, remote_barrier=) ``` A single elected thread issues one `cp.async.bulk.shared::cluster` instruction. @@ -144,7 +144,7 @@ Steps: 1. Both CTAs allocate the **same** shared memory layout so their mbarriers live at the same offset. 2. Every CTA initialises its own barrier for 1 arrival. -3. The source CTA (`pid == 0` below) calls `T.copy(... dst_block=1, remote_barrier=...)`. +3. The source CTA (`pid == 0` below) calls `T.copy_cluster(... dst_block=1, remote_barrier=...)`. 4. The destination CTA (`pid == 1`) waits on its local barrier copy. ```python @@ -174,8 +174,8 @@ def make_cluster_copy_kernel(N: int): s_src[i] = A[i] # Async-push s_src → s_dst in CTA 1, signal CTA 1's barrier. - T.copy(s_src, s_dst, dst_block=1, - remote_barrier=s_barrier[0]) + T.copy_cluster(s_src, s_dst, dst_block=1, + remote_barrier=s_barrier[0]) if pid == 1: # Wait until CTA 0 finishes writing. @@ -201,7 +201,7 @@ if (((int)threadIdx.x) == 0) { Omit `remote_barrier` to use the slow path: ```python -T.copy(s_src, s_dst, dst_block=1) +T.copy_cluster(s_src, s_dst, dst_block=1) ``` This lowers to a SIMT parallel loop where every thread writes one (or a few) @@ -262,7 +262,7 @@ def split_k_gemm(A, B, C): # Phase 1: each CTA loads its K-slice; A is multicast to rank 0 and 1. for ko in T.Pipelined(T.ceildiv(K, BK * 4), num_stages=3): k_off = (rank + ko * 4) * BK - T.copy(A[by * BM, k_off], A_s, cluster_mask=0b0011) + T.copy_cluster(A[by * BM, k_off], A_s, cluster_mask=0b0011) T.copy(B[k_off, bx * BN], B_s) T.gemm(A_s, B_s, C_f) @@ -283,8 +283,8 @@ def split_k_gemm(A, B, C): if rank != 0: # Push this rank's slot to the *same* slot index in rank 0's # C_parts — different offsets, so no destination race. - T.copy(C_parts[rank], C_parts[rank], - dst_block=0, remote_barrier=barrier[0]) + T.copy_cluster(C_parts[rank], C_parts[rank], + dst_block=0, remote_barrier=barrier[0]) if rank == 0: T.mbarrier_wait_parity(barrier[0], 0) # wakes after all 3 arrivals diff --git a/testing/python/cuda/test_tma_dsmem.py b/testing/python/cuda/test_tma_dsmem.py index 1c1c08734..e19cf74c5 100644 --- a/testing/python/cuda/test_tma_dsmem.py +++ b/testing/python/cuda/test_tma_dsmem.py @@ -44,7 +44,7 @@ def kernel( # Bulk-async copy s_src (local) → s_dst (remote, block 1) # using tl::tma_store_cluster, signalling block 1's barrier. - T.copy(s_src, s_dst, dst_block=1, remote_barrier=s_barrier[0]) + T.copy_cluster(s_src, s_dst, dst_block=1, remote_barrier=s_barrier[0]) if pid == 1: # Wait until block 0 finishes writing to our s_dst. @@ -68,12 +68,12 @@ def test_tma_store_cluster(): # Assert that the lowering actually produced tl::tma_store_cluster. # The SIMT fallback (map_shared_rank + scalar stores) also copies data # correctly, so a pure numerical check would miss a regression where - # T.copy(dst_block=..., remote_barrier=...) stops emitting the bulk-async + # T.copy_cluster(dst_block=..., remote_barrier=...) stops emitting the bulk-async # cluster intrinsic. src = mod.get_kernel_source() assert "tl::tma_store_cluster" in src, ( "Expected tl::tma_store_cluster in generated kernel source; " - "T.copy(dst_block=..., remote_barrier=...) may have regressed to the " + "T.copy_cluster(dst_block=..., remote_barrier=...) may have regressed to the " f"SIMT fallback.\nKernel source:\n{src}" ) diff --git a/testing/python/cuda/test_tma_multicast_demo.py b/testing/python/cuda/test_tma_multicast_demo.py index d7f4486aa..8127533ab 100644 --- a/testing/python/cuda/test_tma_multicast_demo.py +++ b/testing/python/cuda/test_tma_multicast_demo.py @@ -43,7 +43,7 @@ def kernel( cluster_dims=(4, 1, 1), ) as (bx, by): A_shared = T.alloc_shared((block_M, block_N), "float16") - T.copy(A[by * block_M, bx * block_N], A_shared, cluster_mask=cluster_mask) + T.copy_cluster(A[by * block_M, bx * block_N], A_shared, cluster_mask=cluster_mask) T.copy(A_shared, B[by * block_M, bx * block_N]) return kernel diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 9fc05a62d..09f0a5381 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -53,7 +53,7 @@ empty, # noqa: F401 ) from tvm.script.parser.tir import allocate as allocate # noqa: F401 -from .copy_op import copy, async_copy, c2d_im2col # noqa: F401 +from .copy_op import copy, copy_cluster, async_copy, c2d_im2col # noqa: F401 from tilelang.tileop.base import GemmWarpPolicy # noqa: F401 from .gemm_op import gemm, gemm_v1, gemm_v2 # noqa: F401 from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401 diff --git a/tilelang/language/copy_op.py b/tilelang/language/copy_op.py index 1a13b997d..c5ab21333 100644 --- a/tilelang/language/copy_op.py +++ b/tilelang/language/copy_op.py @@ -57,9 +57,6 @@ def copy( eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, annotations: dict | None = None, loop_layout: Any | None = None, - dst_block: int | tir.PrimExpr | None = None, - cluster_mask: int | None = None, - remote_barrier: tir.BufferLoad | None = None, ) -> tir.PrimExpr | tir.Stmt: """Copy data between memory regions. @@ -76,28 +73,6 @@ def copy( (only valid for normal SIMT copy; incompatible with TMA/LDSM/STSM/TMem). When provided, it is attached to the outermost parallel loop generated by this copy. - dst_block (Optional[Union[int, tir.PrimExpr]], optional): Destination block index for cluster copy. Defaults to None. - cluster_mask (Optional[int], keyword-only): Bitmask specifying which CTAs in the cluster - participate in a TMA multicast broadcast. The hardware delivers the data to every - masked CTA's shared memory in a single transfer. At runtime the kernel splits into - three cases based on each CTA's rank within the cluster: - - * **Leader** (rank == lowest set bit in mask): issues ``tma_load_multicast``, which - fills the shared memory of *all* masked CTAs simultaneously. - * **Masked peer** (rank is set in mask, but not the lowest): does *nothing* — its - shared memory is written passively by the leader's multicast. - * **Unmasked CTA** (rank is not set in mask): issues a regular ``tma_load`` for its - own shared memory independently. - - A value of ``None`` (default) disables multicast and every CTA issues its own - regular TMA load. - remote_barrier (Optional[tir.BufferLoad], keyword-only): Shared-memory mbarrier element used for - SM-to-SM cluster copy (``dst_block`` must also be set). When provided, the copy is - performed with a single bulk-async ``tl::tma_store_cluster`` call instead of the - default element-by-element SIMT loop. The barrier must be at the **same** shared - memory offset in every CTA of the cluster (which is automatically true when all CTAs - run the same kernel and declare the same shared-memory layout). The destination CTA - should wait on its local copy of this barrier after the copy completes. Raises: TypeError: If copy extents cannot be deduced from arguments @@ -142,14 +117,76 @@ def copy( if loop_layout is not None and "parallel_loop_layout" not in ann: ann["parallel_loop_layout"] = loop_layout - if "dst_block" not in ann and dst_block is not None: - ann["dst_block"] = dst_block + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, annotations=ann if ann else None) - if "cluster_mask" not in ann and cluster_mask is not None: - ann["cluster_mask"] = cluster_mask - if "barrier" not in ann and remote_barrier is not None: +def copy_cluster( + src: BufferLikeType, + dst: BufferLikeType, + *, + dst_block: int | tir.PrimExpr | None = None, + cluster_mask: int | None = None, + remote_barrier: tir.BufferLoad | None = None, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, + coalesced_width: int | None = None, + loop_layout: Any | None = None, +) -> tir.PrimExpr | tir.Stmt: + """Cluster-aware copy between shared memory regions or from global to shared memory. + + This is the entry point for two Hopper cluster features that require + ``cluster_dims`` to be set on the enclosing ``T.Kernel``: + + **TMA multicast** (global → shared, ``cluster_mask`` set): + The hardware delivers a single TMA load to the shared memory of every + masked CTA in the cluster simultaneously. At runtime the kernel splits + into three cases based on each CTA's rank: + + * **Leader** (rank == lowest set bit): issues ``tma_load_multicast``. + * **Masked peer** (other set bits): does nothing — receives passively. + * **Unmasked CTA**: issues a regular ``tma_load`` independently. + + **SM-to-SM copy** (shared → shared, ``dst_block`` set): + Copies from the local CTA's shared memory to a remote CTA's shared + memory within the same cluster. When ``remote_barrier`` is also + provided, a single bulk-async ``tl::tma_store_cluster`` instruction + is emitted; otherwise an element-by-element SIMT loop is used. + + Args: + src: Source memory region. + dst: Destination memory region. + dst_block: Destination CTA rank in the cluster for SM-to-SM copy. + cluster_mask: Bitmask of CTAs that participate in TMA multicast. + remote_barrier: Shared-memory mbarrier for asynchronous SM-to-SM copy + completion signalling. The destination CTA should wait on its + local copy of this barrier. + eviction_policy: Cache eviction hint passed to the TMA instruction. + Only relevant for the TMA multicast path (``cluster_mask`` set). + coalesced_width: Vectorization width (in elements) for the SIMT loop + used on the SM-to-SM fallback path (``dst_block`` set, no fast + bulk-async route available). + loop_layout: Parallel loop layout hint (Fragment) for the SIMT loop on + the SM-to-SM fallback path. Incompatible with the TMA multicast + path (``cluster_mask`` set). + + Returns: + tir.Call: A handle to the copy operation. + """ + src, dst = _normalize_copy_regions(src, dst) + + ann: dict = {} + if dst_block is not None: + ann["dst_block"] = dst_block + if cluster_mask is not None: + ann["cluster_mask"] = cluster_mask + if remote_barrier is not None: ann["barrier"] = remote_barrier + if eviction_policy is not None: + eviction_policy_map = {"evict_normal": 0, "evict_first": 1, "evict_last": 2} + ann["eviction_policy"] = eviction_policy_map[eviction_policy] + if coalesced_width is not None: + ann["coalesced_width"] = coalesced_width + if loop_layout is not None: + ann["parallel_loop_layout"] = loop_layout return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, annotations=ann if ann else None) From c54bcece1e10543c6f96b0283e6a665d01474edc Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Tue, 17 Mar 2026 14:17:25 +0800 Subject: [PATCH 48/51] fix: renaming cluster mbarrier_arrive to ptx_arrive_cluster_barerier --- src/op/builtin.h | 8 -------- src/target/codegen_cuda.cc | 8 -------- 2 files changed, 16 deletions(-) diff --git a/src/op/builtin.h b/src/op/builtin.h index d6b766aeb..fb60c07b7 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -257,14 +257,6 @@ TVM_DLL const Op &create_list_of_mbarrier(); */ TVM_DLL const Op &get_mbarrier(); -/*! - * \brief Arrive at mbarrier with remote cta support - * - * mbarrier_arrive(mbarrier, cta_id, pred) - * - */ -TVM_DLL const Op &mbarrier_arrive(); - /*! * \brief tvm intrinsics for loading data from global tensor descriptor to * shared memory diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index cdd0a52fd..6a8ac2ee9 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1813,14 +1813,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ICHECK_EQ(op->args.size(), 1); std::string barrier_id = this->PrintExpr(op->args[0]); os << mbarrier_name_ + "[" + barrier_id + "]"; - } else if (op->op.same_as(tl::mbarrier_arrive())) { - ICHECK_EQ(op->args.size(), 3); - this->PrintIndent(); - auto mbarrier_obj = print_mbarrier_obj(op->args[0]); - auto cta_id = this->PrintExpr(op->args[1]); - auto pred = this->PrintExpr(op->args[2]); - this->stream << mbarrier_obj << ".arrive(" << cta_id << ", " << pred - << ");\n"; } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { ICHECK_EQ(op->args.size(), 1); this->PrintIndent(); From c6f1f6dd77f7a1174feb5bb15cec4dc392567e78 Mon Sep 17 00:00:00 2001 From: sgd <2012661711@qq.com> Date: Tue, 17 Mar 2026 20:36:15 +0800 Subject: [PATCH 49/51] fix: remove mbarrier_arrive in builtin.cc --- src/op/builtin.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/op/builtin.cc b/src/op/builtin.cc index d4fdf12b9..ce44fe8d8 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -158,11 +158,6 @@ TIR_DEFINE_TL_BUILTIN(get_mbarrier) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_TL_BUILTIN(mbarrier_arrive) - .set_num_inputs(3) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); - TIR_DEFINE_TL_BUILTIN(tma_load).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); From b49462f140b21f59a60b619a7f1f136498c63732 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Tue, 17 Mar 2026 21:23:35 +0800 Subject: [PATCH 50/51] test_tma_dsmem: fallback test --- testing/python/cuda/test_tma_dsmem.py | 224 ++++++++++++++++++++++---- 1 file changed, 192 insertions(+), 32 deletions(-) diff --git a/testing/python/cuda/test_tma_dsmem.py b/testing/python/cuda/test_tma_dsmem.py index e19cf74c5..7dfced0b7 100644 --- a/testing/python/cuda/test_tma_dsmem.py +++ b/testing/python/cuda/test_tma_dsmem.py @@ -1,16 +1,26 @@ """ -Demo / regression test for SM-to-SM bulk async copy via tl::tma_store_cluster. - -T.copy with dst_block + barrier now lowers to a single -tl::tma_store_cluster call instead of a SIMT element-by-element loop. - -Expected generated producer code (block 0): - if (((int)threadIdx.x) == 0) { - tl::tma_store_cluster(&s_dst[0], &s_src[0], 1, - (uint32_t)(512), s_barrier[0]); - } - -Block 1 waits on its own s_barrier and then reads the result. +Regression tests for SM-to-SM cluster copy (T.copy_cluster with dst_block). + +Three lowering paths are covered: + + Fast path (test_tma_store_cluster): + T.copy_cluster(src, dst, dst_block=1, remote_barrier=bar) + → single tl::tma_store_cluster issued by one thread; mbarrier completion + is tracked by the TMA hardware via mbarrier.arrive.expect_tx. + + SIMT fallback, no barrier (test_store_cluster_simt_no_barrier): + T.copy_cluster(src, dst, dst_block=1) # no remote_barrier + → element-wise cooperative_groups::map_shared_rank stores by all threads; + caller uses T.cluster_sync() for ordering. + + SIMT fallback, with barrier (test_store_cluster_simt_barrier): + T.copy_cluster(src2d[0:M, 0:N_tile], dst2d[0:M, 0:N_tile], + dst_block=1, remote_barrier=bar) + where N_tile < N_full, so the inner-dim extent fails the contiguity check. + → element-wise map_shared_rank stores, followed by auto-injected + __syncthreads(); + if (threadIdx.x == 0) s_barrier[0].arrive(1u); + so the destination CTA can wait on the same mbarrier as in the fast path. """ import torch @@ -20,13 +30,17 @@ import numpy as np +# --------------------------------------------------------------------------- +# Fast path kernel +# --------------------------------------------------------------------------- + + def make_store_cluster_kernel(N: int): @T.prim_func def kernel( A: T.Tensor((N,), "float32"), B: T.Tensor((N,), "float32"), ): - # 2 CTAs in a cluster of size 2 with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: s_src = T.alloc_shared((N,), "float32") s_dst = T.alloc_shared((N,), "float32") @@ -34,57 +48,203 @@ def kernel( T.fill(s_src, 0.0) T.fill(s_dst, 0.0) - T.cluster_sync() if pid == 0: - # Load A into s_src for i in T.Parallel(N): s_src[i] = A[i] - - # Bulk-async copy s_src (local) → s_dst (remote, block 1) - # using tl::tma_store_cluster, signalling block 1's barrier. T.copy_cluster(s_src, s_dst, dst_block=1, remote_barrier=s_barrier[0]) if pid == 1: - # Wait until block 0 finishes writing to our s_dst. T.mbarrier_wait_parity(s_barrier[0], 0) + for i in T.Parallel(N): + B[i] = s_dst[i] + + return kernel + + +# --------------------------------------------------------------------------- +# SIMT fallback, no barrier +# --------------------------------------------------------------------------- + + +def make_store_cluster_simt_no_barrier_kernel(N: int): + """No remote_barrier → SIMT fallback always taken; cluster_sync() orders stores.""" + + @T.prim_func + def kernel( + A: T.Tensor((N,), "float32"), + B: T.Tensor((N,), "float32"), + ): + with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: + s_src = T.alloc_shared((N,), "float32") + s_dst = T.alloc_shared((N,), "float32") - # Store result to global memory + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + T.cluster_sync() + + if pid == 0: + for i in T.Parallel(N): + s_src[i] = A[i] + # No remote_barrier: LowerClusterCopy always takes the SIMT path. + # All threads write into block 1's s_dst via map_shared_rank. + T.copy_cluster(s_src, s_dst, dst_block=1) + + # Full cluster barrier: ensures all map_shared_rank stores from + # block 0 are visible in block 1's address space before block 1 + # reads s_dst. + T.cluster_sync() + + if pid == 1: for i in T.Parallel(N): B[i] = s_dst[i] return kernel +# --------------------------------------------------------------------------- +# SIMT fallback, with auto-injected ptx_arrive_cluster_barrier +# --------------------------------------------------------------------------- + + +def make_store_cluster_simt_barrier_kernel(M: int, N_full: int, N_tile: int): + """2-D slice copy that forces the SIMT fallback even though remote_barrier is set. + + s_src / s_dst are allocated with inner dimension N_full, but only the + first N_tile columns are copied. Because N_tile < N_full the + is_contiguous_region() check fails: the inner-dim extent of the copy + region (N_tile) does not equal the buffer shape (N_full). + + LowerClusterCopy falls back to map_shared_rank stores and, because + remote_barrier was supplied, automatically appends: + __syncthreads(); + if (threadIdx.x == 0) s_barrier[0].arrive(1u); + Block 1 therefore waits on the same mbarrier as in the fast-path API, + verifying that ptx_arrive_cluster_barrier is injected and functional. + """ + + @T.prim_func + def kernel( + A: T.Tensor((M, N_tile), "float32"), + B: T.Tensor((M, N_tile), "float32"), + ): + with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid: + # Deliberately wider buffer: N_full > N_tile so the slice + # [0:M, 0:N_tile] is non-contiguous in row-major storage. + s_src = T.alloc_shared((M, N_full), "float32") + s_dst = T.alloc_shared((M, N_full), "float32") + s_barrier = T.alloc_cluster_barrier([1]) + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + T.cluster_sync() + + if pid == 0: + for i, j in T.Parallel(M, N_tile): + s_src[i, j] = A[i, j] + + # [0:M, 0:N_tile] inner-dim extent N_tile != N_full + # → contiguity check fails → SIMT fallback. + # Compiler auto-injects: __syncthreads() + + # if (t == 0) s_barrier[0].arrive(1u); + T.copy_cluster( + s_src[0:M, 0:N_tile], + s_dst[0:M, 0:N_tile], + dst_block=1, + remote_barrier=s_barrier[0], + ) + + if pid == 1: + # Block 1 waits on the auto-injected ptx_arrive_cluster_barrier. + T.mbarrier_wait_parity(s_barrier[0], 0) + for i, j in T.Parallel(M, N_tile): + B[i, j] = s_dst[i, j] + + return kernel + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_tma_store_cluster(): - + """Fast path: T.copy_cluster emits tl::tma_store_cluster.""" N = 128 prim_func = make_store_cluster_kernel(N) mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") - # Assert that the lowering actually produced tl::tma_store_cluster. - # The SIMT fallback (map_shared_rank + scalar stores) also copies data - # correctly, so a pure numerical check would miss a regression where - # T.copy_cluster(dst_block=..., remote_barrier=...) stops emitting the bulk-async - # cluster intrinsic. src = mod.get_kernel_source() assert "tl::tma_store_cluster" in src, ( "Expected tl::tma_store_cluster in generated kernel source; " - "T.copy_cluster(dst_block=..., remote_barrier=...) may have regressed to the " - f"SIMT fallback.\nKernel source:\n{src}" + "T.copy_cluster(dst_block=..., remote_barrier=...) may have regressed " + f"to the SIMT fallback.\nKernel source:\n{src}" + ) + + A = torch.arange(N, dtype=torch.float32, device="cuda") + B = mod(A) + np.testing.assert_allclose( + B.cpu().numpy(), + A.cpu().numpy(), + rtol=0, + atol=0, + err_msg="tma_store_cluster copy produced wrong result", ) + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_store_cluster_simt_no_barrier(): + """SIMT fallback (no remote_barrier): map_shared_rank + cluster_sync ordering.""" + N = 128 + prim_func = make_store_cluster_simt_no_barrier_kernel(N) + mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") + + src = mod.get_kernel_source() + assert "map_shared_rank" in src, f"Expected map_shared_rank in generated source for no-barrier SIMT fallback.\nKernel source:\n{src}" + assert "tl::tma_store_cluster" not in src, f"No-barrier path must NOT emit tl::tma_store_cluster.\nKernel source:\n{src}" + A = torch.arange(N, dtype=torch.float32, device="cuda") B = mod(A) + np.testing.assert_allclose( + B.cpu().numpy(), + A.cpu().numpy(), + rtol=0, + atol=0, + err_msg="SIMT no-barrier cluster copy produced wrong result", + ) - result = B.cpu().numpy() - expected = A.cpu().numpy() - diff = np.abs(result - expected).max() - assert np.allclose(result, expected), f"tma_store_cluster copy failed: max diff = {diff}" +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_store_cluster_simt_barrier(): + """SIMT fallback with auto-injected ptx_arrive_cluster_barrier. + + A non-full-span 2-D slice forces the fallback even though remote_barrier + is supplied. The auto-injected arrive lets block 1 wait on the same + mbarrier as in the fast-path API, verifying barrier correctness. + """ + M, N_full, N_tile = 4, 64, 32 # M * N_tile == 128 == thread count + + prim_func = make_store_cluster_simt_barrier_kernel(M, N_full, N_tile) + mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") + + src = mod.get_kernel_source() + assert "map_shared_rank" in src, f"Expected map_shared_rank for SIMT+barrier fallback.\nKernel source:\n{src}" + assert "tl::tma_store_cluster" not in src, f"Non-contiguous 2-D slice must NOT emit tl::tma_store_cluster.\nKernel source:\n{src}" + + A = torch.arange(M * N_tile, dtype=torch.float32, device="cuda").reshape(M, N_tile) + B = mod(A) + np.testing.assert_allclose( + B.cpu().numpy(), + A.cpu().numpy(), + rtol=0, + atol=0, + err_msg="SIMT+auto-barrier cluster copy produced wrong result", + ) if __name__ == "__main__": From 5d045aabb2e3050d6ef2033714291a1f5eeefae6 Mon Sep 17 00:00:00 2001 From: jingkai he Date: Wed, 18 Mar 2026 15:28:46 +0800 Subject: [PATCH 51/51] =?UTF-8?q?[Feature]=20Multi-TMA=20fallback=20for=20?= =?UTF-8?q?non-contiguous=20T.copy=5Fcluster=20regions=20Previously,=20a?= =?UTF-8?q?=20non-contiguous=20copy=20region=20(e.g.=20a=202-D=20slice=20w?= =?UTF-8?q?here=20the=20inner=20dim=20does=20not=20span=20the=20full=20buf?= =?UTF-8?q?fer=20width)=20with=20a=20remote=5Fbarrier=20fell=20back=20to?= =?UTF-8?q?=20element-wise=20SIMT=20stores,=20losing=20TMA=20throughput.?= =?UTF-8?q?=20This=20commit=20introduces=20a=20"multi-TMA"=20middle=20path?= =?UTF-8?q?=20in=20LowerClusterCopy:=20the=20ND=20copy=20region=20is=20rec?= =?UTF-8?q?ursively=20decomposed=20into=20individual=20contiguous=20rows,?= =?UTF-8?q?=20each=20emitting=20one=20tma=5Fstore=5Fcluster=20call.=20=20T?= =?UTF-8?q?he=20mbarrier=20arrive=5Fcount=20is=20updated=20to=20N=20(the?= =?UTF-8?q?=20number=20of=20rows)=20so=20the=20destination=20CTA's=20wait(?= =?UTF-8?q?)=20still=20works=20correctly=20without=20any=20API=20change.?= =?UTF-8?q?=20Key=20changes:=20-=20src/op/copy.cc:=20add=20static=20MakeTM?= =?UTF-8?q?ARows()=20that=20recursively=20splits=20an=20=20=20ND=20region?= =?UTF-8?q?=20into=20contiguous=20rows;=20static=20extents=20are=20unrolle?= =?UTF-8?q?d=20at=20=20=20compile=20time,=20symbolic=20extents=20produce?= =?UTF-8?q?=20TIR=20For=20loops.=20-=20src/op/operator.h:=20add=20UpdateBa?= =?UTF-8?q?rrierArriveCallback=20(Var=20=E2=86=92=20PrimExpr)=20=20=20to?= =?UTF-8?q?=20LowerArgs=20so=20LowerClusterCopy=20can=20propagate=20the=20?= =?UTF-8?q?new=20arrive=5Fcount.=20-=20src/transform/lower=5Ftile=5Fop.cc:?= =?UTF-8?q?=20collect=20arrive=5Fcount=20overrides=20in=20=20=20barrier=5F?= =?UTF-8?q?arrive=5Fupdates=5F=20and=20apply=20them=20to=20the=20barrier?= =?UTF-8?q?=5Finit=20block=20=20=20annotation=20after=20visiting=20the=20b?= =?UTF-8?q?lock=20body,=20before=20LowerSharedBarrier=20=20=20consumes=20i?= =?UTF-8?q?t.=20-=20testing/python/cuda/test=5Ftma=5Fdsmem.py:=20replace?= =?UTF-8?q?=20the=20old=20SIMT-barrier=20=20=20test=20with=20test=5Fstore?= =?UTF-8?q?=5Fcluster=5Fmulti=5Ftma=5Fbarrier=20(2-D,=20M=20rows)=20and=20?= =?UTF-8?q?add=20=20=20test=5Fstore=5Fcluster=5F3d=5Fmulti=5Ftma=20(3-D,?= =?UTF-8?q?=20D=C3=97M=20rows)=20to=20cover=20the=20=20=20two-level=20recu?= =?UTF-8?q?rsive=20decomposition.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/op/copy.cc | 192 +++++++++++++++++++++++++- src/op/operator.h | 6 + src/transform/lower_tile_op.cc | 46 +++++- testing/python/cuda/test_tma_dsmem.py | 108 +++++++++++++-- 4 files changed, 334 insertions(+), 18 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index 5eb83b448..58692f00a 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1499,6 +1499,141 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, return body; } +// --------------------------------------------------------------------------- +// Multi-TMA decomposition helper +// --------------------------------------------------------------------------- +// Recursively splits an N-D copy region into individual contiguous "rows" and +// emits one tma_store_cluster call per row. +// +// Returns (stmts, arrive_count_expr) where: +// • stmts – flat list of TIR statements to emit (may contain For +// loops when an extent is symbolic/dynamic) +// • arrive_count – PrimExpr for the total number of tma_store_cluster +// calls; used as the mbarrier arrive_count so the +// destination CTA's wait completes only after all rows +// have been transferred. +// +// Algorithm: +// Base case – both regions are already contiguous → emit one call, count=1. +// Static dim – extent is a compile-time IntImm → unroll and recurse. +// Dynamic dim – extent is symbolic → generate a TIR For loop; the loop +// body is the recursion result; count = extent * body_count. +// +// Preconditions: +// • src_ranges and dst_ranges have the same rank and matching per-dim +// extents (checked by the caller). +static std::pair, PrimExpr> +MakeTMARows(const Buffer &src, const Array &src_ranges, + const Buffer &dst, const Array &dst_ranges, + PrimExpr dst_block, PrimExpr barrier_load, + arith::Analyzer *analyzer) { + int n = static_cast(src_ranges.size()); + + // Check whether a buffer/range pair is contiguous in row-major storage. + auto is_contig = [&](const Buffer &buf, const Array &ranges) -> bool { + int pivot = -1; + for (int i = 0; i < n; ++i) { + if (!analyzer->CanProveEqual(ranges[i]->extent, 1)) { + pivot = i; + break; + } + } + if (pivot == -1) + return true; + for (int i = pivot + 1; i < n; ++i) { + if (!analyzer->CanProveEqual(ranges[i]->min, 0) || + !analyzer->CanProveEqual(ranges[i]->extent, buf->shape[i])) + return false; + } + return true; + }; + + // Linear element offset for the starting element of a region. + auto linear_off = [](const Buffer &buf, + const Array &ranges) -> PrimExpr { + int r = static_cast(ranges.size()); + PrimExpr off = 0, stride = 1; + for (int i = r - 1; i >= 0; --i) { + off = off + ranges[i]->min * stride; + if (i > 0) + stride = stride * buf->shape[i]; + } + return off; + }; + + // Base case: both regions are contiguous → one tma_store_cluster. + if (is_contig(src, src_ranges) && is_contig(dst, dst_ranges)) { + PrimExpr total_elems = 1; + for (const auto &r : src_ranges) + total_elems = total_elems * r->extent; + PrimExpr size_bytes = + cast(DataType::UInt(32), total_elems * src->dtype.bytes()); + PrimExpr src_ptr = src.access_ptr(1, DataType::Handle(), 1, + linear_off(src, src_ranges), total_elems); + PrimExpr dst_ptr = dst.access_ptr(2, DataType::Handle(), 1, + linear_off(dst, dst_ranges), total_elems); + Stmt call = + Evaluate(Call(DataType::Handle(), tma_store_cluster(), + {dst_ptr, src_ptr, dst_block, size_bytes, barrier_load})); + return {{call}, IntImm(DataType::Int(32), 1)}; + } + + // Recursive case: find the outermost non-trivial dimension. + int split_dim = -1; + for (int d = 0; d < n; ++d) { + if (!analyzer->CanProveEqual(src_ranges[d]->extent, 1)) { + split_dim = d; + break; + } + } + ICHECK(split_dim >= 0) + << "MakeTMARows: all dimensions are trivial yet region is not " + "contiguous – this should not happen"; + + PrimExpr extent = src_ranges[split_dim]->extent; + const auto *ext_imm = extent.as(); + + if (ext_imm) { + // ── Static extent: unroll at compile time ────────────────────────────── + Array all_stmts; + PrimExpr total = IntImm(DataType::Int(32), 0); + for (int64_t k = 0; k < ext_imm->value; ++k) { + Array new_src = src_ranges; + Array new_dst = dst_ranges; + PrimExpr kexpr = IntImm(DataType::Int(32), k); + new_src.Set(split_dim, + Range::FromMinExtent(src_ranges[split_dim]->min + kexpr, 1)); + new_dst.Set(split_dim, + Range::FromMinExtent(dst_ranges[split_dim]->min + kexpr, 1)); + auto [stmts, cnt] = MakeTMARows(src, new_src, dst, new_dst, dst_block, + barrier_load, analyzer); + for (const auto &s : stmts) + all_stmts.push_back(s); + total = total + cnt; + } + return {all_stmts, total}; + } else { + // ── Dynamic extent: emit a TIR For loop ──────────────────────────────── + // Build the loop body: fix split_dim to (range_min + loop_var). + Var k("k_tma_row", DataType::Int(32)); + Array body_src = src_ranges; + Array body_dst = dst_ranges; + body_src.Set(split_dim, + Range::FromMinExtent(src_ranges[split_dim]->min + k, 1)); + body_dst.Set(split_dim, + Range::FromMinExtent(dst_ranges[split_dim]->min + k, 1)); + auto [body_stmts, body_cnt] = MakeTMARows( + src, body_src, dst, body_dst, dst_block, barrier_load, analyzer); + Stmt body = body_stmts.size() == 1 ? body_stmts[0] + : static_cast(SeqStmt(body_stmts)); + Stmt for_loop = + For(k, IntImm(DataType::Int(32), 0), extent, ForKind::kSerial, body); + // arrive_count = extent * body_count (every iteration contributes the same) + PrimExpr total = extent * body_cnt; + return {{for_loop}, total}; + } +} + Stmt CopyNode::LowerClusterCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { // Check if the target supports cluster copy @@ -1560,13 +1695,10 @@ Stmt CopyNode::LowerClusterCopy(const LowerArgs &T, dst_elements = dst_elements * r->extent; bool element_match = analyzer->CanProveEqual(src_elements, dst_elements); - if (!(src_contiguous && dst_contiguous && element_match)) { - LOG(WARNING) - << "Falling back to element-wise cluster copy: bulk cluster fast " - "path requires contiguous src/dst regions with matching element " - "counts. src=" - << src->name << ", dst=" << dst->name; - } else { + if (src_contiguous && dst_contiguous && element_match) { + // ----------------------------------------------------------------------- + // Fast path: single tma_store_cluster for the whole contiguous region. + // ----------------------------------------------------------------------- PrimExpr barrier_load = barrier_opt.value(); // Compute linear offsets from the copy ranges (one offset per buffer). @@ -1608,6 +1740,52 @@ Stmt CopyNode::LowerClusterCopy(const LowerArgs &T, // Single-thread guard: only thread_bounds->min issues the instruction. return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), bulk_copy); } + + // ------------------------------------------------------------------------- + // Multi-TMA path: non-contiguous region decomposed into N contiguous rows. + // Requires matching per-dim extents between src and dst (element_match + // alone is insufficient – we need the same iteration space on both sides). + // Extents may be compile-time constants (→ unrolled calls) or symbolic + // (→ TIR For loops); in both cases MakeTMARows computes the arrive_count + // expression used to initialise the mbarrier. + // ------------------------------------------------------------------------- + bool same_shape = (src_range.size() == dst_range.size()); + for (size_t d = 0; d < src_range.size() && same_shape; ++d) { + if (!analyzer->CanProveEqual(src_range[d]->extent, + dst_range[d]->extent)) { + same_shape = false; + } + } + + if (element_match && same_shape) { + PrimExpr barrier_load = barrier_opt.value(); + const auto *barrier_buf_load = barrier_load.as(); + ICHECK(barrier_buf_load) + << "LowerClusterCopy: expected BufferLoad for barrier annotation"; + Var barrier_data_var = barrier_buf_load->buffer->data; + + auto [tma_stmts, n_rows] = + MakeTMARows(src, src_range, dst, dst_range, dst_block.value(), + barrier_load, analyzer); + + // Inform LowerTileOpPass so it can update the barrier's arrive_count in + // the barrier_init block annotation before LowerSharedBarrier runs. + if (T.UpdateBarrierArrive) { + T.UpdateBarrierArrive(barrier_data_var, n_rows); + } + + Stmt seq = (tma_stmts.size() == 1) + ? tma_stmts[0] + : static_cast(SeqStmt(tma_stmts)); + // Single-thread guard: only thread_bounds->min issues TMA instructions. + return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), seq); + } + + LOG(WARNING) + << "Falling back to element-wise cluster copy: bulk cluster " + "fast/multi-TMA paths require matching element counts and same " + "per-dim extents between src and dst. src=" + << src->name << ", dst=" << dst->name; } // --------------------------------------------------------------------------- diff --git a/src/op/operator.h b/src/op/operator.h index afa36bf38..32d9bf382 100644 --- a/src/op/operator.h +++ b/src/op/operator.h @@ -24,6 +24,11 @@ using namespace tir; using AddWorkspaceCallback = std::function; using AllocMBarrierCallback = std::function; +// Called by LowerClusterCopy when the multi-TMA path decomposes a +// non-contiguous region into N separate tma_store_cluster calls. The pass +// records the (barrier_data_var → N) mapping and updates the barrier_init +// block annotation before LowerSharedBarrier consumes it. +using UpdateBarrierArriveCallback = std::function; using LayoutMap = Map; using BufferMap = Map; @@ -53,6 +58,7 @@ struct LowerArgs { Var thread_var; AddWorkspaceCallback AddWorkspace; AllocMBarrierCallback AllocMBarrier; + UpdateBarrierArriveCallback UpdateBarrierArrive; LayoutMap layout_map; Map buffer_remap; // Map from LetStmt variable to its bound expression, for resolving diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index e683a7edb..1c1e79a63 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -291,6 +291,36 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { } workspace_stack_.pop_back(); } + + // Apply any barrier arrive-count overrides registered by LowerClusterCopy + // during the multi-TMA decomposition path. We update the barrier_init + // annotation here (before LowerSharedBarrier consumes it) so that the + // mbarrier is initialised with arrive_count = N (number of TMA rows). + if (!barrier_arrive_updates_.empty() && + block->annotations.count("barrier_init")) { + auto barrier_init_map = Downcast>>( + block->annotations.Get("barrier_init").value()); + bool updated = false; + for (auto it = barrier_arrive_updates_.begin(); + it != barrier_arrive_updates_.end();) { + if (barrier_init_map.count(it->first)) { + auto old_counts = barrier_init_map.at(it->first); + Array new_counts; + for (size_t i = 0; i < old_counts.size(); i++) { + new_counts.push_back(it->second); + } + barrier_init_map.Set(it->first, new_counts); + updated = true; + it = barrier_arrive_updates_.erase(it); + } else { + ++it; + } + } + if (updated) { + block_ptr->annotations.Set("barrier_init", barrier_init_map); + } + } + return block; } @@ -1029,6 +1059,11 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { return id; }; + UpdateBarrierArriveCallback barrier_arrive_callback = [this](Var data_var, + PrimExpr n) { + barrier_arrive_updates_[data_var] = n; + }; + // Compute mbarrier expressions from the enclosing loop and pipeline info. // pipeline_num_stages: number of pipeline stages (from T.Pipelined // annotation) mbar_stage_expr: ko % num_stages (cycles through multiple @@ -1050,8 +1085,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { auto lowered = tile_op->Lower( LowerArgs{target_, thread_bounds, thread_var_->var, callback, - mbarrier_callback, layout_map_, buffer_remap_, - let_var_to_expr, + mbarrier_callback, barrier_arrive_callback, layout_map_, + buffer_remap_, let_var_to_expr, /*in_pipeline=*/pipelined_depth_ > 0, mbar_phase_expr, pipeline_num_stages, mbar_stage_expr}, analyzer_); @@ -1350,6 +1385,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { // parameters rather than in memory indices. bool in_tma_context_{false}; int pipelined_depth_{0}; + // Pending barrier arrive-count overrides from multi-TMA cluster-copy + // decomposition. Maps barrier buffer data Var → new arrive count N. + // Populated by LowerClusterCopy via UpdateBarrierArriveCallback and + // consumed (then cleared) in VisitStmt_(BlockNode) before LowerSharedBarrier + // processes the barrier_init annotation. + std::unordered_map + barrier_arrive_updates_; }; namespace transform { diff --git a/testing/python/cuda/test_tma_dsmem.py b/testing/python/cuda/test_tma_dsmem.py index 7dfced0b7..d2b6a423d 100644 --- a/testing/python/cuda/test_tma_dsmem.py +++ b/testing/python/cuda/test_tma_dsmem.py @@ -220,21 +220,29 @@ def test_store_cluster_simt_no_barrier(): @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_store_cluster_simt_barrier(): - """SIMT fallback with auto-injected ptx_arrive_cluster_barrier. +def test_store_cluster_multi_tma_barrier(): + """Multi-TMA path: non-contiguous 2-D slice decomposed into N row TMA calls. - A non-full-span 2-D slice forces the fallback even though remote_barrier - is supplied. The auto-injected arrive lets block 1 wait on the same - mbarrier as in the fast-path API, verifying barrier correctness. + The compiler decomposes a non-full-span 2-D slice (shape [M, N_tile] inside + a buffer of shape [M, N_full]) into M individual tma_store_cluster calls – + one per contiguous row. The mbarrier arrive_count is updated to M so the + destination CTA's wait(0) completes only after all rows are transferred. """ - M, N_full, N_tile = 4, 64, 32 # M * N_tile == 128 == thread count + M, N_full, N_tile = 4, 64, 32 # M rows, each N_tile elements prim_func = make_store_cluster_simt_barrier_kernel(M, N_full, N_tile) mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") src = mod.get_kernel_source() - assert "map_shared_rank" in src, f"Expected map_shared_rank for SIMT+barrier fallback.\nKernel source:\n{src}" - assert "tl::tma_store_cluster" not in src, f"Non-contiguous 2-D slice must NOT emit tl::tma_store_cluster.\nKernel source:\n{src}" + # Multi-TMA path must emit M separate tma_store_cluster calls, not SIMT stores. + assert "tl::tma_store_cluster" in src, f"Expected tl::tma_store_cluster for multi-TMA row decomposition.\nKernel source:\n{src}" + assert "map_shared_rank" not in src, f"Multi-TMA path must NOT fall back to map_shared_rank.\nKernel source:\n{src}" + # The barrier must be initialised with arrive_count == M (one per TMA call). + assert f"s_barrier[0].init({M})" in src, f"Expected barrier arrive_count={M} for {M}-row decomposition.\nKernel source:\n{src}" + # Exactly M tma_store_cluster calls should appear in the source. + assert src.count("tl::tma_store_cluster") == M, ( + f"Expected exactly {M} tma_store_cluster calls, got {src.count('tl::tma_store_cluster')}.\nKernel source:\n{src}" + ) A = torch.arange(M * N_tile, dtype=torch.float32, device="cuda").reshape(M, N_tile) B = mod(A) @@ -243,7 +251,89 @@ def test_store_cluster_simt_barrier(): A.cpu().numpy(), rtol=0, atol=0, - err_msg="SIMT+auto-barrier cluster copy produced wrong result", + err_msg="Multi-TMA row-decomposed cluster copy produced wrong result", + ) + + +# --------------------------------------------------------------------------- +# 3-D multi-TMA: two levels of recursive decomposition +# --------------------------------------------------------------------------- + + +def make_store_cluster_3d_multi_tma_kernel(D: int, M: int, N_full: int, N_tile: int): + """3-D slice copy decomposed into D*M tma_store_cluster calls. + + Buffer shape is [D, M, N_full]; the copy region is [0:D, 0:M, 0:N_tile]. + Because N_tile < N_full the innermost dim is not full-span. + MakeTMARows recurses twice (once on dim 0, once on dim 1) producing D*M + contiguous-row TMA calls and sets barrier arrive_count = D*M. + """ + + @T.prim_func + def kernel( + A: T.Tensor((D, M, N_tile), "float32"), + B: T.Tensor((D, M, N_tile), "float32"), + ): + with T.Kernel(2, threads=D * M * N_tile, cluster_dims=(2, 1, 1)) as pid: + s_src = T.alloc_shared((D, M, N_full), "float32") + s_dst = T.alloc_shared((D, M, N_full), "float32") + s_barrier = T.alloc_cluster_barrier([1]) + + T.fill(s_src, 0.0) + T.fill(s_dst, 0.0) + T.cluster_sync() + + if pid == 0: + for d, i, j in T.Parallel(D, M, N_tile): + s_src[d, i, j] = A[d, i, j] + + T.copy_cluster( + s_src[0:D, 0:M, 0:N_tile], + s_dst[0:D, 0:M, 0:N_tile], + dst_block=1, + remote_barrier=s_barrier[0], + ) + + if pid == 1: + T.mbarrier_wait_parity(s_barrier[0], 0) + for d, i, j in T.Parallel(D, M, N_tile): + B[d, i, j] = s_dst[d, i, j] + + return kernel + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_store_cluster_3d_multi_tma(): + """3-D multi-TMA: recursive decomposition produces D*M separate TMA calls. + + With D=2 and M=4, the two-level recursion over dims 0 and 1 yields 8 + contiguous-row tma_store_cluster calls and initialises the barrier with + arrive_count=8. + """ + D, M, N_full, N_tile = 2, 4, 32, 16 # D*M*N_tile == 128 == thread count + + prim_func = make_store_cluster_3d_multi_tma_kernel(D, M, N_full, N_tile) + mod = tilelang.compile(prim_func, out_idx=[1], execution_backend="cython") + + src = mod.get_kernel_source() + n_expected = D * M + assert "tl::tma_store_cluster" in src, f"Expected tl::tma_store_cluster for 3-D multi-TMA.\nKernel source:\n{src}" + assert f"s_barrier[0].init({n_expected})" in src, ( + f"Expected barrier arrive_count={n_expected} for {D}x{M} decomposition.\nKernel source:\n{src}" + ) + assert src.count("tl::tma_store_cluster") == n_expected, ( + f"Expected exactly {n_expected} tma_store_cluster calls, got {src.count('tl::tma_store_cluster')}.\nKernel source:\n{src}" + ) + + A = torch.arange(D * M * N_tile, dtype=torch.float32, device="cuda").reshape(D, M, N_tile) + B = mod(A) + np.testing.assert_allclose( + B.cpu().numpy(), + A.cpu().numpy(), + rtol=0, + atol=0, + err_msg="3-D multi-TMA cluster copy produced wrong result", )