Skip to content

Commit 586a2b3

Browse files
Update docs
1 parent 8730260 commit 586a2b3

13 files changed

Lines changed: 198 additions & 72 deletions

File tree

_sources/autoapi/tilelang/intrinsics/tcgen05_macro_generator/index.rst.txt

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,25 @@ Module Contents
120120

121121

122122

123+
.. py:method:: tcgen05mma_ss(A_buf, B_buf, C_local_buf, mbar, clear_accum = False)
124+
125+
Emit the SS (Shared-Shared) variant of TCGEN5MMA.
126+
127+
Reads operand A and B from shared memory via a descriptor.
128+
129+
:param A_buf: Operand A in shared memory.
130+
:type A_buf: Buffer
131+
:param B_buf: Operand B in shared memory.
132+
:type B_buf: Buffer
133+
:param C_local_buf: Accumulator buffer in tensor memory.
134+
:type C_local_buf: Buffer
135+
:param mbar: Memory barrier for MMA completion signalling.
136+
:type mbar: PrimExpr
137+
:param clear_accum: Whether to zero the accumulator before the first MMA.
138+
:type clear_accum: PrimExpr
139+
140+
141+
123142
.. py:method:: tcgen05mma_ts(A_buf, B_buf, C_local_buf, mbar, clear_accum = False)
124143
125144
Emit the TS (TensorMemory-Shared) variant of TCGEN5MMA.
@@ -176,9 +195,9 @@ Module Contents
176195

177196

178197

179-
.. py:method:: get_tcgen5_mma_meta(m, n, k)
198+
.. py:method:: get_tcgen5_mma_meta(m, n, k, disable_2cta)
180199
181-
Query the FFI for TCGEN5MMA atom metadata (atom_m, atom_n, atom_k, enable_ws, enable_2cta).
200+
Query the FFI for TCGEN5MMA atom metadata (atom_m, atom_n, atom_k, enable_ws, enable_2cta), and record them in `self.meta`.
182201

183202

184203

_sources/autoapi/tilelang/language/builtin/index.rst.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Functions
1717
tilelang.language.builtin.access_ptr
1818
tilelang.language.builtin.create_tma_descriptor
1919
tilelang.language.builtin.tma_load
20+
tilelang.language.builtin.tma_load_2sm
2021
tilelang.language.builtin.fence_proxy_async
2122
tilelang.language.builtin.tma_store_arrive
2223
tilelang.language.builtin.tma_store_wait
@@ -135,6 +136,16 @@ Module Contents
135136
:rtype: tir.Call
136137

137138

139+
.. py:function:: tma_load_2sm(*args)
140+
141+
Perform a Tensor Memory Access (TMA) load operation with 2SM on Blackwell.
142+
143+
:param \*args: Variable arguments specifying the TMA load parameters
144+
145+
:returns: A handle to the TMA load operation
146+
:rtype: tir.Call
147+
148+
138149
.. py:function:: fence_proxy_async(*args)
139150
140151
Create a fence for asynchronous proxy operations.
@@ -556,12 +567,15 @@ Module Contents
556567
Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
557568

558569

559-
.. py:function:: tcgen05_mma_arrive(mbar)
570+
.. py:function:: tcgen05_mma_arrive(mbar, arrive_2cta = False)
560571
561572
Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer.
562573

563574
:param mbar: The mbarrier object in shared memory (e.g., Barrier*) or its address.
564575
:type mbar: tir.Buffer | BufferLoad | PrimExpr
576+
:param arrive_2cta: Whether to also arrive at the peer CTA's barrier.
577+
If set, will be lowered to umma_arrive_multicast_2x1SM.
578+
:type arrive_2cta: bool
565579

566580

567581
.. py:function:: ptx_mma_sm70(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index)

_sources/autoapi/tilelang/language/gemm_op/index.rst.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ Module Contents
8484
compilation fails instead of silently falling back to MMA.
8585

8686

87-
.. py:function:: tcgen05_gemm(A, B, C, transpose_A = False, transpose_B = False, policy = GemmWarpPolicy.Square, clear_accum = False, *, mbar)
87+
.. py:function:: tcgen05_gemm(A, B, C, transpose_A = False, transpose_B = False, policy = GemmWarpPolicy.Square, clear_accum = False, *, mbar, use_2cta = False)
8888
8989
Explicit Blackwell TCGEN05 GEMM without an implicit wait.
9090

@@ -93,6 +93,9 @@ Module Contents
9393
- it always requests the TCGEN5MMA lowering path
9494
- it never auto-emits an inlined `mbarrier_wait_parity`
9595

96+
When ``use_2cta=True``, the instruction is lowered to the 2CTA variant
97+
which requires ``cluster_dims`` to be ``(2,1,1)`` or ``(1,2,1)``.
98+
9699
If the current target or operand pattern cannot use Blackwell TCGEN5MMA,
97100
compilation fails instead of silently falling back to another GEMM path.
98101

_sources/autoapi/tilelang/language/tir/op/index.rst.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -872,11 +872,11 @@ Module Contents
872872

873873
.. py:function:: ptx_wgmma_rs(dtype, wgmma_prefix, b_is_k_major, a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, A_buf, A_offset, B_desc, B_offset, C_data, C_offset, scale_out, scale_in_a, scale_in_b)
874874
875-
.. py:function:: ptx_tcgen05_mma_ss(kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr, C_offset, desc_val, scale_out, mask0, mask1, mask2, mask3, enable_ws=False, ws=None, warp_specialized=None, variant=None)
875+
.. py:function:: ptx_tcgen05_mma_ss(kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr, C_offset, desc_val, scale_out, mask0, mask1, mask2, mask3, enable_ws=False, enable_2cta=False, ws=None, warp_specialized=None, variant=None)
876876
877-
TVM intrinsic for tcgen05.mma shared-memory × shared-memory instructions.
877+
TVM intrinsic for tcgen05.mma shared-memory x shared-memory instructions.
878878

879-
Expects 13 or 14 positional arguments:
879+
Expects 14 or 15 positional arguments:
880880
(kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr, C_offset,
881881
desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws]).
882882
Aliases: you can also pass `ws` or `warp_specialized` (booleans) instead of `enable_ws`.
@@ -885,9 +885,9 @@ Module Contents
885885
"tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4).
886886

887887

888-
.. py:function:: ptx_tcgen05_mma_ts(kind_dtype, A_ptr, A_offset, desc_b, B_offset, C_ptr, C_offset, desc_val, scale_out, mask0, mask1, mask2, mask3)
888+
.. py:function:: ptx_tcgen05_mma_ts(kind_dtype, A_ptr, A_offset, desc_b, B_offset, C_ptr, C_offset, desc_val, scale_out, mask0, mask1, mask2, mask3, enable_2cta=False)
889889
890-
TVM intrinsic for tcgen05.mma tensor-memory × shared-memory instructions.
890+
TVM intrinsic for tcgen05.mma tensor-memory x shared-memory instructions.
891891

892892
Expects 13 positional arguments:
893893
(kind_dtype, A_ptr, A_offset, desc_b, B_offset, C_ptr, C_offset,

_sources/autoapi/tilelang/transform/index.rst.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ Functions
8484
tilelang.transform.LayoutReducer
8585
tilelang.transform.UnrollLoop
8686
tilelang.transform.LowerLDGSTG
87+
tilelang.transform.LowerBlackwell2SM
8788

8889

8990
Package Contents
@@ -547,3 +548,14 @@ Package Contents
547548
:rtype: tvm.transform.Pass
548549

549550

551+
.. py:function:: LowerBlackwell2SM()
552+
553+
Lower 2SM TCGEN5MMA and related on Blackwell target
554+
555+
:returns:
556+
557+
tvm.transform.Pass
558+
The result pass
559+
:rtype: fpass
560+
561+

autoapi/tilelang/intrinsics/tcgen05_macro_generator/index.html

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,24 @@ <h2>Module Contents<a class="headerlink" href="#module-contents" title="Link to
745745
</dl>
746746
</dd></dl>
747747

748+
<dl class="py method">
749+
<dt class="sig sig-object py" id="tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.tcgen05mma_ss">
750+
<span class="sig-name descname"><span class="pre">tcgen05mma_ss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">A_buf</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">B_buf</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">C_local_buf</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mbar</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">clear_accum</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.tcgen05mma_ss" title="Link to this definition"></a></dt>
751+
<dd><p>Emit the SS (Shared-Shared) variant of TCGEN5MMA.</p>
752+
<p>Reads operand A and B from shared memory via a descriptor.</p>
753+
<dl class="field-list simple">
754+
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
755+
<dd class="field-odd"><ul class="simple">
756+
<li><p><strong>A_buf</strong> (<em>Buffer</em>) – Operand A in shared memory.</p></li>
757+
<li><p><strong>B_buf</strong> (<em>Buffer</em>) – Operand B in shared memory.</p></li>
758+
<li><p><strong>C_local_buf</strong> (<em>Buffer</em>) – Accumulator buffer in tensor memory.</p></li>
759+
<li><p><strong>mbar</strong> (<em>PrimExpr</em>) – Memory barrier for MMA completion signalling.</p></li>
760+
<li><p><strong>clear_accum</strong> (<em>PrimExpr</em>) – Whether to zero the accumulator before the first MMA.</p></li>
761+
</ul>
762+
</dd>
763+
</dl>
764+
</dd></dl>
765+
748766
<dl class="py method">
749767
<dt class="sig sig-object py" id="tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.tcgen05mma_ts">
750768
<span class="sig-name descname"><span class="pre">tcgen05mma_ts</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">A_buf</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">B_buf</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">C_local_buf</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mbar</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">clear_accum</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.tcgen05mma_ts" title="Link to this definition"></a></dt>
@@ -815,14 +833,15 @@ <h2>Module Contents<a class="headerlink" href="#module-contents" title="Link to
815833

816834
<dl class="py method">
817835
<dt class="sig sig-object py" id="tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.get_tcgen5_mma_meta">
818-
<span class="sig-name descname"><span class="pre">get_tcgen5_mma_meta</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">m</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">k</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.get_tcgen5_mma_meta" title="Link to this definition"></a></dt>
819-
<dd><p>Query the FFI for TCGEN5MMA atom metadata (atom_m, atom_n, atom_k, enable_ws, enable_2cta).</p>
836+
<span class="sig-name descname"><span class="pre">get_tcgen5_mma_meta</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">m</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">k</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">disable_2cta</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.get_tcgen5_mma_meta" title="Link to this definition"></a></dt>
837+
<dd><p>Query the FFI for TCGEN5MMA atom metadata (atom_m, atom_n, atom_k, enable_ws, enable_2cta), and record them in <cite>self.meta</cite>.</p>
820838
<dl class="field-list simple">
821839
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
822840
<dd class="field-odd"><ul class="simple">
823841
<li><p><strong>m</strong> (<em>int</em>)</p></li>
824842
<li><p><strong>n</strong> (<em>int</em>)</p></li>
825843
<li><p><strong>k</strong> (<em>int</em>)</p></li>
844+
<li><p><strong>disable_2cta</strong> (<a class="reference internal" href="../../language/dtypes/index.html#tilelang.language.dtypes.bool" title="tilelang.language.dtypes.bool"><em>bool</em></a>)</p></li>
826845
</ul>
827846
</dd>
828847
</dl>
@@ -933,6 +952,7 @@ <h2>Module Contents<a class="headerlink" href="#module-contents" title="Link to
933952
<li><a class="reference internal" href="#tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.a_shared_layout"><code class="docutils literal notranslate"><span class="pre">TensorCoreIntrinEmitter.a_shared_layout</span></code></a></li>
934953
<li><a class="reference internal" href="#tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.b_shared_layout"><code class="docutils literal notranslate"><span class="pre">TensorCoreIntrinEmitter.b_shared_layout</span></code></a></li>
935954
<li><a class="reference internal" href="#tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.tcgen05mma"><code class="docutils literal notranslate"><span class="pre">TensorCoreIntrinEmitter.tcgen05mma()</span></code></a></li>
955+
<li><a class="reference internal" href="#tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.tcgen05mma_ss"><code class="docutils literal notranslate"><span class="pre">TensorCoreIntrinEmitter.tcgen05mma_ss()</span></code></a></li>
936956
<li><a class="reference internal" href="#tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.tcgen05mma_ts"><code class="docutils literal notranslate"><span class="pre">TensorCoreIntrinEmitter.tcgen05mma_ts()</span></code></a></li>
937957
<li><a class="reference internal" href="#tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.make_mma_load_layout"><code class="docutils literal notranslate"><span class="pre">TensorCoreIntrinEmitter.make_mma_load_layout()</span></code></a></li>
938958
<li><a class="reference internal" href="#tilelang.intrinsics.tcgen05_macro_generator.TensorCoreIntrinEmitter.make_mma_store_layout"><code class="docutils literal notranslate"><span class="pre">TensorCoreIntrinEmitter.make_mma_store_layout()</span></code></a></li>

0 commit comments

Comments
 (0)