Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
80e409b
[Feature] Extend t.copy to support TMA multicast and SM-to-SM cluster…
He-Jingkai Mar 6, 2026
01c73b3
[Docs] Add programming guide for Cluster TMA features
He-Jingkai Mar 6, 2026
9ab0119
Merge remote-tracking branch 'upstream/main' into cluster-launch
He-Jingkai Mar 6, 2026
e8b3bd0
change cluster_size -> cluster_dims
He-Jingkai Mar 6, 2026
deab402
fix merge conflict
He-Jingkai Mar 7, 2026
2ef8eba
fix TILELANG CHECK bug
He-Jingkai Mar 7, 2026
29a7971
unify cluster function
He-Jingkai Mar 7, 2026
5d57c11
fix pre-commit errors
He-Jingkai Mar 7, 2026
c2efb88
docs(cluster_tma): clarify multicast non-issuer behavior in lowering …
sigmoidsee Mar 9, 2026
258c948
fix(copy): gate dst_block cluster lowering on target capability
sigmoidsee Mar 9, 2026
5a61902
Merge branch 'tile-ai:main' into t_copy_extend
He-Jingkai Mar 9, 2026
d2d7f1e
fix(cuda/cluster): hard-trap when cluster intrinsics are unavailable …
sigmoidsee Mar 9, 2026
2c67281
fix inject_tma & lower_hopper & lower_tile bugs
sigmoidsee Mar 9, 2026
cbc15d7
Merge remote-tracking branch 'pr/t_copy_extend' into t_copy_extend
sigmoidsee Mar 9, 2026
db56e92
[warp specialize]Fix mbarrier retarget for tma_load_multicast in warp…
sigmoidsee Mar 9, 2026
8af6bd0
fix(tma-barrier): correct arrive thread count under equality-guarded …
sigmoidsee Mar 9, 2026
f2cb665
fix: overlapped TMA barrier IDs
He-Jingkai Mar 9, 2026
febfd78
fix: multicast checks
He-Jingkai Mar 9, 2026
bf38cfa
minor fix
He-Jingkai Mar 9, 2026
4008456
minor fix
He-Jingkai Mar 9, 2026
cb5305f
fix: tma_load_multicast() in the stage-local mbarrier rewrite
He-Jingkai Mar 9, 2026
310ccbb
minor fix
He-Jingkai Mar 9, 2026
c50f426
format
He-Jingkai Mar 9, 2026
add3de9
minor fix
He-Jingkai Mar 9, 2026
531c51f
fix: Track allocation from every hoisted barrier init.
He-Jingkai Mar 9, 2026
1086f24
fix: Scope the 128-byte alignment to kernels that actually use TMA.
He-Jingkai Mar 9, 2026
1996788
fix: Assert the tl::tma_store_cluster lowering, not just the output.
He-Jingkai Mar 9, 2026
57fd797
format fix
He-Jingkai Mar 9, 2026
a39604d
fix: Don't assume both if arms transfer the same number of bytes.
He-Jingkai Mar 9, 2026
d1c7bb8
fix: Don't wrap the kept pipelined loop in a SeqStmt before the For p…
He-Jingkai Mar 9, 2026
4244e07
rm mbarrier_init and use alloc_cluster_barrier
He-Jingkai Mar 9, 2026
e5b00c5
fix: Don't make cluster copy a global exemption for unrelated TMA loads.
He-Jingkai Mar 9, 2026
25f0aa1
fix: Don't drop dependency modeling for handle-based mbarrier_wait_pa…
He-Jingkai Mar 9, 2026
f0d4169
fix: Don't hoist extracted barrier-init if statements to the block root.
He-Jingkai Mar 9, 2026
ccd9c1f
Merge branch 't_copy_extend' of github.com:He-Jingkai/tilelang into t…
sigmoidsee Mar 10, 2026
8a947bb
fix(copy): honor remapped dst layout in cluster-copy slow path
sigmoidsee Mar 10, 2026
6ca3e73
fix(copy): gate cluster TMA fast path on provable contiguity
sigmoidsee Mar 10, 2026
989044e
fix(copy): add barrier completion for SIMT cluster-copy fallback
sigmoidsee Mar 10, 2026
2a26453
fix: remove dup cluster_sync
He-Jingkai Mar 11, 2026
4a982ab
Merge branch 't_copy_extend' of github.com:He-Jingkai/tilelang into t…
He-Jingkai Mar 11, 2026
c07a79c
fix: remove not used _get_mbarrier
He-Jingkai Mar 11, 2026
a2a1fc6
fix: remove mbarrier related code in codegen_cuda.cc and barrier.h
He-Jingkai Mar 11, 2026
85cebfb
fix: codegen_cuda.cc: reuse functions in cluster.h and remove depende…
He-Jingkai Mar 11, 2026
a5da2a3
fix testting
He-Jingkai Mar 12, 2026
99c800c
fix(transform): prevent ragged_prefix free var from breaking MakePack…
sigmoidsee Mar 12, 2026
831beb4
minor fix
He-Jingkai Mar 16, 2026
b30e926
format
He-Jingkai Mar 16, 2026
a8753c8
fix inject tma barrier: fix crash and correct mbarrier thread-count i…
sigmoidsee Mar 16, 2026
8a3f99d
Merge branch 't_copy_extend' of github.com:He-Jingkai/tilelang into t…
sigmoidsee Mar 16, 2026
a312dd1
Merge remote-tracking branch 'upstream/main' into t_copy_extend
He-Jingkai Mar 16, 2026
9abd5ec
revert inject tma
sigmoidsee Mar 16, 2026
b9ab3f3
Merge branch 't_copy_extend' of github.com:He-Jingkai/tilelang into t…
sigmoidsee Mar 16, 2026
0ec4374
fix(inject_tma_barrier): avoid SIGSEGV and correctly infer barrier ar…
sigmoidsee Mar 16, 2026
79cfee7
fix: add tma_load_multicast same as tma_load
sigmoidsee Mar 16, 2026
00ea64b
T.copy_cluster
He-Jingkai Mar 17, 2026
c54bcec
fix: renaming cluster mbarrier_arrive to ptx_arrive_cluster_barerier
sigmoidsee Mar 17, 2026
c6f1f6d
fix: remove mbarrier_arrive in builtin.cc
sigmoidsee Mar 17, 2026
b49462f
test_tma_dsmem: fallback test
He-Jingkai Mar 17, 2026
aeffe21
Merge branch 't_copy_extend' of github.com:He-Jingkai/tilelang into t…
He-Jingkai Mar 17, 2026
0828868
Merge branch 'main' of github.com:tile-ai/tilelang into t_copy_extend
He-Jingkai Mar 18, 2026
5d045aa
[Feature] Multi-TMA fallback for non-contiguous T.copy_cluster regions
He-Jingkai Mar 18, 2026
317e139
Merge remote-tracking branch 'upstream/main' into t_copy_extend
sigmoidsee Mar 19, 2026
3feb182
Merge branch 'main' of github.com:tile-ai/tilelang into t_copy_extend
sigmoidsee Mar 24, 2026
f760b53
Merge branch 'main' of github.com:tile-ai/tilelang into t_copy_extend
sigmoidsee Mar 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 303 additions & 0 deletions docs/programming_guides/cluster_tma.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
# Cluster TMA: Multicast and SM-to-SM Copy

This page describes two advanced data-movement features that are available on
NVIDIA Hopper (SM90) and later: **TMA multicast** and **SM-to-SM cluster
copy**. Both features are exposed through extensions to the existing `T.copy`
operator and require a kernel launched with thread block cluster, i.e., with `cluster_dims != (1, 1, 1)`.

Requirements:
- CUDA Compute Capability ≥ 9.0 (Hopper / Blackwell / RTX 5090)

---

## Background: Thread Block Clusters

A *thread block cluster* is a group of CTAs that share a common virtual address
space for their shared-memory regions and can communicate without going through
global memory. Within a cluster, each CTA has a *block rank* (0-indexed
position inside the cluster), and all CTAs can observe each other's shared
memory via the `shared::cluster` address space.

```python
with T.Kernel(grid_x, grid_y, threads=128, cluster_dims=(4, 1, 1)) as (bx, by):
rank = T.block_rank_in_cluster() # 0..3 within this cluster
nctas = T.get_cluster_block_nums() # total CTAs in this cluster
T.cluster_sync() # barrier across all CTAs in cluster
```

---

## Feature 1 — TMA Multicast (`cluster_mask`)

### What it does

Normally each CTA issues its own TMA load, fetching a tile from global memory
into its private shared memory. With multicast, **a single TMA transaction
broadcasts one global tile to every participating CTA simultaneously**, saving
repeated DRAM traffic when multiple CTAs in a cluster need the same data (e.g.,
the same K-panel in a split-K GEMM).

```text
Global memory ──TMA multicast──▶ shared memory (rank 0)
└─▶ shared memory (rank 1) (same tile, no extra DRAM read)
TMA load ──▶ shared memory (rank 2) (independent tile)
TMA load ──▶ shared memory (rank 3) (independent tile)
```

### API

```python
T.copy_cluster(src_global, dst_shared, cluster_mask=<int>)
```

`cluster_mask` is a bitmask where each set bit identifies a CTA rank that
participates in the multicast. The CTA whose rank equals the lowest set bit
in the mask issues `cp.async.bulk.tensor … multicast::cluster`; every other
CTA in the mask receives the data passively (no instruction issued). CTAs
outside the mask perform a regular TMA load for their own tile.

### Example

```python
import tilelang
import tilelang.language as T

def make_tma_multicast_kernel(M, N, block_M, block_N, cluster_mask):
@T.prim_func
def kernel(
A: T.Tensor((M, N), "float16"),
B: T.Tensor((M, N), "float16"),
):
# 4 CTAs per cluster; ranks 0 and 1 share the same tile via multicast.
with T.Kernel(
T.ceildiv(N, block_N),
T.ceildiv(M, block_M),
threads=128,
cluster_dims=(4, 1, 1)
) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), "float16")

# cluster_mask=0b0011: ranks 0 and 1 participate.
# Rank 0 issues tma_load_multicast; rank 1 receives passively.
# Ranks 2 and 3 each issue a regular tma_load.
T.copy_cluster(A[by * block_M, bx * block_N], A_shared,
cluster_mask=cluster_mask)

T.copy(A_shared, B[by * block_M, bx * block_N])

return kernel
```

Running the kernel above with `cluster_mask = 0b0011`:

| Rank | Action | `B` slice receives |
|------|--------|--------------------|
| 0 | issues multicast load | A tile at rank-0 address |
| 1 | passively receives | **same** A tile as rank 0 |
| 2 | regular TMA load | A tile at rank-2 address |
| 3 | regular TMA load | A tile at rank-3 address |

### Notes

- The compiler lowers `cluster_mask != 0` to
`cp.async.bulk.tensor.Nd.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster`
for the issuing CTA; CTAs in the mask but not elected as issuer receive
passively, and only CTAs outside the mask issue a standard
`cp.async.bulk.tensor`.
- Software-pipelining (`T.Pipelined`) is fully supported; the warp-specialized
rewriter recognises `tma_load_multicast` as a producer operation.
- `cluster_mask` is a compile-time constant; dynamic masks are not supported.

---

## Feature 2 — SM-to-SM Cluster Copy (`dst_block`)

### What it does

SM-to-SM copy lets one CTA **push data directly from its own shared memory
into another CTA's shared memory** within the same cluster, without a round
trip through global memory. This is useful for patterns such as:

- Partial result exchange (e.g., split-K partial sums across SM boundaries)
- Producer–consumer pipelines where the producer fills a neighbor's buffer
- All-to-all collective communication within a cluster

Two sub-variants are provided depending on whether an mbarrier is supplied:

| Variant | Parameter | Hardware instruction | Threads used |
|---------|-----------|---------------------|--------------|
| **Fast path** | `dst_block` + `remote_barrier` | `cp.async.bulk.shared::cluster` | 1 (async DMA) |
| **Slow path** | `dst_block` only | `map_shared_rank` + scalar stores | all (SIMT loop) |

### Fast path — bulk async copy with mbarrier

```python
T.copy_cluster(src_shared, dst_shared, dst_block=<rank>, remote_barrier=<mbarrier>)
```

A single elected thread issues one `cp.async.bulk.shared::cluster` instruction.
The hardware DMA engine transfers the entire tile asynchronously and signals
the destination CTA's mbarrier on completion. The destination CTA waits with
`T.mbarrier_wait_parity`.

Steps:
1. Both CTAs allocate the **same** shared memory layout so their mbarriers live
at the same offset.
2. Every CTA initialises its own barrier for 1 arrival.
3. The source CTA (`pid == 0` below) calls `T.copy_cluster(... dst_block=1, remote_barrier=...)`.
4. The destination CTA (`pid == 1`) waits on its local barrier copy.

```python
import tilelang
import tilelang.language as T

@tilelang.jit(verbose=True, execution_backend="cython")
def make_cluster_copy_kernel(N: int):
@T.prim_func
def kernel(
A: T.Tensor((N,), "float32"),
B: T.Tensor((N,), "float32"),
):
with T.Kernel(2, threads=128, cluster_dims=(2, 1, 1)) as pid:
s_src = T.alloc_shared((N,), "float32")
s_dst = T.alloc_shared((N,), "float32")
s_barrier = T.alloc_cluster_barrier([1])

T.fill(s_src, 0.0)
T.fill(s_dst, 0.0)

T.cluster_sync()

if pid == 0:
# Load A into local shared memory.
for i in T.Parallel(N):
s_src[i] = A[i]

# Async-push s_src → s_dst in CTA 1, signal CTA 1's barrier.
T.copy_cluster(s_src, s_dst, dst_block=1,
remote_barrier=s_barrier[0])

if pid == 1:
# Wait until CTA 0 finishes writing.
T.mbarrier_wait_parity(s_barrier[0], 0)

for i in T.Parallel(N):
B[i] = s_dst[i]

return kernel
```

Generated producer code (single-thread guard, one PTX instruction):

```cuda
if (((int)threadIdx.x) == 0) {
tl::tma_store_cluster(&s_dst[0], &s_src[0], 1,
(uint32_t)(N * 4), s_barrier[0]);
}
```

### Slow path — element-by-element SIMT fallback

Omit `remote_barrier` to use the slow path:

```python
T.copy_cluster(s_src, s_dst, dst_block=1)
```

This lowers to a SIMT parallel loop where every thread writes one (or a few)
elements into the remote CTA's shared memory via
`cooperative_groups::map_shared_rank`. Because `map_shared_rank` returns a
scalar pointer, vectorised writes are not possible. Use this path only when an
mbarrier is unavailable or when the tile is too small to justify barrier
overhead.

### Synchronisation contract

| | Fast path | Slow path |
|-|-----------|-----------|
| Source CTA | no wait needed; copy is async | effectively sync after the loop |
| Destination CTA | `T.mbarrier_wait_parity(barrier, parity)` | external `T.cluster_sync()` or equivalent |

### Notes

- Both paths require `src` and `dst` to be in `shared` or `shared.dyn` scope.
- The mbarrier must be allocated with `T.alloc_cluster_barrier([arrive_count])`.
- `T.cluster_sync()` after allocation but before the copy is required to ensure
all CTAs have reached the barrier-init barrier before any data is pushed.
- `dst_block` may be a compile-time integer or a runtime `tir.PrimExpr`.

---

## Cluster Helper Builtins

| Builtin | Return | Description |
|---------|--------|-------------|
| `T.block_rank_in_cluster()` | `int32` | Block rank (0-indexed) within the cluster |
| `T.get_cluster_block_nums()` | `int32` | Total number of CTAs in the cluster |
| `T.cluster_sync()` | — | Barrier synchronisation across all cluster CTAs |
| `T.alloc_cluster_barrier([count])` | `Buffer` | Allocate and initialise an mbarrier for `count` arrivals |
| `T.mbarrier_arrive(bar)` | — | Signal one arrival on an mbarrier |
| `T.mbarrier_wait_parity(bar, parity)` | — | Wait until `bar` flips to `parity` |

---

## Putting It Together: Split-K Sketch

A common pattern combining both features: multicast the shared K-panel to
all cluster CTAs (saving DRAM bandwidth), then reduce partial sums with
SM-to-SM copy (saving global-memory round trips).

```python
@T.prim_func
def split_k_gemm(A, B, C):
with T.Kernel(grid_x, grid_y, threads=256, cluster_dims=(4, 1, 1)) as (bx, by):
rank = T.block_rank_in_cluster()
A_s = T.alloc_shared((BM, BK), "float16")
B_s = T.alloc_shared((BK, BN), "float16")
C_f = T.alloc_fragment((BM, BN), "float32")
C_s = T.alloc_shared((BM, BN), "float32")
barrier = T.alloc_cluster_barrier([3])
T.clear(C_f)

# Phase 1: each CTA loads its K-slice; A is multicast to rank 0 and 1.
for ko in T.Pipelined(T.ceildiv(K, BK * 4), num_stages=3):
k_off = (rank + ko * 4) * BK
T.copy_cluster(A[by * BM, k_off], A_s, cluster_mask=0b0011)
T.copy(B[k_off, bx * BN], B_s)
T.gemm(A_s, B_s, C_f)

# Phase 2: push each rank's partial sums to rank 0 for accumulation.
#
# Use a per-rank staging slot so every non-zero rank writes to a
# distinct destination region — avoiding both a destination race and
# an arrival-count mismatch. Each CTA stores its own partial into
# C_parts[rank]; non-zero ranks then push that slot to the matching
# slot in rank 0's shared memory.
#
# Arrival count must equal the number of producers: cluster_size - 1.
C_parts = T.alloc_shared((4, BM, BN), "float32") # one slot per rank
T.copy(C_f, C_parts[rank])

T.cluster_sync()

if rank != 0:
# Push this rank's slot to the *same* slot index in rank 0's
# C_parts — different offsets, so no destination race.
T.copy_cluster(C_parts[rank], C_parts[rank],
dst_block=0, remote_barrier=barrier[0])

if rank == 0:
T.mbarrier_wait_parity(barrier[0], 0) # wakes after all 3 arrivals
# C_parts[0..3] in rank 0's smem now hold all four partial sums.
# accumulate and store ...
T.copy(C_parts[0], C[by * BM, bx * BN])
```

---

## See Also

- `testing/python/cuda/test_tma_multicast_demo.py` — multicast validation
- `testing/python/cuda/test_tma_dsmem.py` — SM-to-SM copy validation
- Programming Guides → Instructions — complete `T.copy` parameter reference
- Programming Guides → Control Flow — `T.Pipelined` and warp-specialized pipelines
20 changes: 20 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ TIR_DEFINE_TL_BUILTIN(tma_load_im2col)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tma_load_multicast)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tma_store).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

Expand Down Expand Up @@ -278,6 +283,11 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(get_cluster_block_nums)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(get_lane_idx)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down Expand Up @@ -531,5 +541,15 @@ TIR_DEFINE_TL_BUILTIN(stg128).set_num_inputs(-1).set_attr<TCallEffectKind>(
TIR_DEFINE_TL_BUILTIN(stg256).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_cluster_store)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tma_store_cluster)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

} // namespace tl
} // namespace tvm
Loading
Loading