diff --git a/cula/ops/chunk_delta_h_sm100.py b/cula/ops/chunk_delta_h_sm100.py index c4c84af1..61d4b0f0 100644 --- a/cula/ops/chunk_delta_h_sm100.py +++ b/cula/ops/chunk_delta_h_sm100.py @@ -29,7 +29,7 @@ import torch.nn.functional as F import triton from cutlass._mlir.dialects import llvm as _llvm -from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.nvgpu import OperandMajorMode, cpasync, tcgen05 from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream from cutlass.cute.typing import Float32, Int32, Int64 from cutlass.cutlass_dsl import T as _T @@ -308,8 +308,9 @@ def __call__( # WH MMA: A=state(TMEM, K-major), B=W(SMEM, K-major) wh_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.K, # A: state, K-major (required for TMEM source) - tcgen05.OperandMajorMode.K, # B: W, K-major (BK contiguous) + self.io_dtype, + OperandMajorMode.K, # A: state, K-major (required for TMEM source) + OperandMajorMode.K, # B: W, K-major (BK contiguous) self.acc_dtype, self.cta_group, self.wh_mma_tiler[:2], @@ -319,8 +320,9 @@ def __call__( # KV MMA: A=v_new^T(TMEM, K-major required), B=K^T(SMEM, MN-major) kv_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.K, # A: v_new, K-major (required for TMEM source) - tcgen05.OperandMajorMode.MN, # B: K^T, MN-major (BK contiguous) + self.io_dtype, + OperandMajorMode.K, # A: v_new, K-major (required for TMEM source) + OperandMajorMode.MN, # B: K^T, MN-major (BK contiguous) self.acc_dtype, self.cta_group, self.kv_mma_tiler[:2], diff --git a/cula/ops/chunk_wy_dqkg_sm100.py b/cula/ops/chunk_wy_dqkg_sm100.py index cafbb247..2d49cf05 100644 --- a/cula/ops/chunk_wy_dqkg_sm100.py +++ b/cula/ops/chunk_wy_dqkg_sm100.py @@ -14,7 +14,7 @@ mbarrier_init_fence, mbarrier_wait, ) -from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.nvgpu import OperandMajorMode, cpasync, tcgen05 from cutlass.cute.nvgpu.tcgen05 import ( make_umma_smem_desc, smem_descriptor_to_int, @@ -605,8 +605,9 @@ def __call__( # dq += do @ h, dk += vnew @ dh, dw += dv @ h vloop_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.K, # A: K-major - tcgen05.OperandMajorMode.K, # B: K-major + self.io_dtype, + OperandMajorMode.K, # A: K-major + OperandMajorMode.K, # B: K-major self.acc_dtype, self.cta_group, self.vloop_gemm_tiler[:2], # (64, 128) @@ -617,8 +618,9 @@ def __call__( # dA += dv @ v^T, dA += dw @ kg^T dA_vloop_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.K, - tcgen05.OperandMajorMode.K, + self.io_dtype, + OperandMajorMode.K, + OperandMajorMode.K, self.acc_dtype, self.cta_group, self.dA_vloop_tiler[:2], # (64, 64) @@ -629,8 +631,9 @@ def __call__( # dvb = A @ dv, dkgb = A @ dw dvb_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.MN, - tcgen05.OperandMajorMode.MN, + self.io_dtype, + OperandMajorMode.MN, + OperandMajorMode.MN, self.acc_dtype, self.cta_group, self.dvb_tiler[:2], # (64, 64) @@ -639,8 +642,9 @@ def __call__( # dkgb_tiled_mma: SS MN,MN (64,128) - dkgb dkgb_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.MN, - tcgen05.OperandMajorMode.MN, + self.io_dtype, + OperandMajorMode.MN, + OperandMajorMode.MN, self.acc_dtype, self.cta_group, self.kloop_dkgb_tiler[:2], # (64, 128) @@ -650,8 +654,9 @@ def __call__( # dA += dw @ kg^T dA_kloop_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.K, - tcgen05.OperandMajorMode.K, + self.io_dtype, + OperandMajorMode.K, + OperandMajorMode.K, self.acc_dtype, self.cta_group, self.kloop_dA_tiler[:2], # (64, 64) @@ -661,8 +666,9 @@ def __call__( # dA = dA @ A dA2post_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.K, - tcgen05.OperandMajorMode.K, + self.io_dtype, + OperandMajorMode.K, + OperandMajorMode.K, self.acc_dtype, self.cta_group, self.dApost_tiler[:2], # (64, 64) @@ -673,8 +679,9 @@ def __call__( # dA = A @ dA dA3post_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.MN, - tcgen05.OperandMajorMode.MN, + self.io_dtype, + OperandMajorMode.MN, + OperandMajorMode.MN, self.acc_dtype, self.cta_group, self.dApost_tiler[:2], # (64, 64) diff --git a/cula/ops/cp/pre_scan.py b/cula/ops/cp/pre_scan.py index befef74d..b5a363fa 100644 --- a/cula/ops/cp/pre_scan.py +++ b/cula/ops/cp/pre_scan.py @@ -20,7 +20,7 @@ import cutlass.utils as utils import cutlass.utils.blackwell_helpers as sm100_utils import torch -from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.nvgpu import OperandMajorMode, cpasync, tcgen05 from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream from cutlass.cute.typing import Float32, Int32, Int64 @@ -272,8 +272,9 @@ def __call__( # ===================== MMA setup (same as fwd_h) ===================== wh_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.K, - tcgen05.OperandMajorMode.K, + self.io_dtype, + OperandMajorMode.K, + OperandMajorMode.K, self.acc_dtype, self.cta_group, self.wh_mma_tiler[:2], @@ -281,8 +282,9 @@ def __call__( ) kv_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.K, - tcgen05.OperandMajorMode.MN, + self.io_dtype, + OperandMajorMode.K, + OperandMajorMode.MN, self.acc_dtype, self.cta_group, self.kv_mma_tiler[:2], diff --git a/cula/ops/fwd_o_sm100.py b/cula/ops/fwd_o_sm100.py index 9d4bcb46..a4dc85f4 100644 --- a/cula/ops/fwd_o_sm100.py +++ b/cula/ops/fwd_o_sm100.py @@ -74,7 +74,7 @@ import cutlass.utils as utils import cutlass.utils.blackwell_helpers as sm100_utils import torch -from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.nvgpu import OperandMajorMode, cpasync, tcgen05 from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream from cutlass.cute.typing import Float32, Int32, Int64 from fla.ops.utils import prepare_chunk_indices @@ -354,8 +354,9 @@ def __call__( # B is MN-major because h_T GMEM has V(=N) contiguous qh_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.K, # A: K-major (TMEM requires K-major) - tcgen05.OperandMajorMode.MN, # B: MN-major (V contiguous in GMEM) + self.io_dtype, + OperandMajorMode.K, # A: K-major (TMEM requires K-major) + OperandMajorMode.MN, # B: MN-major (V contiguous in GMEM) self.acc_dtype, self.cta_group, self.qh_mma_tiler[:2], @@ -366,8 +367,9 @@ def __call__( # B is MN-major because v_T GMEM has V(=N) contiguous av_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.K, # A: K-major (TMEM requires K-major) - tcgen05.OperandMajorMode.MN, # B: MN-major (V contiguous in GMEM) + self.io_dtype, + OperandMajorMode.K, # A: K-major (TMEM requires K-major) + OperandMajorMode.MN, # B: MN-major (V contiguous in GMEM) self.acc_dtype, self.cta_group, self.av_mma_tiler[:2], @@ -561,8 +563,9 @@ class SharedStorage: # B operand majorness must match av_tiled_mma for C layout compatibility. am_coord_mma = sm100_utils.make_trivial_tiled_mma( self.io_dtype, - tcgen05.OperandMajorMode.K, - tcgen05.OperandMajorMode.MN, + self.io_dtype, + OperandMajorMode.K, + OperandMajorMode.MN, self.acc_dtype, self.cta_group, (self.BT, self.BT), diff --git a/cula/ops/intrinsics_sm100.py b/cula/ops/intrinsics_sm100.py index e7c1ce5b..0d319e41 100644 --- a/cula/ops/intrinsics_sm100.py +++ b/cula/ops/intrinsics_sm100.py @@ -68,7 +68,6 @@ import cutlass.cute as cute from cutlass._mlir import ir as _ir_mod -from cutlass._mlir.dialects import arith as _arith from cutlass._mlir.dialects import llvm from cutlass._mlir.dialects import nvvm as _nvvm from cutlass._mlir.dialects import vector as _vector @@ -117,7 +116,6 @@ def _do(addr_val, *, loc=None, ip=None): return _nvvm.tcgen05_ld( res=vec_i32_ty, shape=_nvvm.Tcgen05LdStShape.SHAPE_32X32B, - num=num, tmem_addr=tmem_ptr, loc=loc, ip=ip, @@ -154,9 +152,8 @@ def _do(addr_val, vec_val, *, loc=None, ip=None): tmem_ptr = llvm.inttoptr(ptr6_ty, _to_ir(addr_val, loc, ip), loc=loc, ip=ip) _nvvm.tcgen05_st( shape=_nvvm.Tcgen05LdStShape.SHAPE_32X32B, - num=num, tmem_addr=tmem_ptr, - r=_to_ir(vec_val, loc, ip), + val=_to_ir(vec_val, loc, ip), loc=loc, ip=ip, ) @@ -279,12 +276,12 @@ def store_256b(gmem_ptr, vec): @dsl_user_op def _do(addr, v, *, loc=None, ip=None): - i32_ty = _ir_mod.IntegerType.get_signless(32) ir_v = _to_ir(v, loc, ip) elems = [ - _vector.extractelement( + _vector.extract( ir_v, - position=_arith.constant(i32_ty, i, loc=loc, ip=ip), + dynamic_position=[], + static_position=[i], loc=loc, ip=ip, ) diff --git a/cula/ops/kda_fully_fused_sm100_wip.py b/cula/ops/kda_fully_fused_sm100_wip.py index ab09f0b2..766b3f07 100644 --- a/cula/ops/kda_fully_fused_sm100_wip.py +++ b/cula/ops/kda_fully_fused_sm100_wip.py @@ -65,7 +65,7 @@ import cutlass.utils as utils import cutlass.utils.blackwell_helpers as sm100_utils import torch -from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.nvgpu import OperandMajorMode, cpasync, tcgen05 from cutlass.cute.runtime import from_dlpack from cutlass.cute.typing import Int32, Int64 from fla.modules.l2norm import l2norm_fwd @@ -478,13 +478,13 @@ def __call__( self.k_major_mode = utils.LayoutEnum.from_tensor(k).mma_major_mode() self.v_major_mode = utils.LayoutEnum.from_tensor(v).mma_major_mode() self.g_major_mode = utils.LayoutEnum.from_tensor(g).mma_major_mode() # NEW for KDA - self.k_major_mode_kv = tcgen05.OperandMajorMode.MN # For V^T*K, S dimension coalesced + self.k_major_mode_kv = OperandMajorMode.MN # For V^T*K, S dimension coalesced # TMEM register output results as (D, C) self.o_layout = utils.LayoutEnum.from_tensor(o) - if cutlass.const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + if cutlass.const_expr(self.q_major_mode != OperandMajorMode.K): raise RuntimeError("The layout of q is not supported") - if cutlass.const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + if cutlass.const_expr(self.k_major_mode != OperandMajorMode.K): raise RuntimeError("The layout of k is not supported") if cutlass.const_expr(self.o_layout != utils.LayoutEnum.COL_MAJOR): raise RuntimeError("The layout of o is not supported") @@ -492,6 +492,7 @@ def __call__( raise RuntimeError("The layout of k & k^t should be different") qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, self.q_major_mode, self.k_major_mode, @@ -500,6 +501,7 @@ def __call__( self.qk_mma_tiler[:2], ) kk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.k_dtype, self.k_dtype, # SHOULE BE both K-major self.k_major_mode, @@ -519,9 +521,10 @@ def __call__( ) # State^T Q^T sq_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.io_dtype, self.io_dtype, # State is in TMEM, always K major, TODO - tcgen05.OperandMajorMode.K, + OperandMajorMode.K, self.q_major_mode, self.acc_dtype, self.cta_group, @@ -529,9 +532,10 @@ def __call__( a_source=tcgen05.OperandSource.TMEM, ) ks_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.io_dtype, self.io_dtype, # State is in TMEM, always K major, TODO - tcgen05.OperandMajorMode.K, + OperandMajorMode.K, # State is in TMEM, always K major, TODO self.k_major_mode, self.acc_dtype, @@ -540,8 +544,9 @@ def __call__( a_source=tcgen05.OperandSource.TMEM, ) - m_major_mode = tcgen05.OperandMajorMode.K + m_major_mode = OperandMajorMode.K mv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, self.v_dtype, self.v_major_mode, m_major_mode, @@ -550,8 +555,9 @@ def __call__( self.mv_mma_tiler[:2], ) - p_major_mode = tcgen05.OperandMajorMode.K + p_major_mode = OperandMajorMode.K vp_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, self.v_dtype, self.v_major_mode, p_major_mode, @@ -1459,6 +1465,7 @@ def kernel( ############################################ kv_mma_tiler2 = (self.kv_mma_tiler[0], self.kv_mma_tiler[1] // 2, self.kv_mma_tiler[2]) fake_kv_tiled_mma_acc32 = sm100_utils.make_trivial_tiled_mma( + self.k_dtype, self.k_dtype, self.v_major_mode, self.k_major_mode_kv, diff --git a/cula/ops/lightning_attn_sm100.py b/cula/ops/lightning_attn_sm100.py index 8a6b204e..a499f8b4 100644 --- a/cula/ops/lightning_attn_sm100.py +++ b/cula/ops/lightning_attn_sm100.py @@ -59,7 +59,7 @@ import cutlass.utils.blackwell_helpers as sm100_utils import torch from cutlass._mlir.dialects import llvm as _llvm -from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.nvgpu import OperandMajorMode, cpasync, tcgen05 from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream from cutlass.cute.typing import Float32, Int32, Int64 from cutlass.cutlass_dsl import T as _T @@ -435,13 +435,13 @@ def __call__( self.q_major_mode = utils.LayoutEnum.from_tensor(q).mma_major_mode() self.k_major_mode = utils.LayoutEnum.from_tensor(k).mma_major_mode() self.v_major_mode = utils.LayoutEnum.from_tensor(v).mma_major_mode() - self.k_major_mode_kv = tcgen05.OperandMajorMode.MN # For V^T*K, S dimension coalesced + self.k_major_mode_kv = OperandMajorMode.MN # For V^T*K, S dimension coalesced # TMEM register output results as (D, C) self.o_layout = utils.LayoutEnum.from_tensor(o) - if cutlass.const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + if cutlass.const_expr(self.q_major_mode != OperandMajorMode.K): raise RuntimeError("The layout of q is not supported") - if cutlass.const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + if cutlass.const_expr(self.k_major_mode != OperandMajorMode.K): raise RuntimeError("The layout of k is not supported") if cutlass.const_expr(self.o_layout != utils.LayoutEnum.COL_MAJOR): raise RuntimeError("The layout of o is not supported") @@ -449,6 +449,7 @@ def __call__( raise RuntimeError("The layout of k & k^t should be different") qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, self.q_major_mode, self.k_major_mode, @@ -458,6 +459,7 @@ def __call__( ) # V^T*K, majorness kv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.k_dtype, self.k_dtype, self.v_major_mode, self.k_major_mode_kv, @@ -467,17 +469,19 @@ def __call__( ) # State^T Q^T sq_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.io_dtype, self.io_dtype, # State is in TMEM, always K major, TODO - tcgen05.OperandMajorMode.K, + OperandMajorMode.K, self.q_major_mode, self.acc_dtype, self.cta_group, self.sq_mma_tiler[:2], a_source=tcgen05.OperandSource.TMEM, ) - p_major_mode = tcgen05.OperandMajorMode.K + p_major_mode = OperandMajorMode.K vp_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, self.v_dtype, self.v_major_mode, p_major_mode, @@ -1100,6 +1104,7 @@ def kernel( ############################################ kv_mma_tiler2 = (self.kv_mma_tiler[0], self.kv_mma_tiler[1] // 2, self.kv_mma_tiler[2]) fake_kv_tiled_mma_acc32 = sm100_utils.make_trivial_tiled_mma( + self.k_dtype, self.k_dtype, self.v_major_mode, self.k_major_mode_kv, diff --git a/cula/ops/linear_attn_sm100.py b/cula/ops/linear_attn_sm100.py index 64a11755..3ffe99f9 100644 --- a/cula/ops/linear_attn_sm100.py +++ b/cula/ops/linear_attn_sm100.py @@ -67,7 +67,7 @@ import cutlass.utils as utils import cutlass.utils.blackwell_helpers as sm100_utils import torch -from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.nvgpu import OperandMajorMode, cpasync, tcgen05 from cutlass.cute.runtime import from_dlpack from cutlass.cute.typing import Int32, Int64 @@ -350,13 +350,13 @@ def __call__( self.q_major_mode = utils.LayoutEnum.from_tensor(q).mma_major_mode() self.k_major_mode = utils.LayoutEnum.from_tensor(k).mma_major_mode() self.v_major_mode = utils.LayoutEnum.from_tensor(v).mma_major_mode() - self.k_major_mode_kv = tcgen05.OperandMajorMode.MN # For V^T*K, S dimension coalesced + self.k_major_mode_kv = OperandMajorMode.MN # For V^T*K, S dimension coalesced # TMEM register output results as (D, C) self.o_layout = utils.LayoutEnum.from_tensor(o) - if cutlass.const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + if cutlass.const_expr(self.q_major_mode != OperandMajorMode.K): raise RuntimeError("The layout of q is not supported") - if cutlass.const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + if cutlass.const_expr(self.k_major_mode != OperandMajorMode.K): raise RuntimeError("The layout of k is not supported") if cutlass.const_expr(self.o_layout != utils.LayoutEnum.COL_MAJOR): raise RuntimeError("The layout of o is not supported") @@ -364,6 +364,7 @@ def __call__( raise RuntimeError("The layout of k & k^t should be different") qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, self.q_major_mode, self.k_major_mode, @@ -373,6 +374,7 @@ def __call__( ) # V^T*K, majorness kv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.k_dtype, self.k_dtype, self.v_major_mode, self.k_major_mode_kv, @@ -382,17 +384,19 @@ def __call__( ) # State^T Q^T sq_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.io_dtype, self.io_dtype, # State is in TMEM, always K major, TODO - tcgen05.OperandMajorMode.K, + OperandMajorMode.K, self.q_major_mode, self.acc_dtype, self.cta_group, self.sq_mma_tiler[:2], a_source=tcgen05.OperandSource.TMEM, ) - p_major_mode = tcgen05.OperandMajorMode.K + p_major_mode = OperandMajorMode.K vp_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, self.v_dtype, self.v_major_mode, p_major_mode, @@ -927,6 +931,7 @@ def kernel( ############################################ kv_mma_tiler2 = (self.kv_mma_tiler[0], self.kv_mma_tiler[1] // 2, self.kv_mma_tiler[2]) fake_kv_tiled_mma_acc32 = sm100_utils.make_trivial_tiled_mma( + self.k_dtype, self.k_dtype, self.v_major_mode, self.k_major_mode_kv, diff --git a/pyproject.toml b/pyproject.toml index ef1a531b..ed835b4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ readme = "README.md" authors = [ { name = "cula contributors" } ] requires-python = ">=3.10" dependencies = [ - "nvidia-cutlass-dsl>=4.4.2", + "nvidia-cutlass-dsl>=4.6.0.dev0", "apache-tvm-ffi>=0.1.9", ] license = { text = "Apache-2.0" } diff --git a/tests/conftest.py b/tests/conftest.py index f144c10b..a9338aca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import re + import pytest import torch @@ -56,9 +57,5 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_slow) continue callspec = getattr(item, "callspec", None) - if ( - callspec is not None - and callspec.params.get("disable_recompute") - and "kda_fast_norecomp" not in item.keywords - ): + if callspec is not None and callspec.params.get("disable_recompute") and "kda_fast_norecomp" not in item.keywords: item.add_marker(skip_fast_norecomp)