From e5a0d0877c1358ae68e1edc9a416c2f000165183 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Mon, 13 Apr 2026 07:33:22 +0000 Subject: [PATCH 1/3] [Perf] Port aiter mixed_moe kernel optimizations for stage1/stage2 Port performance-critical optimizations from aiter's mixed_moe_gemm_2stage kernel body (both stage1 and stage2) into FlyDSL, along with supporting infrastructure changes. Key changes: - mixed_moe_gemm_2stage.py: Full kernel body replacement with aiter version featuring dual SmemAllocator (ping-pong), unified MFMA pipeline schedule, _barrier() for fine-grained waitcnt control, and new parameters (persist_m, fuse_fp4_quant, fuse_sort_scale, use_async_copy, sort_block_m, etc.) - layout_utils.py: New file ported from aiter for layout index arithmetic (crd2idx, idx2crd, _div_pow2, _mod_pow2) - silu_and_mul_fq.py: New file ported from aiter for split-K + fp4 quant after silu fusion - mfma_preshuffle_pipeline.py: Added k_major support, cache_modifier param, bitwise-AND optimization in swizzle_xor16, PreshuffleScaleLayout additions - kernels_common.py: Extracted shared _if_then context manager and validate_moe_dtypes helper - mfma_epilogues.py: Replaced local _if_then with shared import Performance (DeepSeek TP8 FP4, 7168x256, E=256, K=8): - Stage1 Decode t=1: 37.3 -> 26.2 us (-29.8%) - Stage1 Decode t=8: 45.0 -> 31.0 us (-31.1%) - Stage1 Prefill 8K: 561.8 -> 348.8 us (-37.9%) - Stage2 Prefill 8K reduce: 569.1 -> 534.8 us (-6.0%) - FP8 stage2 unchanged (within noise) Made-with: Cursor --- kernels/kernels_common.py | 37 +- kernels/layout_utils.py | 181 ++ kernels/mfma_epilogues.py | 13 +- kernels/mfma_preshuffle_pipeline.py | 227 +- kernels/mixed_moe_gemm_2stage.py | 3759 ++++++++++++++++++--------- kernels/silu_and_mul_fq.py | 368 +++ 6 files changed, 3203 insertions(+), 1382 deletions(-) create mode 100644 kernels/layout_utils.py create mode 100644 kernels/silu_and_mul_fq.py diff --git a/kernels/kernels_common.py b/kernels/kernels_common.py index 3af725a6..3d73bdae 100644 --- a/kernels/kernels_common.py +++ b/kernels/kernels_common.py @@ -7,13 +7,48 @@ but this module is intentionally small and MLIR-dialect facing. """ +from contextlib import contextmanager + from flydsl._mlir import ir from flydsl.expr.typing import T -from flydsl._mlir.dialects import arith as _std_arith, builtin, gpu as _gpu, llvm as _llvm +from flydsl._mlir.dialects import arith as _std_arith, builtin, gpu as _gpu, llvm as _llvm, scf as _scf from flydsl.expr import buffer_ops from flydsl.runtime.device import get_rocm_arch, is_rdna_arch +@contextmanager +def _if_then(if_op, scf=None): + """Context manager for SCF IfOp then-region across old/new Python APIs. + + Ensures the then block always ends with a YieldOp. + The optional *scf* parameter is accepted for backward compatibility + but ignored — the module-level import is used. + """ + with ir.InsertionPoint(if_op.then_block): + try: + yield if_op.then_block + finally: + blk = if_op.then_block + if (not blk.operations) or not isinstance(blk.operations[-1], _scf.YieldOp): + _scf.YieldOp([]) + + +_VALID_A_DTYPES = frozenset(("fp8", "fp16", "int8", "fp4")) +_VALID_B_DTYPES = frozenset(("fp8", "fp16", "int8", "int4", "fp4")) + + +def validate_moe_dtypes(a_dtype: str, b_dtype: str) -> None: + """Validate a_dtype/b_dtype strings for mixed MoE kernels.""" + if a_dtype not in _VALID_A_DTYPES: + raise ValueError( + f"a_dtype must be one of {tuple(sorted(_VALID_A_DTYPES))}, got {a_dtype!r}" + ) + if b_dtype not in _VALID_B_DTYPES: + raise ValueError( + f"b_dtype must be one of {tuple(sorted(_VALID_B_DTYPES))}, got {b_dtype!r}" + ) + + def dtype_to_elem_type(dtype_str: str): """Map a dtype string to its MLIR scalar type. diff --git a/kernels/layout_utils.py b/kernels/layout_utils.py new file mode 100644 index 00000000..350c9c48 --- /dev/null +++ b/kernels/layout_utils.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Layout helpers for GEMM kernels. + +Parses fly layout type strings (e.g. '(4,64):(64,1)') and computes +idx2crd / crd2idx with plain arith ops for static layouts. +Falls back to fly dialect ops for dynamic layouts. + +Optimisation: power-of-2 strides/shapes emit ``shrui`` / ``andi`` instead of +``divui`` / ``remui``, avoiding 10-15-cycle V_DIV sequences on CDNA GPUs. +""" + +import math as _math +import re +import builtins as _builtins + +import flydsl.expr as fx +from flydsl._mlir import ir +from flydsl.expr import arith +from flydsl.expr.arith import ArithValue +from flydsl.expr.typing import T + + +def _wrap(v): + """Wrap raw ir.Value in ArithValue for operator overloading compatibility.""" + if isinstance(v, ArithValue): + return v + if isinstance(v, ir.Value): + return ArithValue(v) + return v + + +def _is_pow2(n): + """Return True when *n* is a positive power of two.""" + return n > 0 and (n & (n - 1)) == 0 + + +def _div_pow2(val, divisor): + """Unsigned divide index *val* by a **compile-time** power-of-2 *divisor*. + + Emits ``arith.shrui`` (1 VALU cycle) instead of ``arith.divui`` + (10-15 VALU cycles on CDNA). + """ + shift = _math.log2(divisor) + assert shift == int(shift), f"{divisor} is not a power of 2" + return arith.shrui(val, arith.index(int(shift))) + + +def _mod_pow2(val, modulus): + """Unsigned remainder of index *val* by a **compile-time** power-of-2 *modulus*. + + Emits ``arith.andi`` (1 VALU cycle) instead of ``arith.remui``. + """ + return arith.andi(val, arith.index(modulus - 1)) + + +def _parse_dim(tok): + """Parse a single dimension token: '?' -> None, otherwise int.""" + tok = tok.strip() + return None if tok == "?" else int(tok) + + +def _parse_layout(ly): + """Parse '(s0,s1,...):(d0,d1,...)' -> (shapes, strides) as lists (None for '?').""" + ly_str = str(ly.type) if hasattr(ly, "type") else str(ly) + m = re.search(r"\(([^)]+)\):\(([^)]+)\)", ly_str) + if not m: + return None + shapes = [_parse_dim(s) for s in m.group(1).split(",")] + strides = [_parse_dim(s) for s in m.group(2).split(",")] + return shapes, strides + + +def _has_dynamic_strides(strides): + """Check if any stride is dynamic (None).""" + return any(s is None for s in strides) + + +def idx2crd(idx, layout): + """Decompose flat index into a list of coordinate values. + + For static layouts, computes coordinates with plain arith ops. + Power-of-2 strides/shapes use shift/mask instead of div/rem. + For dynamic layouts, falls back to fx.idx2crd + fx.get. + """ + parsed = _parse_layout(layout) + + if parsed is None or _has_dynamic_strides(parsed[1]): + result = fx.idx2crd(idx, layout) + ndims = len(parsed[1]) if parsed else 1 + return [_wrap(fx.get(result, i)) for i in range(ndims)] + + if hasattr(idx, "type") and str(idx.type) != "index": + idx = arith.index_cast(T.index, idx) + shapes, strides = parsed + ndims = len(strides) + + ordered = sorted( + [ + (i, s, sz) + for i, s, sz in _builtins.zip(range(ndims), strides, shapes) + if s != 0 + ], + key=lambda x: x[1], + reverse=True, + ) + coords = [None] * ndims + remaining = idx + for i, stride_val, size_val in ordered: + if stride_val == 1: + c = remaining + elif _is_pow2(stride_val): + c = _div_pow2(remaining, stride_val) + else: + c = remaining / arith.index(stride_val) + if size_val is not None: + if _is_pow2(size_val): + c = _mod_pow2(c, size_val) + else: + c = c % arith.index(size_val) + coords[i] = c + for i in range(ndims): + if coords[i] is None: + coords[i] = remaining + return coords + + +def crd2idx(crd, layout): + """Compute flat index from a coordinate tuple/list. + + For static layouts, computes with plain arith ops. + For dynamic layouts, falls back to fx.crd2idx with fx.make_coord. + """ + if not isinstance(crd, (list, tuple)): + crd = [crd] + parsed = _parse_layout(layout) + + if parsed is None or _has_dynamic_strides(parsed[1]): + crd_i32 = [] + for c in crd: + cv = c + if isinstance(cv, int): + cv = arith.constant(cv, T.i32) + crd_i32.append(cv) + continue + if isinstance(cv, ArithValue): + raw = cv.ir_value() if hasattr(cv, "ir_value") else cv + if isinstance(raw, ir.Value) and isinstance(raw.type, ir.IndexType): + cv = arith.index_cast(T.i32, raw) + else: + cv = raw + elif isinstance(cv, ir.Value) and isinstance(cv.type, ir.IndexType): + cv = arith.index_cast(T.i32, cv) + elif hasattr(cv, "ir_value"): + raw = cv.ir_value() + if isinstance(raw, ir.Value) and isinstance(raw.type, ir.IndexType): + cv = arith.index_cast(T.i32, raw) + else: + cv = raw + crd_i32.append(cv) + coord_val = fx.make_coord(*crd_i32) + result = fx.crd2idx(coord_val, layout) + scalar = fx.get_scalar(result) + if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType): + scalar = arith.index_cast(T.index, scalar) + return _wrap(scalar) + + _, strides = parsed + result = None + for coord_v, stride_v in _builtins.zip(crd, strides): + if stride_v == 0: + continue + term = coord_v if stride_v == 1 else coord_v * arith.index(stride_v) + result = term if result is None else result + term + return result if result is not None else arith.index(0) + + +def get(int_tuple, mode): + """Extract element at `mode` from a Python list/tuple.""" + return int_tuple[mode] diff --git a/kernels/mfma_epilogues.py b/kernels/mfma_epilogues.py index 89174fd0..5df39ea0 100644 --- a/kernels/mfma_epilogues.py +++ b/kernels/mfma_epilogues.py @@ -29,24 +29,13 @@ from __future__ import annotations -from contextlib import contextmanager from typing import Callable from flydsl._mlir import ir import flydsl.expr as fx from flydsl.expr.typing import T - -@contextmanager -def _if_then(if_op, scf): - """Compat helper for SCF IfOp then-region across old/new Python APIs.""" - with ir.InsertionPoint(if_op.then_block): - try: - yield if_op.then_block - finally: - blk = if_op.then_block - if (not blk.operations) or not isinstance(blk.operations[-1], scf.YieldOp): - scf.YieldOp([]) +from kernels.kernels_common import _if_then def default_epilog( diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index 1de69d06..308ba1ae 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -28,9 +28,15 @@ def crd2idx(crd, layout): def swizzle_xor16(row, col, k_blocks16): """XOR-with-row swizzle on the K dimension at 16B granularity. - Computes: col XOR ((row % k_blocks16) * 16) + Computes: col XOR ((row & (k_blocks16 - 1)) * 16) + + k_blocks16 is always a power of 2 (tile_k_bytes / 16), so use + bitwise AND instead of remui to save ~10 VALU cycles on CDNA. """ - rem = row % k_blocks16 + from flydsl.expr import arith as _swz_arith + + mask = k_blocks16 - _swz_arith.index(1) + rem = _swz_arith.andi(row, mask) return col ^ (rem * 16) @@ -45,20 +51,39 @@ def split_row_major_2d(index, minor_extent): return index // minor_extent, index % minor_extent -def _buffer_load_vec(buffer_ops, vector, rsrc, idx, *, elem_type, vec_elems, elem_bytes, offset_in_bytes): +def _buffer_load_vec( + buffer_ops, + vector, + rsrc, + idx, + *, + elem_type, + vec_elems, + elem_bytes, + offset_in_bytes, + cache_modifier=0, +): """Load vec_elems elements via buffer_load dwordx[1,2,4] + bitcast.""" + from flydsl.expr import arith as _ld_arith + elem_size = int(elem_bytes) load_bytes = int(vec_elems) * elem_size vec_width = load_bytes // 4 if offset_in_bytes: - idx_i32 = idx // 4 + idx_i32 = _ld_arith.shrui(idx, _ld_arith.index(2)) elif elem_bytes == 2: - idx_i32 = (idx * 2) // 4 + idx_i32 = _ld_arith.shrui(idx, _ld_arith.index(1)) else: idx_i32 = idx - i32_val = buffer_ops.buffer_load(rsrc, idx_i32, vec_width=vec_width, dtype=T.i32) + i32_val = buffer_ops.buffer_load( + rsrc, + idx_i32, + vec_width=vec_width, + dtype=T.i32, + cache_modifier=cache_modifier, + ) if vec_width == 1: i32_vec = vector.from_elements(T.vec(1, T.i32), [i32_val]) else: @@ -66,59 +91,6 @@ def _buffer_load_vec(buffer_ops, vector, rsrc, idx, *, elem_type, vec_elems, ele return vector.bitcast(T.vec(int(vec_elems), elem_type), i32_vec) -@dataclass(frozen=True) -class PreshuffleBLayout: - """Container returned by `make_preshuffle_b_layout`.""" - - layout_b: object - kpack_bytes: int - - -def make_preshuffle_b_layout( - arith, - *, - c_n: ir.Value, - c_k: ir.Value, - kpack_bytes: int = 16, - elem_bytes: int = 1, -) -> PreshuffleBLayout: - """Build B layout matching aiter/CK preshuffle for A8 MFMA kernels.""" - if kpack_bytes not in (8, 16): - raise ValueError(f"kpack_bytes must be 8 or 16, got {kpack_bytes!r}") - - c16 = fx.Index(16) - c64 = fx.Index(64) - c4 = fx.Index(4) - c_kpack = fx.Index(kpack_bytes) - - if elem_bytes not in (1, 2): - raise ValueError(f"elem_bytes must be 1 or 2, got {elem_bytes!r}") - c_k_bytes = c_k * arith.constant(int(elem_bytes), index=True) - c_k0 = c_k_bytes // c64 - n0 = c_n // c16 - - c_kpack_elems = c_kpack if elem_bytes == 1 else (c_kpack // arith.constant(int(elem_bytes), index=True)) - - stride_nlane = c_kpack_elems - stride_klane = c16 * stride_nlane - stride_k0 = c4 * stride_klane - stride_n0 = c_k0 * stride_k0 - - # fly.make_shape requires i32/i64 for dynamic operands (not index). - # Convert dynamic index values to i32; use Python ints for static constants. - kpack_elems_static = kpack_bytes if elem_bytes == 1 else kpack_bytes // elem_bytes - n0_i32 = arith.index_cast(T.i32, n0) - c_k0_i32 = arith.index_cast(T.i32, c_k0) - stride_n0_i32 = arith.index_cast(T.i32, stride_n0) - stride_k0_i32 = arith.index_cast(T.i32, stride_k0) - stride_klane_i32 = arith.index_cast(T.i32, stride_klane) - stride_nlane_i32 = arith.index_cast(T.i32, stride_nlane) - - stride_b = (stride_n0_i32, stride_k0_i32, stride_klane_i32, stride_nlane_i32, 1) - layout_b = fx.make_layout((n0_i32, c_k0_i32, 4, 16, kpack_elems_static), stride_b) - return PreshuffleBLayout(layout_b=layout_b, kpack_bytes=kpack_bytes) - - @dataclass(frozen=True) class PreshuffleScaleLayout: """Container returned by `make_preshuffle_scale_layout`. @@ -129,10 +101,10 @@ class PreshuffleScaleLayout: idx = mni * stride_n0 + ku * stride_k0 + k_lane * stride_klane + n_lane """ - layout_scale: object # fly layout value (same as PreshuffleBLayout.layout_b) - stride_n0: object # index-typed MLIR value (dynamic) - stride_k0: object # index-typed MLIR value (= 64) - stride_klane: object # index-typed MLIR value (= 16) + layout_scale: object + stride_n0: object + stride_k0: object + stride_klane: object def make_preshuffle_scale_layout( @@ -150,14 +122,12 @@ def make_preshuffle_scale_layout( Layout shape: ``(c_mn1, c_k1, 4, 16)`` where ``c_mn1 = c_mn / 16 / mn_pack`` and ``c_k1 = (c_k / scale_block_size) / 4 / k_pack``. """ - c16 = arith.constant(16, index=True) - c4 = arith.constant(4, index=True) - c_mn_pack = arith.constant(mn_pack, index=True) - c_k_pack = arith.constant(k_pack, index=True) - c_k_scale = c_k / scale_block_size - - c_mn1 = c_mn / c16 / c_mn_pack - c_k1 = c_k_scale / c4 / c_k_pack + c16 = fx.Index(16) + c4 = fx.Index(4) + c_k_scale = c_k // fx.Index(scale_block_size) + + c_mn1 = (c_mn // c16) // fx.Index(mn_pack) + c_k1 = (c_k_scale // c4) // fx.Index(k_pack) if elem_bytes != mn_pack * k_pack: raise ValueError( f"elem_bytes of scale must be {mn_pack} * {k_pack}, got {elem_bytes!r}" @@ -167,7 +137,6 @@ def make_preshuffle_scale_layout( stride_k0 = c4 * stride_klane stride_n0 = c_k1 * stride_k0 - # Build fly layout (i32 strides for fx.make_layout). c_mn1_i32 = arith.index_cast(T.i32, c_mn1) c_k1_i32 = arith.index_cast(T.i32, c_k1) stride_n0_i32 = arith.index_cast(T.i32, stride_n0) @@ -187,6 +156,76 @@ def make_preshuffle_scale_layout( ) +@dataclass(frozen=True) +class PreshuffleBLayout: + """Container returned by `make_preshuffle_b_layout`.""" + + layout_b: object + kpack_bytes: int + + +def make_preshuffle_b_layout( + arith, + *, + c_n: ir.Value, + c_k: ir.Value, + kpack_bytes: int = 16, + elem_bytes: int = 1, + k_major: bool = False, +) -> PreshuffleBLayout: + """Build B layout matching aiter/CK preshuffle for A8 MFMA kernels. + + When *k_major* is True the block-level order is K-major (``k_blk`` outermost), + matching the ``(0,3,1,4,2,5)`` shuffle permutation. The default N-major + order (``k_major=False``) matches the legacy ``(0,1,3,4,2,5)`` permutation. + """ + if kpack_bytes not in (8, 16): + raise ValueError(f"kpack_bytes must be 8 or 16, got {kpack_bytes!r}") + + c16 = fx.Index(16) + c_kpack = fx.Index(kpack_bytes) + + if elem_bytes not in (1, 2): + raise ValueError(f"elem_bytes must be 1 or 2, got {elem_bytes!r}") + c_k_bytes = c_k * arith.constant(int(elem_bytes), index=True) + n0 = c_n // c16 + + c_kpack_elems = c_kpack if elem_bytes == 1 else (c_kpack // arith.constant(int(elem_bytes), index=True)) + + stride_nlane = c_kpack_elems + + if k_major: + c32 = fx.Index(32) + c2 = fx.Index(2) + c_k0 = c_k_bytes // c32 + klane_dim = 2 + stride_klane = c16 * stride_nlane + stride_n0 = c2 * stride_klane + stride_k0 = n0 * stride_n0 + else: + c64 = fx.Index(64) + c4 = fx.Index(4) + c_k0 = c_k_bytes // c64 + klane_dim = 4 + stride_klane = c16 * stride_nlane + stride_k0 = c4 * stride_klane + stride_n0 = c_k0 * stride_k0 + + kpack_elems_static = kpack_bytes if elem_bytes == 1 else kpack_bytes // elem_bytes + n0_i32 = arith.index_cast(T.i32, n0) + c_k0_i32 = arith.index_cast(T.i32, c_k0) + stride_n0_i32 = arith.index_cast(T.i32, stride_n0) + stride_k0_i32 = arith.index_cast(T.i32, stride_k0) + stride_klane_i32 = arith.index_cast(T.i32, stride_klane) + stride_nlane_i32 = arith.index_cast(T.i32, stride_nlane) + + stride_b = (stride_n0_i32, stride_k0_i32, stride_klane_i32, stride_nlane_i32, 1) + layout_b = fx.make_layout( + (n0_i32, c_k0_i32, klane_dim, 16, kpack_elems_static), stride_b + ) + return PreshuffleBLayout(layout_b=layout_b, kpack_bytes=kpack_bytes) + + def _i8x4_in_i32_to_bf16x4_i64(val_i32, arith, vector, scale_val=None): """Convert one i32 (4 signed int8 bytes) to 4 bf16 packed as i64. @@ -267,8 +306,14 @@ def load_b_raw_w4a16( idx_bytes = idx_pack + k2_base b4 = _buffer_load_vec( - buffer_ops, vector, b_rsrc, idx_bytes, - elem_type=elem_type, vec_elems=4, elem_bytes=1, offset_in_bytes=True, + buffer_ops, + vector, + b_rsrc, + idx_bytes, + elem_type=elem_type, + vec_elems=4, + elem_bytes=1, + offset_in_bytes=True, ) packed32 = vector.extract( vector.bitcast(T.vec(1, T.i32), b4), @@ -344,8 +389,14 @@ def load_b_pack_k32( if unpack_int4: idx_bytes = idx_pack + k2_base b4 = _buffer_load_vec( - buffer_ops, vector, b_rsrc, idx_bytes, - elem_type=elem_type, vec_elems=4, elem_bytes=1, offset_in_bytes=True, + buffer_ops, + vector, + b_rsrc, + idx_bytes, + elem_type=elem_type, + vec_elems=4, + elem_bytes=1, + offset_in_bytes=True, ) packed32 = vector.extract( vector.bitcast(T.vec(1, T.i32), b4), @@ -371,8 +422,13 @@ def load_b_pack_k32( vec_elems = kpack_bytes // int(elem_bytes) b16 = _buffer_load_vec( - buffer_ops, vector, b_rsrc, idx_pack, - elem_type=elem_type, vec_elems=vec_elems, elem_bytes=elem_bytes, + buffer_ops, + vector, + b_rsrc, + idx_pack, + elem_type=elem_type, + vec_elems=vec_elems, + elem_bytes=elem_bytes, offset_in_bytes=(elem_bytes == 1), ) @@ -425,8 +481,13 @@ def buffer_copy_gmem16_dwordx4( if int(vec_elems) <= 0: raise ValueError(f"vec_elems must be > 0, got {vec_elems!r}") return _buffer_load_vec( - buffer_ops, vector, rsrc, idx_i32, - elem_type=elem_type, vec_elems=vec_elems, elem_bytes=elem_bytes, + buffer_ops, + vector, + rsrc, + idx_i32, + elem_type=elem_type, + vec_elems=vec_elems, + elem_bytes=elem_bytes, offset_in_bytes=False, ) @@ -548,15 +609,19 @@ def lds_load_pack_k32( __all__ = [ "PreshuffleBLayout", + "PreshuffleScaleLayout", "buffer_copy_gmem16_dwordx4", - "lds_row_major_idx", "lds_load_pack_k32", + "lds_row_major_idx", "lds_store_4b_xor16", "lds_store_8b_xor16", "lds_store_16b_xor16", "make_preshuffle_b_layout", + "make_preshuffle_scale_layout", "load_b_pack_k32", + "load_b_raw_w4a16", "split_row_major_2d", "swizzle_xor16", "tile_chunk_coord_i32", + "unpack_b_w4a16", ] diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index eb23631a..318982a0 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -2,12 +2,10 @@ # Copyright (c) 2025 FlyDSL Project Contributors """MoE GEMM stage1/stage2 kernel implementations (FlyDSL MFMA FP8/FP16/FP4). - """ import functools import os -from contextlib import contextmanager import flydsl.compiler as flyc import flydsl.expr as fx @@ -21,41 +19,56 @@ try: from flydsl.runtime.device import supports_bf16_global_atomics except ImportError: - # Backward compatibility for runtime.device versions that only expose get_rocm_arch. def supports_bf16_global_atomics(arch: str) -> bool: return str(arch).startswith(("gfx94", "gfx95", "gfx12")) from flydsl._mlir import ir -from flydsl._mlir.dialects import scf, memref, llvm +from flydsl._mlir.dialects import llvm, scf, memref +from flydsl._mlir.dialects.arith import CmpIPredicate from kernels.mfma_preshuffle_pipeline import ( _buffer_load_vec, buffer_copy_gmem16_dwordx4, - crd2idx, lds_row_major_idx, lds_store_16b_xor16, lds_store_8b_xor16, lds_store_4b_xor16, + load_b_pack_k32, make_preshuffle_b_layout, make_preshuffle_scale_layout, - load_b_pack_k32, split_row_major_2d, tile_chunk_coord_i32, swizzle_xor16, ) -from kernels.mfma_epilogues import c_shuffle_epilog, mfma_epilog +from kernels.mfma_epilogues import c_shuffle_epilog +from kernels.layout_utils import crd2idx, idx2crd, get as layout_get +from kernels.kernels_common import _if_then, validate_moe_dtypes -@contextmanager -def _if_then(if_op): - """Compat helper for SCF IfOp then-region across old/new Python APIs.""" - with ir.InsertionPoint(if_op.then_block): - try: - yield if_op.then_block - finally: - blk = if_op.then_block - if (not blk.operations) or not isinstance(blk.operations[-1], scf.YieldOp): - scf.YieldOp([]) +def _barrier(vmcnt=63, lgkmcnt=63): + """Emit s_waitcnt + s_barrier via inline asm. + + Bypasses LLVM SIInsertWaitcnts which would insert a conservative + s_waitcnt vmcnt(0) lgkmcnt(0) before every S_BARRIER MI. + """ + parts = [] + needs_waitcnt = vmcnt < 63 or lgkmcnt < 63 + if needs_waitcnt: + wc = [] + if vmcnt < 63: + wc.append(f"vmcnt({vmcnt})") + if lgkmcnt < 63: + wc.append(f"lgkmcnt({lgkmcnt})") + parts.append("s_waitcnt " + " ".join(wc)) + parts.append("s_barrier") + llvm.InlineAsmOp( + res=None, + operands_=[], + asm_string="\n".join(parts), + constraints="", + has_side_effects=True, + is_align_stack=False, + ) def _w_elem_type(*, is_f4_b: bool, is_f16_b: bool): @@ -65,7 +78,7 @@ def _w_elem_type(*, is_f4_b: bool, is_f16_b: bool): return T.f16 if is_f16_b else T.f8 -@functools.lru_cache(maxsize=1024) +@functools.lru_cache(maxsize=None) def compile_mixed_moe_gemm1( *, model_dim: int, @@ -75,7 +88,6 @@ def compile_mixed_moe_gemm1( tile_m: int, tile_n: int, tile_k: int, - # NOTE: aiter swap passes these for API symmetry; stage1 uses dynamic memrefs so they are ignored. doweight_stage1: bool, a_dtype: str = "fp8", b_dtype: str = "fp4", @@ -85,24 +97,35 @@ def compile_mixed_moe_gemm1( enable_bias: bool = False, model_dim_pad: int = 0, inter_dim_pad: int = 0, + persist_m: int = 1, + fuse_fp4_quant: bool = False, + fuse_sort_scale: bool = False, + use_async_copy: bool = False, + waves_per_eu: int = 3, + k_batch: int = 1, + b_nt: int = 0, + gate_only: bool = False, ): - """Compile stage1 kernel (`moe_gemm1`) and return the compiled executable. - - a_dtype: - - "fp8": X is fp8 - - "fp16": X is fp16 (caller uses tile_k halved vs fp8 to match MFMA K halving) - - "int8": X is int8 - - "fp4": X is fp4 - - b_dtype: - - "fp8": W is fp8 - - "fp16": W is fp16 (caller uses tile_k halved vs fp8 to match MFMA K halving) - - "int8": W is int8 - - "int4": W4A8 path: X is int8, W is packed int4 (2 values per byte) unpacked to int8 in-kernel - - "fp4": W is fp4 + """Compile stage1 kernel (gate+up with silu) based on stage2 structure. + + GEMM: silu(X @ W_gate.T) * (X @ W_up.T) -> [tokens*topk, inter_dim] + Direct store (no atomic). When k_batch>1 (split-K), each CTA + computes a K-slice and atomically adds gate/up partials. + Note: persist_m=1 (no persistence) is optimal for stage1 because K=model_dim + is large, so each CTA is already compute-heavy. persist_m>1 serializes M blocks + that the GPU can process in parallel. + + When gate_only=True (requires k_batch>1), each workgroup computes + only one B-tile stream instead of interleaving gate and up. + The grid X dimension doubles (inter_in / tile_n instead of + inter_in / 2 / tile_n) so that by_n covers the full [0, 2*inter_dim) + range, naturally selecting gate or up rows by position. + This halves per-WG B-VMEM traffic and MFMA count, and the + doubled block count compensates. """ gpu_arch = get_hip_arch() - allocator = SmemAllocator(None, arch=gpu_arch) + allocator_pong = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0") + allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem1") if a_dtype not in ("fp8", "fp16", "int8", "fp4"): raise ValueError( @@ -115,271 +138,402 @@ def compile_mixed_moe_gemm1( is_f16_a = a_dtype == "fp16" is_f16_b = b_dtype == "fp16" - is_f16 = is_f16_a or is_f16_b - is_f8_a = a_dtype == "fp8" is_f4_a = a_dtype == "fp4" is_f4_b = b_dtype == "fp4" - pack_M = 2 - pack_N = 2 + sort_block_m = max(32, tile_m) + num_waves = tile_n // 32 + total_threads = num_waves * 64 + pack_M = 1 if tile_m < 32 else 2 + n_per_wave = tile_n // num_waves + pack_N = min(2, n_per_wave // 16) pack_K = 2 - + scale_mn_pack = 2 elem_bytes = 1 - a_elem_bytes = 2 if is_f16_a else 1 b_elem_bytes = 1 tile_k_bytes = int(tile_k) * int(a_elem_bytes) - a_elem_vec_pack = 2 if is_f4_a else 1 cbsz = 0 if is_f8_a else 4 blgp = 4 - # K64-byte micro-step: always 64 bytes per `ku`. For fp16, this is 32 elements (2xK16 MFMA). if (tile_k_bytes % 64) != 0: - raise ValueError( - f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " - f"(tile_k={tile_k}, elem_bytes={a_elem_bytes})" - ) + raise ValueError(f"tile_k_bytes must be divisible by 64, got {tile_k_bytes}") + + out_s = str(out_dtype).strip().lower() + out_is_f32 = out_s in ("f32", "fp32", "float") + out_is_bf16 = out_s in ("bf16", "bfloat16") is_int4 = b_dtype == "int4" + is_int8 = False + def _x_elem_type(): + if is_f4_b: + return T.f8 if is_f8_a else T.i8 + return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) - def _x_lds_elem_type(): - return T.f16 if is_f16_a else T.f8 + def _w_elem_type(): + if is_f4_b: + return T.i8 + return T.f16 if is_f16_b else (T.i8 if is_int8 else T.f8) - def _out_elem_type(): - return T.bf16 if out_dtype == "bf16" else T.f16 + def out_elem(): + return T.f32 if out_is_f32 else (T.bf16 if out_is_bf16 else T.f16) - def _out_lds_elem_type(): - return T.f32 + # Split-K validation + _is_splitk = k_batch > 1 + if gate_only and not _is_splitk: + raise ValueError("gate_only requires k_batch > 1 (split-K)") + if _is_splitk: + _k_per_batch = model_dim // k_batch + assert ( + model_dim % k_batch == 0 + ), f"model_dim={model_dim} not divisible by k_batch={k_batch}" + assert ( + _k_per_batch % tile_k == 0 + ), f"K_per_batch={_k_per_batch} not divisible by tile_k={tile_k}" + + fuse_fp4_quant = False + else: + _k_per_batch = model_dim + _k_dim = _k_per_batch + # Stage1 gate-only: output = [tokens*topk, inter_dim], direct store (accumulate=False) + # Weight layout: [E * 2*inter_dim, model_dim] pre-shuffled; gate = first inter_dim rows per expert + # GEMM: X[tokens, model_dim] @ W_gate[inter_dim, model_dim].T -> [tokens*topk, inter_dim] - total_threads = 256 bytes_x_per_tile = int(tile_m) * int(tile_k) * int(a_elem_bytes) if bytes_x_per_tile % total_threads != 0: raise ValueError( - "tile_m*tile_k*elem_bytes must be divisible by " - f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, elem_bytes={a_elem_bytes}" + f"tile_m*tile_k*elem_bytes must be divisible by {total_threads}" ) bytes_per_thread_x = bytes_x_per_tile // total_threads - pad_k = 0 + + _use_lds128 = os.environ.get("FLIR_CK_LDS128", "1") in ( + "1", + "true", + "True", + "YES", + "yes", + ) + pad_k = 0 if _use_lds128 else 8 lds_stride = tile_k + pad_k + if use_cshuffle_epilog is None: - use_cshuffle_epilog = os.environ.get("FLYDSL_MOE_STAGE1_CSHUFFLE", "0") in ( + _use_cshuffle_epilog = os.environ.get("FLIR_MOE_STAGE1_CSHUFFLE", "1") in ( "1", "true", "True", "YES", "yes", ) - use_cshuffle_epilog = bool(use_cshuffle_epilog) + else: + _use_cshuffle_epilog = bool(use_cshuffle_epilog) + + _need_quant = fuse_fp4_quant + _need_sort = _need_quant and fuse_sort_scale - epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" + if _need_quant: + _use_cshuffle_epilog = True + + _fp4q_tag = "_fp4q" if _need_quant else "" + _sort_tag = "_sort" if _need_sort else "" + _async_tag = "_async" if use_async_copy else "" + _sk_tag = f"_sk{k_batch}" if _is_splitk else "" + _go_tag = "_go" if gate_only else "" module_name = ( - f"mfma_moe1_a{a_dtype}_w{b_dtype}_{epilog_tag}" - f"_t{tile_m}x{tile_n}x{tile_k}" - f"_abi35_ckstage1" + f"mfma_moe1_silu_mul_a{a_dtype}_w{b_dtype}_{out_s}" + f"_t{tile_m}x{tile_n}x{tile_k}_pm{persist_m}{_fp4q_tag}{_sort_tag}{_async_tag}{_sk_tag}{_go_tag}_v32" ).replace("-", "_") - # -- LDS sizing (pure Python; no MLIR Context needed) --------------------- - # Reuse the same LDS bytes for both: - # - ping-pong X tiles (2 * tile_m * lds_stride * elem_bytes bytes) - # - optional CShuffle tile (stage1 uses 2xf16 vector store, sized in 4B pairs) - _use_cshuffle_epilog = bool(use_cshuffle_epilog) - lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(a_elem_bytes) - lds_out_bytes = 4 * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 - lds_tid_bytes = int(tile_m) * 4 - lds_total_bytes = max(lds_x_bytes, lds_out_bytes) + lds_tid_bytes - lds_total_elems = lds_total_bytes if a_elem_bytes == 1 else (lds_total_bytes // 2) - lds_alloc_bytes = int(lds_total_elems) * int(a_elem_bytes) - lds_alloc_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = lds_alloc_offset + lds_alloc_bytes + # -- LDS sizing (split ping/pong allocators) -- + _cshuffle_elem_bytes = 4 if _need_quant else (4 if out_is_f32 else 2) + _single_x_bytes = int(tile_m) * int(lds_stride) * int(a_elem_bytes) + lds_out_bytes = ( + _cshuffle_elem_bytes * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 + ) + lds_tid_bytes = int(tile_m) * 4 + _buffer_bytes = max(_single_x_bytes, lds_out_bytes) + _buffer_elems = _buffer_bytes if a_elem_bytes == 1 else (_buffer_bytes // 2) - @flyc.kernel - def moe_gemm1( - arg_out: fx.Tensor, - arg_x: fx.Tensor, - arg_w: fx.Tensor, - arg_scale_x: fx.Tensor, - arg_scale_w: fx.Tensor, - arg_sorted_token_ids: fx.Tensor, - arg_expert_ids: fx.Tensor, - arg_sorted_weights: fx.Tensor, - arg_max_token_ids: fx.Tensor, - arg_bias: fx.Tensor, - i32_tokens_in: fx.Int32, - i32_inter_in: fx.Int32, - i32_k_in: fx.Int32, - i32_size_expert_ids_in: fx.Int32, - ): - tokens_in = arith.index_cast(T.index, i32_tokens_in) - k_in = arith.index_cast(T.index, i32_k_in) - size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) - tokens_i32_v = i32_tokens_in - k_i32_v = i32_k_in - x_elem = T.f16 if is_f16_a else T.f8 - vec4_f32 = T.vec(4, T.f32) - vec4_i32 = T.vec(4, T.i32) - vec1_f32 = T.vec(1, T.f32) - vec16_elems = 16 if a_elem_bytes == 1 else 8 - vec16_x = T.vec(vec16_elems, x_elem) - vec2_i64 = T.vec(2, T.i64) - - def silu(x): - # Align with CK's device fast path: - # emu = exp(-x) ~= exp2(log2e * (-x)) -> v_exp_f32 - # sig = rcp(1 + emu) -> v_rcp_f32 - # y = x * sig - t = x * (-1.4426950408889634) # -log2(e) - emu = rocdl.exp2(T.f32, t) - den = 1.0 + emu - sig = rocdl.rcp(T.f32, den) - return x * sig - - _arith_min = getattr(arith, "minimum", None) or getattr(arith, "minimumf") - _arith_max = getattr(arith, "maximum", None) or getattr(arith, "maximumf") - - def swiglu(gate, up, alpha=1.702, limit=7.0): - gate = _arith_min(gate, limit) - up = _arith_min(up, limit) - up = _arith_max(up, -limit) - - t = gate * alpha * (-1.4426950408889634) # -log2(e) - emu = rocdl.exp2(T.f32, t) - den = 1.0 + emu - sig = rocdl.rcp(T.f32, den) - return gate * sig * (up + 1.0) - - acc_init = arith.constant_vector(0.0, vec4_f32) - - # B preshuffle layout: match GEMM test helper exactly. - c_n_total = fx.Index(experts * (2 * inter_dim)) - kpack_bytes = 8 if is_int4 else 16 - b_layout = make_preshuffle_b_layout( - arith, - c_n=c_n_total, - c_k=k_in // pack_K, - kpack_bytes=kpack_bytes, - elem_bytes=b_elem_bytes, - ) - layout_b = b_layout.layout_b + def x_lds_elem(): + return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) + + lds_pong_offset = allocator_pong._align(allocator_pong.ptr, 16) + allocator_pong.ptr = lds_pong_offset + _buffer_bytes + _lds_tid_offset_pong = allocator_pong._align(allocator_pong.ptr, 4) + allocator_pong.ptr = _lds_tid_offset_pong + lds_tid_bytes + + lds_ping_offset = allocator_ping._align(allocator_ping.ptr, 16) + allocator_ping.ptr = lds_ping_offset + _buffer_bytes + + # if tile_m == 16: + # waves_per_eu = 1 + + if waves_per_eu is not None and waves_per_eu >= 1: + _total_cu_lds = 160 * 1024 + _min_lds = _total_cu_lds // (waves_per_eu + 1) + 1 + _pong_sz = allocator_pong._align(allocator_pong.ptr, 128) + _ping_sz = allocator_ping._align(allocator_ping.ptr, 128) + _cur_lds = _pong_sz + _ping_sz + if _cur_lds < _min_lds: + allocator_ping.ptr += _min_lds - _cur_lds + + kpack_bytes = 8 if is_int4 else 16 + out_elem_bytes = 4 if out_is_f32 else 2 + + _e_vec_s1 = min(tile_n // 32, 8) + if _need_quant: + _e_vec_s1 = max(2, _e_vec_s1) + _num_threads_per_quant_blk_s1 = 32 // _e_vec_s1 + _shuffle_dists_s1 = [] + _sh_val = 1 + while _sh_val < _num_threads_per_quant_blk_s1: + _shuffle_dists_s1.append(_sh_val) + _sh_val *= 2 + _num_shuffle_steps_s1 = len(_shuffle_dists_s1) + + # ---- Unified pipeline schedule (outside @flyc.kernel) ---- + # Each scheduling phase is a dict: + # mfma: [(k_idx, mi_idx, ikxdl, imxdl, asv_idx), ...] + # a_reads: [(k, mi), ...] # A ds_read subtiles + # b_loads: [('gate'/'up', ku, ni), ...] # B VMEM loads + # has_scale: bool # A/B scale VMEM loads + _pipe_m_repeat = tile_m // 16 + _pipe_k_unroll = tile_k_bytes // 128 + _pipe_k_unroll_packed = _pipe_k_unroll // pack_K + _pipe_m_repeat_packed = _pipe_m_repeat // pack_M + _pipe_num_acc_n = n_per_wave // 16 + + # A ds_read groups: group by mi (same mi, all k values together) + _pipe_a_groups = [] + for _mi in range(_pipe_m_repeat): + _grp = [] + for _k in range(_pipe_k_unroll): + _grp.append((_k, _mi)) + if len(_grp) == 2: + _pipe_a_groups.append(_grp) + _grp = [] + if _grp: + _pipe_a_groups.append(_grp) + + # B VMEM loads: individual gate/up loads + _pipe_b_loads = [] + for ku in range(_pipe_k_unroll): + for ni in range(_pipe_num_acc_n): + _pipe_b_loads.append(("gate", ku, ni)) + if not gate_only: + _pipe_b_loads.append(("up", ku, ni)) + + # MFMA order: B-major (fix B, cycle all A tiles before next B) + # Each entry: one (k, ni) pair; the compute function loops over all mi. + # This keeps B operands (from VMEM) fixed while cycling A (from LDS, no wait). + _pipe_all_mfma = [] + for _ku128 in range(_pipe_k_unroll_packed): + for _ikxdl in range(pack_K): + for _inxdl in range(pack_N): + _k_idx = _ku128 * pack_K + _ikxdl + _ni_idx = _inxdl + _pipe_all_mfma.append((_k_idx, _ni_idx, _ikxdl, _inxdl, _ku128)) + + # Group MFMAs per scheduling phase (wider M -> more MFMAs per phase) + _pipe_mfma_per_phase = max(1, len(_pipe_all_mfma) // 4) + _pipe_n_phases = len(_pipe_all_mfma) // _pipe_mfma_per_phase + + # Build unified phase descriptors + _a_groups_per_phase = (len(_pipe_a_groups) + _pipe_n_phases - 1) // _pipe_n_phases + _pipe_phases = [] + _mfma_i = 0 + _a_i = 0 + for _p in range(_pipe_n_phases): + _a_reads = [] + for _ in range(_a_groups_per_phase): + if _a_i < len(_pipe_a_groups): + _a_reads.extend(_pipe_a_groups[_a_i]) + _a_i += 1 + _phase = { + "mfma": _pipe_all_mfma[_mfma_i : _mfma_i + _pipe_mfma_per_phase], + "a_reads": _a_reads, + "b_loads": [], + "has_scale": (_p == 0), + } + _mfma_i += _pipe_mfma_per_phase + _pipe_phases.append(_phase) + + # Distribute B loads evenly across phases 1..n-1 (phase 0 has scales) + _bi = 0 + for _p in range(1, _pipe_n_phases): + _rem_b = len(_pipe_b_loads) - _bi + _rem_p = _pipe_n_phases - _p + _n_b = (_rem_b + _rem_p - 1) // _rem_p if _rem_p > 0 else 0 + for _ in range(_n_b): + if _bi < len(_pipe_b_loads): + _pipe_phases[_p]["b_loads"].append(_pipe_b_loads[_bi]) + _bi += 1 + + # Extract flat lists for kernel access (avoids dict access in AST rewriter) + _pp_mfma = [p["mfma"] for p in _pipe_phases] + _pp_a_reads = [p["a_reads"] for p in _pipe_phases] + _pp_b_loads = [p["b_loads"] for p in _pipe_phases] + _pp_has_scale = [p["has_scale"] for p in _pipe_phases] - m_repeat = tile_m // 16 - k_unroll = tile_k_bytes // 128 # K64-byte micro-step + if True: - # A scale is sorted/padded by MoE routing, so its M dimension follows - # the sorted row buffer (`blocks * tile_m`), not the raw token count. - sorted_rows = size_expert_ids_in * fx.Index(tile_m) - layout_a_scale = make_preshuffle_scale_layout( - arith, c_mn=sorted_rows, c_k=k_in - ) - layout_b_scale = make_preshuffle_scale_layout( - arith, c_mn=c_n_total, c_k=k_in - ) + @flyc.kernel + def moe_gemm1( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_num_valid_ids: fx.Tensor, + arg_bias: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_n_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + ): - shape_lds = fx.make_shape(tile_m, tile_k) - stride_lds = fx.make_stride(lds_stride, 1) - layout_lds = fx.make_layout(shape_lds, stride_lds) + tokens_in = arith.index_cast(ir.IndexType.get(), i32_tokens_in.ir_value()) + n_in = arith.index_cast(ir.IndexType.get(), i32_n_in.ir_value()) + k_in = arith.index_cast(ir.IndexType.get(), i32_k_in.ir_value()) + size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) - tx = gpu.thread_id("x") - # Align with Aiter launch mapping (NSwizzle==false): - # - blockIdx.x -> N dimension (tile along inter_dim) - # - blockIdx.y -> expert-block id / M dimension (tile along sorted M) - by = gpu.block_id("x") # tile along inter_dim - bx = gpu.block_id("y") # tile along sorted M + x_elem = T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) + f32 = T.f32 + i32 = T.i32 + i64 = T.i64 + vec4_f32 = T.vec(4, f32) + vec16_elems = 16 if a_elem_bytes == 1 else 8 + vec16_x = T.vec(vec16_elems, x_elem) + vec2_i64 = T.vec(2, i64) - # Block validity: compute as early as possible so invalid blocks skip all buffer-resource - # setup, LDS pointer math, and gmem prefetch work. - bx_m = bx * fx.Index(tile_m) - by_n = by * fx.Index(tile_n) + acc_init = arith.constant_vector(0.0, vec4_f32) - maxids_rsrc = buffer_ops.create_buffer_resource( - arg_max_token_ids, max_size=False, num_records_bytes=fx.Int32(4) - ) - max_token_id_i32 = buffer_ops.buffer_load( - maxids_rsrc, fx.Index(0), vec_width=1, dtype=T.i32 - ) + # --- Stage1 dimension mapping --- + # X: [tokens, model_dim] -- M = sorted tokens, K = model_dim + # W: [E*2*inter_dim, model_dim] gate portion -- N = inter_dim - bias_rsrc = ( - buffer_ops.create_buffer_resource(arg_bias, max_size=False) - if enable_bias - else None - ) + # B preshuffle layout: [E*2*inter_dim, model_dim] + # Gate rows for expert e: [e*2*inter_dim, e*2*inter_dim + inter_dim) + c_n_total = arith.constant(experts * (2 * inter_dim), index=True) + b_layout = make_preshuffle_b_layout( + arith, + c_n=c_n_total, + c_k=k_in // pack_K, + kpack_bytes=kpack_bytes, + elem_bytes=b_elem_bytes, + # k_major=True, + ) + layout_b = b_layout.layout_b - bx_m_i32 = arith.index_cast(T.i32, bx_m) - blk_valid = arith.cmpi(arith.CmpIPredicate.ult, bx_m_i32, max_token_id_i32) - # Common constants/atoms (hoisted): keep IR small like GEMM. - # CK-style XOR16 swizzle parameter (constant, power-of-two in our configs). - k_blocks16 = fx.Index(tile_k_bytes // 16) - _if_blk = scf.IfOp(blk_valid) - with _if_then(_if_blk): - x_lds_elem = _x_lds_elem_type() - base_ptr = allocator.get_base() - lds_x_ptr = SmemPtr( - base_ptr, - lds_alloc_offset, - x_lds_elem, - shape=(lds_total_elems,), + # A-scale: [sorted_size, K/32] -- pre-scattered by caller into sorted layout + # Same as stage2: indexed by sorted_row position, not by token_id. + sorted_m = size_expert_ids_in * arith.constant(sort_block_m, index=True) + layout_a_scale = make_preshuffle_scale_layout( + arith, c_mn=sorted_m, c_k=arith.constant(model_dim, index=True) ) - lds_x = lds_x_ptr.get() - # Alias LDS bytes as fp16 for optional CShuffle epilogue. - _use_cshuffle_epilog = bool(use_cshuffle_epilog) + # B-scale: [E*2*inter_dim, K/32] + layout_b_scale = make_preshuffle_scale_layout( + arith, c_mn=c_n_total, c_k=arith.constant(model_dim, index=True) + ) + + if use_async_copy and a_elem_vec_pack > 1: + _eff_lds_stride = lds_stride // a_elem_vec_pack + _eff_tile_k_bytes = tile_k_bytes // a_elem_vec_pack + else: + _eff_lds_stride = lds_stride + _eff_tile_k_bytes = tile_k_bytes + + shape_lds = fx.make_shape(tile_m, _eff_lds_stride) + stride_lds = fx.make_stride(_eff_lds_stride, 1) + layout_lds = fx.make_layout(shape_lds, stride_lds) + + tx = gpu.thread_id("x") + by = gpu.block_id("x") # tile along inter_dim (N) + bx_persist = gpu.block_id("y") # persistent WG index + by_n = by * arith.constant(tile_n, index=True) - _lds_out_elems = tile_m * tile_n if _use_cshuffle_epilog else 0 + if _is_splitk: + bz = gpu.block_id("z") # K-batch id + k_base_idx = bz * arith.constant(_k_dim, index=True) + else: + k_base_idx = arith.index(0) + + k_blocks16 = arith.constant(_eff_tile_k_bytes // 16, index=True) + layout_tx_wave_lane = fx.make_layout((num_waves, 64), stride=(64, 1)) + layout_lane16 = fx.make_layout((4, 16), stride=(16, 1)) + + base_ptr_pong = allocator_pong.get_base() + base_ptr_ping = allocator_ping.get_base() + lds_x_pong = SmemPtr( + base_ptr_pong, lds_pong_offset, x_lds_elem(), shape=(_buffer_elems,) + ).get() + lds_x_ping = SmemPtr( + base_ptr_ping, lds_ping_offset, x_lds_elem(), shape=(_buffer_elems,) + ).get() + _lds_out_elem_type = ( + T.f32 if _need_quant else (T.bf16 if out_is_bf16 else T.f16) + ) lds_out = ( SmemPtr( - base_ptr, - lds_x_ptr.byte_offset, - _out_lds_elem_type(), - shape=(_lds_out_elems,), + base_ptr_pong, + lds_pong_offset, + _lds_out_elem_type, + shape=(tile_m * tile_n,), ).get() if _use_cshuffle_epilog else None ) + lds_tid = SmemPtr( + base_ptr_pong, _lds_tid_offset_pong, T.i32, shape=(tile_m,) + ).get() - # Use logical buffer sizes (descriptor num_records) so hardware OOB checking can be - # used directly (CK-style). This allows us to avoid `select`-based masking for - # invalid lanes and rely on the buffer instruction's built-in bounds behavior. - x_nbytes = ( - tokens_in - * (k_in // fx.Index(int(a_elem_vec_pack))) - * fx.Index(int(elem_bytes)) - ) + # Buffer resources + c_a_pack = arith.constant(int(a_elem_vec_pack), index=True) + c_elem_bytes = arith.constant(int(a_elem_bytes), index=True) + + # X: [tokens, model_dim] + x_nbytes_idx = (tokens_in * k_in * c_elem_bytes) / c_a_pack + x_nbytes_i32 = arith.index_cast(T.i32, x_nbytes_idx) x_rsrc = buffer_ops.create_buffer_resource( - arg_x, max_size=False, num_records_bytes=x_nbytes + arg_x, max_size=False, num_records_bytes=x_nbytes_i32 ) - _w_n = fx.Index(experts * (2 * inter_dim)) - _w_nbytes = ( - _w_n - * (k_in // fx.Index(int(a_elem_vec_pack))) - * fx.Index(int(elem_bytes)) + + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) + + # Out: [tokens*topk, inter_dim] + out_nbytes_idx = ( + tokens_in + * arith.index(topk) + * n_in + * arith.constant(out_elem_bytes, index=True) ) - _w_nbytes_i32 = arith.index_cast(T.i32, _w_nbytes) - w_rsrc = buffer_ops.create_buffer_resource( - arg_w, max_size=False, num_records_bytes=_w_nbytes_i32 + out_nbytes_i32 = arith.index_cast(T.i32, out_nbytes_idx) + buffer_ops.create_buffer_resource( + arg_out, max_size=False, num_records_bytes=out_nbytes_i32 ) - # OUT: [tokens * topk * inter_dim] in f16/bf16 (2B each) or fp8 (1B each). - _out_elem_bytes = 1 if out_dtype == "fp8" else 2 - out_nbytes = tokens_in * arith.constant( - topk * inter_dim * _out_elem_bytes, index=True + numids_rsrc = buffer_ops.create_buffer_resource( + arg_num_valid_ids, + max_size=False, + num_records_bytes=arith.constant(4, type=T.i32), ) - out_nbytes_i32 = arith.index_cast(T.i32, out_nbytes) - out_rsrc = buffer_ops.create_buffer_resource( - arg_out, max_size=False, num_records_bytes=out_nbytes_i32 + num_valid_i32 = buffer_ops.buffer_load( + numids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=T.i32 ) if is_f16_a: sx_rsrc = None else: - # A1 microscale: [sorted_rows, K/32] e8m0 bytes, packed as i32. - _c32 = fx.Index(32) - _kblk = k_in / _c32 - _sorted_rows = size_expert_ids_in * arith.constant( - tile_m, index=True - ) - sx_nbytes = _sorted_rows * _kblk - sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes) + # A scale: [sorted_size, model_dim/32] pre-scattered by caller + c32 = arith.constant(32, index=True) + kblk = k_in / c32 + sx_nbytes_idx = sorted_m * kblk + sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes_idx) sx_rsrc = buffer_ops.create_buffer_resource( arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 ) @@ -387,86 +541,85 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): if is_f16_b: sw_rsrc = None else: - # W1 microscale: [experts * 2 * inter_dim, K/32] e8m0 bytes. - _c32_w = fx.Index(32) - _kblk_w = k_in / _c32_w - _mn_w = fx.Index(experts * (2 * inter_dim)) - sw_nbytes = _mn_w * _kblk_w - sw_nbytes_i32 = arith.index_cast(T.i32, sw_nbytes) + c32 = arith.constant(32, index=True) + kblk_w = k_in / c32 + mn_w = arith.constant(experts * (2 * inter_dim), index=True) + sw_nbytes_idx = mn_w * kblk_w + sw_nbytes_i32 = arith.index_cast(T.i32, sw_nbytes_idx) sw_rsrc = buffer_ops.create_buffer_resource( arg_scale_w, max_size=False, num_records_bytes=sw_nbytes_i32 ) - # sorted_token_ids / sorted_weights: [blocks*tile_m] (CK-style padded length) - sorted_nbytes = size_expert_ids_in * arith.constant( - tile_m * 4, index=True + sorted_nbytes_idx = size_expert_ids_in * arith.constant( + sort_block_m * 4, index=True ) - sorted_nbytes_i32 = arith.index_cast(T.i32, sorted_nbytes) + sorted_nbytes_i32 = arith.index_cast(T.i32, sorted_nbytes_idx) sorted_rsrc = buffer_ops.create_buffer_resource( arg_sorted_token_ids, max_size=False, num_records_bytes=sorted_nbytes_i32, ) - sorted_w_rsrc = ( - buffer_ops.create_buffer_resource( - arg_sorted_weights, - max_size=False, - num_records_bytes=sorted_nbytes_i32, - ) - if doweight_stage1 - else None + sorted_w_rsrc = buffer_ops.create_buffer_resource( + arg_sorted_weights, max_size=False, num_records_bytes=sorted_nbytes_i32 ) - # expert ids: [blocks] i32 -> bytes = size_expert_ids_in*4 - eid_nbytes_i32 = arith.index_cast( - T.i32, size_expert_ids_in * fx.Index(4) - ) + eid_nbytes_idx = size_expert_ids_in * arith.constant(4, index=True) + eid_nbytes_i32 = arith.index_cast(T.i32, eid_nbytes_idx) expert_rsrc = buffer_ops.create_buffer_resource( arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32 ) - # Expert id for this M tile (keep address math in `index`) + # Sorted-scale buffer resource for fused mxfp4 quantization + _sorted_scale_cols = inter_dim // 32 + _sorted_scale_cols_i32 = arith.constant(_sorted_scale_cols, type=T.i32) + sorted_scale_rsrc = None + + # ---- persist_m loop (same pattern as stage2) ---- + _PERSIST_M = persist_m + _c0_p = arith.constant(0, index=True) + _c1_p = arith.constant(1, index=True) + _c_pm = arith.constant(_PERSIST_M, index=True) + _for_persist = scf.ForOp(_c0_p, _c_pm, _c1_p) + _for_ip = ir.InsertionPoint(_for_persist.body) + _for_ip.__enter__() + _mi_p = _for_persist.induction_variable + bx = bx_persist * _c_pm + _mi_p + bx_m = bx * arith.constant(sort_block_m, index=True) + + # Block validity + bx_m_i32 = arith.index_cast(T.i32, bx_m) + blk_valid = arith.cmpi(CmpIPredicate.ult, bx_m_i32, num_valid_i32) expert_i32 = buffer_ops.buffer_load( expert_rsrc, bx, vec_width=1, dtype=T.i32 ) + expert_idx = arith.index_cast(ir.IndexType.get(), expert_i32) exp_valid = arith.cmpi( - arith.CmpIPredicate.ult, expert_i32, fx.Int32(experts) - ) # todo fix - _ifexpert_of = scf.IfOp(exp_valid) - with _if_then(_ifexpert_of): - expert_idx = arith.index_cast(T.index, expert_i32) - inter2_idx = fx.Index(2 * inter_dim) - expert_off_idx = expert_idx * inter2_idx # index + CmpIPredicate.ult, expert_i32, arith.constant(experts, type=T.i32) + ) - bx_m = bx * fx.Index(tile_m) + def _moe_gemm1_body(): + # Gate expert offset: first inter_dim rows of each expert's 2*inter_dim block + expert_off_idx = expert_idx * arith.constant(2 * inter_dim, index=True) - # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- - # Keep a fixed 16B gmem->reg schedule (dwordx4) to match preshuffle_gemm_flyc.py. - if bytes_per_thread_x % 16 != 0: - raise ValueError( - f"bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" - ) + # X loading -- KEY DIFFERENCE from stage2: X row = token_id only x_load_bytes = 16 num_x_loads = bytes_per_thread_x // x_load_bytes - chunk_i32 = x_load_bytes // 4 # dwords per chunk (1/2/4) + chunk_i32 = x_load_bytes // 4 - # Work in dword units along K: K_dwords = (K_packed_bytes)/4. - # For fp4, 2 elements per byte, so divide by a_elem_vec_pack. - c_a_pack = fx.Index(int(a_elem_vec_pack)) c_k_div4 = ( - (k_in // c_a_pack) * fx.Index(int(elem_bytes)) - ) // arith.index(4) - c_k_div4_i32 = arith.index_cast(T.i32, c_k_div4) - tile_k_dwords = (int(tile_k) * int(elem_bytes)) // 4 + (k_in / c_a_pack) * arith.constant(int(a_elem_bytes), index=True) + ) / arith.index(4) + tile_k_dwords = (int(tile_k) * int(a_elem_bytes)) // ( + 4 * int(a_elem_vec_pack) + ) layout_x_tile_div4 = fx.make_layout( (tile_m, tile_k_dwords), stride=(tile_k_dwords, 1) ) - c_chunk_i32 = fx.Index(chunk_i32) + c_chunk_i32 = arith.constant(chunk_i32, index=True) tx_i32_base = tx * c_chunk_i32 - mask24 = fx.Int32(0xFFFFFF) - # Keep i32 constants available for epilogue index math. - topk_i32 = fx.Int32(topk) + topk_i32 = arith.constant(topk) + mask24 = arith.constant(0xFFFFFF) tokens_i32 = arith.index_cast(T.i32, tokens_in) def x_tile_chunk_coord_i32(i: int): @@ -479,46 +632,9 @@ def x_tile_chunk_coord_i32(i: int): chunk_i32=chunk_i32, ) - # CK-aligned: decode token once (per thread's M-slice) and build a base row offset. - x_row_base_div4 = [] - x_row_valid = [] - x_col_local_i32 = [] - x_row_local = [] - for i in range_constexpr(num_x_loads): - row_local, col_local_i32 = x_tile_chunk_coord_i32(i) - x_row_local.append(row_local) - x_col_local_i32.append(col_local_i32) - - sorted_row_i = bx_m + row_local - sorted_row_i32 = arith.index_cast(T.i32, sorted_row_i) - row_valid = arith.cmpi( - arith.CmpIPredicate.ult, sorted_row_i32, max_token_id_i32 - ) - sorted_row_safe = arith.select( - row_valid, sorted_row_i, fx.Index(0) - ) - fused_i = buffer_ops.buffer_load( - sorted_rsrc, sorted_row_safe, vec_width=1, dtype=T.i32 - ) - t_i32 = fused_i & mask24 - s_i32 = fused_i >> fx.Int32(24) - t_valid = arith.cmpi(arith.CmpIPredicate.ult, t_i32, tokens_i32) - s_valid = arith.cmpi(arith.CmpIPredicate.ult, s_i32, topk_i32) - ts_valid = row_valid & (t_valid & s_valid) - x_row_valid.append(ts_valid) - t_safe = arith.select(ts_valid, t_i32, fx.Int32(0)) - t_idx = arith.index_cast(T.index, t_safe) - x_row_base_div4.append(t_idx * c_k_div4) - - vec4_i32 = T.vec(4, T.i32) - def load_x(idx_i32): - """Load `x_load_bytes` bytes from X (gmem) into regs. - - For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. - """ idx_elem = ( - idx_i32 if elem_bytes == 1 else (idx_i32 * arith.index(2)) + idx_i32 if a_elem_bytes == 1 else (idx_i32 * arith.index(2)) ) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -529,240 +645,279 @@ def load_x(idx_i32): vec_elems=vec16_elems, ) - _zero_row_idx = fx.Index(0) + # Decode sorted token ids -- stage1: X row = token_id (not t*topk+s) + x_row_base_div4 = [] + x_col_local_i32 = [] + x_row_local = [] + # Also store token_id and slot_id for output indexing + + for i in range_constexpr(num_x_loads): + row_local, col_local_i32 = x_tile_chunk_coord_i32(i) + x_row_local.append(row_local) + x_col_local_i32.append(col_local_i32) + + sorted_row_i = bx_m + row_local + fused_i = buffer_ops.buffer_load( + sorted_rsrc, sorted_row_i, vec_width=1, dtype=T.i32 + ) + t_i32 = arith.andi(fused_i, mask24) + s_i32 = arith.shrui(fused_i, arith.constant(24)) + t_valid = arith.cmpi(CmpIPredicate.ult, t_i32, tokens_i32) + s_valid = arith.cmpi(CmpIPredicate.ult, s_i32, topk_i32) + ts_valid = arith.andi(t_valid, s_valid) + t_safe = arith.select(ts_valid, t_i32, arith.constant(0)) + + # KEY: X row base uses token_id only (not t*topk+s) + t_idx = arith.index_cast(ir.IndexType.get(), t_safe) + x_row_base_div4.append(t_idx * c_k_div4) def load_x_tile(base_k): - """Prefetch the per-thread X tile portion (gmem -> regs) for a given K base (in elements).""" base_k_div4 = ( - (base_k // c_a_pack) - * fx.Index(int(elem_bytes)) - ) // arith.index(4) - zero_x_i32 = arith.constant_vector(0, vec4_i32) + (base_k / c_a_pack) + * arith.constant(int(a_elem_bytes), index=True) + ) / arith.index(4) parts = [] for i in range_constexpr(num_x_loads): - safe_base = arith.select( - x_row_valid[i], x_row_base_div4[i], _zero_row_idx - ) - idx_i32 = safe_base + base_k_div4 + x_col_local_i32[i] + idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] x_vec = load_x(idx_i32) - x_i32 = vector.bitcast(vec4_i32, x_vec) - x_i32 = arith.select(x_row_valid[i], x_i32, zero_x_i32) - parts.append(x_i32) + parts.append(vector.bitcast(T.vec(4, i32), x_vec)) return parts - # tx -> wave/lane (GEMM-style decomposition). - wave_id, lane_id = split_row_major_2d(tx, fx.Index(64)) - lane_div_16, lane_mod_16 = split_row_major_2d( - lane_id, fx.Index(16) - ) - - # Match GEMM naming/pattern: row in LDS is lane_mod_16, and col base is lane_div_16*16B (KPackBytes=16). + # Wave/lane decomposition (identical to stage2) + coord_wl = idx2crd(tx, layout_tx_wave_lane) + wave_id = layout_get(coord_wl, 0) + lane_id = layout_get(coord_wl, 1) + coord_l16 = idx2crd(lane_id, layout_lane16) + lane_div_16 = layout_get(coord_l16, 0) + lane_mod_16 = layout_get(coord_l16, 1) row_a_lds = lane_mod_16 - col_offset_base = lane_div_16 * fx.Index(16) + col_offset_base = lane_div_16 * arith.constant(16, index=True) - # Dynamic N tiling within block (same as existing kernels) - num_waves = 4 - n_per_wave = tile_n // num_waves num_acc_n = n_per_wave // 16 - c_n_per_wave = fx.Index(n_per_wave) - wave_mod_4 = wave_id % arith.index(4) - n_tile_base = wave_mod_4 * c_n_per_wave - - # fp4 pack - k_unroll_packed = k_unroll // pack_K - m_repeat_packed = m_repeat // pack_M - num_acc_n_packed = num_acc_n // pack_N + c_n_per_wave = arith.constant(n_per_wave, index=True) + wave_n_id = wave_id % arith.constant(num_waves, index=True) + n_tile_base = wave_n_id * c_n_per_wave - # Precompute gate/up B coordinates and output columns for CK-style stage1: - # weights/scales are laid out as [gate rows][up rows], so each output column - # pairs one gate row with the matching up row at +inter_dim. - col_g_list = [] - inter_idx = fx.Index(inter_dim) - out_block_base = by_n - out_wave_base = n_tile_base + # N-tile precompute for gate AND up weights gate_n_intra_list = [] gate_n_blk_list = [] - up_n_intra_list = [] - up_n_blk_list = [] + if not gate_only: + up_n_intra_list = [] + up_n_blk_list = [] + c_n0_static = experts * (2 * inter_dim) // 16 + layout_n_blk_intra = fx.make_layout((c_n0_static, 16), stride=(16, 1)) + inter_idx = arith.constant(inter_dim, index=True) + for i in range_constexpr(num_acc_n): offset = i * 16 - c_offset = fx.Index(offset) - out_col = ( - out_block_base + out_wave_base + c_offset + lane_mod_16 - ) - col_g_list.append(out_col) + c_offset = arith.constant(offset, index=True) - gate_row_w = expert_off_idx + out_col - gate_n_blk, gate_n_intra = split_row_major_2d( - gate_row_w, fx.Index(16) - ) - gate_n_blk_list.append(gate_n_blk) - gate_n_intra_list.append(gate_n_intra) + global_n = by_n + n_tile_base + c_offset + lane_mod_16 + # Gate: rows [expert_off, expert_off + inter_dim) + # For gate_only, by_n covers [0, 2*inter_dim) so this + # indexes into both gate and up regions naturally. + gate_row_w = expert_off_idx + global_n + gate_coord = idx2crd(gate_row_w, layout_n_blk_intra) + gate_n_blk_list.append(layout_get(gate_coord, 0)) + gate_n_intra_list.append(layout_get(gate_coord, 1)) + if not gate_only: + # Up: rows [expert_off + inter_dim, expert_off + 2*inter_dim) + up_row_w = gate_row_w + inter_idx + up_coord = idx2crd(up_row_w, layout_n_blk_intra) + up_n_blk_list.append(layout_get(up_coord, 0)) + up_n_intra_list.append(layout_get(up_coord, 1)) - up_row_w = gate_row_w + inter_idx - up_n_blk, up_n_intra = split_row_major_2d( - up_row_w, fx.Index(16) - ) - up_n_blk_list.append(up_n_blk) - up_n_intra_list.append(up_n_intra) + m_repeat = tile_m // 16 + k_unroll = tile_k_bytes // 128 + k_unroll_packed = k_unroll // pack_K + m_repeat_packed = m_repeat // pack_M + num_acc_n_packed = num_acc_n // pack_N - # --- B Load Logic (K64) - shared layout with preshuffle GEMM --- + # B load for gate and up separately def load_b_packs_k64(base_k, ku: int, n_blk, n_intra): - # K64 micro-step = 2x K32 MFMA steps. Reuse the shared helper. - b0 = load_b_pack_k32( - buffer_ops, - arith, - vector, - arg_b=arg_w, - b_rsrc=w_rsrc, - layout_b=layout_b, - base_k=base_k, - ki_step=ku * 2, - n_blk=n_blk, - n_intra=n_intra, - lane_div_16=lane_div_16, - elem_type=_w_elem_type(is_f4_b=is_f4_b, is_f16_b=is_f16), - kpack_bytes=kpack_bytes, - elem_bytes=b_elem_bytes, - unpack_int4=bool(is_int4), + c64 = arith.constant(64, index=True) + base_k_bytes = base_k * arith.constant( + int(b_elem_bytes), index=True ) - b1 = load_b_pack_k32( + k0 = base_k_bytes // c64 + arith.constant(ku, index=True) + k1 = lane_div_16 + coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True)) + idx_pack = crd2idx(coord_pack, layout_b) + vec_elems = kpack_bytes // int(b_elem_bytes) + b16 = _buffer_load_vec( buffer_ops, - arith, vector, - arg_b=arg_w, - b_rsrc=w_rsrc, - layout_b=layout_b, - base_k=base_k, - ki_step=ku * 2 + 1, - n_blk=n_blk, - n_intra=n_intra, - lane_div_16=lane_div_16, - elem_type=_w_elem_type(is_f4_b=is_f4_b, is_f16_b=is_f16), - kpack_bytes=kpack_bytes, + w_rsrc, + idx_pack, + elem_type=_w_elem_type(), + vec_elems=vec_elems, elem_bytes=b_elem_bytes, - unpack_int4=bool(is_int4), + offset_in_bytes=(b_elem_bytes == 1), + cache_modifier=b_nt, + ) + b_i64x2 = vector.bitcast(vec2_i64, b16) + b0 = vector.extract( + b_i64x2, static_position=[0], dynamic_position=[] + ) + b1 = vector.extract( + b_i64x2, static_position=[1], dynamic_position=[] ) return b0, b1 def load_b_tile(base_k): + """Load B tiles. Returns (gate_b_tile, up_b_tile). + When gate_only, up_b_tile is None.""" gate_b_tile = [] - up_b_tile = [] + up_b_tile = [] if not gate_only else None for ku in range_constexpr(k_unroll): - gate_packs0 = [] - gate_packs1 = [] - up_packs0 = [] - up_packs1 = [] + g_packs0, g_packs1 = [], [] + u_packs0, u_packs1 = [], [] for ni in range_constexpr(num_acc_n): - gate_b0, gate_b1 = load_b_packs_k64( - base_k, - ku, - gate_n_blk_list[ni], - gate_n_intra_list[ni], + gb0, gb1 = load_b_packs_k64( + base_k, ku, gate_n_blk_list[ni], gate_n_intra_list[ni] ) - up_b0, up_b1 = load_b_packs_k64( - base_k, - ku, - up_n_blk_list[ni], - up_n_intra_list[ni], - ) - gate_packs0.append(gate_b0) - gate_packs1.append(gate_b1) - up_packs0.append(up_b0) - up_packs1.append(up_b1) - gate_b_tile.append((gate_packs0, gate_packs1)) - up_b_tile.append((up_packs0, up_packs1)) + g_packs0.append(gb0) + g_packs1.append(gb1) + if not gate_only: + ub0, ub1 = load_b_packs_k64( + base_k, ku, up_n_blk_list[ni], up_n_intra_list[ni] + ) + u_packs0.append(ub0) + u_packs1.append(ub1) + gate_b_tile.append((g_packs0, g_packs1)) + if not gate_only: + up_b_tile.append((u_packs0, u_packs1)) return gate_b_tile, up_b_tile - def load_scale(arg_scale, rsrc, scale_info, ku, mni): - k_lane = lane_div_16 - n_lane = lane_mod_16 - # Direct arith crd2idx: idx = mni*stride_n0 + ku*stride_k0 + k_lane*stride_klane + n_lane - idx_pack = ( - mni * scale_info.stride_n0 - + ku * scale_info.stride_k0 - + k_lane * scale_info.stride_klane - + n_lane + # Pre-compute scale base element indices (K-loop invariant). + # idx = mni * stride_n0 + ku * stride_k0 + k_lane * stride_klane + n_lane + # Split into: base_elem = mni * stride_n0 + lane_elem (invariant) + # k_elem = ku * stride_k0 (per-iteration) + _scale_lane_elem = ( + lane_div_16 * layout_b_scale.stride_klane + lane_mod_16 + ) + + _gate_scale_bases = [] + if not gate_only: + _up_scale_bases = [] + for _ni in range_constexpr(num_acc_n_packed): + _col_base = ( + by_n + + n_tile_base + + arith.constant(_ni * 16 * pack_N, index=True) ) - s = buffer_ops.buffer_load( - rsrc, idx_pack, vec_width=1, dtype=T.i32 + _gate_mni = (expert_off_idx + _col_base) // arith.constant( + 32, index=True ) - return vector.from_elements(T.vec(1, T.i32), [s]) + _gate_scale_bases.append( + _gate_mni * layout_b_scale.stride_n0 + _scale_lane_elem + ) + if not gate_only: + _up_mni = ( + expert_off_idx + inter_idx + _col_base + ) // arith.constant(32, index=True) + _up_scale_bases.append( + _up_mni * layout_b_scale.stride_n0 + _scale_lane_elem + ) - def load_scale_masked(arg_scale, rsrc, scale_info, ku, mni, valid): - safe_mni = arith.select( - valid, mni, fx.Index(0) + _a_scale_bases = [] + for _mi in range_constexpr(m_repeat_packed): + _a_mni = _mi + bx_m // scale_mn_pack // 16 + _a_scale_bases.append( + _a_mni * layout_a_scale.stride_n0 + _scale_lane_elem ) - scale_i32 = vector.extract( - load_scale(arg_scale, rsrc, scale_info, ku, safe_mni), - static_position=[0], - dynamic_position=[], + + _c16_idx = arith.constant(16, index=True) + _c2_idx = arith.constant(2, index=True) + _scale_mask_lo = arith.constant(0xFF, type=T.i32) + + if pack_M < scale_mn_pack: + _m_half_idx = (bx_m // _c16_idx) % _c2_idx + _m_half_i32 = arith.index_cast(T.i32, _m_half_idx) + _scale_shift = _m_half_i32 * arith.constant(8, type=T.i32) + _scale_shift_hi = _scale_shift + arith.constant(16, type=T.i32) + + if pack_N < scale_mn_pack: + _n_half_idx = (n_tile_base // _c16_idx) % _c2_idx + _n_half_i32 = arith.index_cast(T.i32, _n_half_idx) + _bscale_shift = _n_half_i32 * arith.constant(8, type=T.i32) + _bscale_shift_hi = _bscale_shift + arith.constant(16, type=T.i32) + + def _rearrange_a_scale(raw_i32): + """Rearrange scale bytes for pack_M=1: extract m_half's k0,k1 bytes.""" + if pack_M >= scale_mn_pack: + return raw_i32 + b_k0 = arith.andi( + arith.shrui(raw_i32, _scale_shift), _scale_mask_lo + ) + b_k1 = arith.andi( + arith.shrui(raw_i32, _scale_shift_hi), _scale_mask_lo + ) + return arith.ori( + b_k0, arith.shli(b_k1, arith.constant(8, type=T.i32)) ) - scale_i32 = arith.select(valid, scale_i32, fx.Int32(0)) - return vector.from_elements(T.vec(1, T.i32), [scale_i32]) - def load_b_scale_tile(base_k): - gate_b_scale_tile = [] - up_b_scale_tile = [] - for ku in range_constexpr(k_unroll_packed): - for ni in range_constexpr(num_acc_n_packed): - col_offset = ni * 16 * pack_N - col_offset_idx = fx.Index(col_offset) - col_base = ( - out_block_base + out_wave_base + col_offset_idx - ) - col_valid = arith.cmpi( - arith.CmpIPredicate.ult, col_base, inter_idx - ) - gate_mni = ( - expert_off_idx + col_base - ) // fx.Index(32) - up_mni = ( - expert_off_idx + inter_idx + col_base - ) // fx.Index(32) - gate_scale_i32 = load_scale_masked( - arg_scale_w, - sw_rsrc, - layout_b_scale, - ku + base_k, - gate_mni, - col_valid, - ) - up_scale_i32 = load_scale_masked( - arg_scale_w, - sw_rsrc, - layout_b_scale, - ku + base_k, - up_mni, - col_valid, - ) - gate_b_scale_tile.append(gate_scale_i32) - up_b_scale_tile.append(up_scale_i32) - return gate_b_scale_tile, up_b_scale_tile + def _rearrange_b_scale(raw_i32): + """Rearrange scale bytes for pack_N=1: extract n_half's k0,k1 bytes.""" + if pack_N >= scale_mn_pack: + return raw_i32 + b_k0 = arith.andi( + arith.shrui(raw_i32, _bscale_shift), _scale_mask_lo + ) + b_k1 = arith.andi( + arith.shrui(raw_i32, _bscale_shift_hi), _scale_mask_lo + ) + return arith.ori( + b_k0, arith.shli(b_k1, arith.constant(8, type=T.i32)) + ) - def load_a_scale_tile(base_k): + def prefetch_ab_scale_tile(base_k): a_scale_tile = [] + gate_b_scale = [] + up_b_scale = [] if not gate_only else None for ku in range_constexpr(k_unroll_packed): + k_off = (ku + base_k) * layout_b_scale.stride_k0 for mi in range_constexpr(m_repeat_packed): - scale = load_scale( - arg_scale_x, + s = buffer_ops.buffer_load( sx_rsrc, - layout_a_scale, - ku + base_k, - mi + bx_m // pack_M // 16, + _a_scale_bases[mi] + k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, ) - a_scale_tile.append(scale) - return a_scale_tile - - def prefetch_ab_scale_tile(base_k): - gate_bs, up_bs = load_b_scale_tile(base_k) - return [load_a_scale_tile(base_k), gate_bs, up_bs] + s = _rearrange_a_scale(s) + a_scale_tile.append( + vector.from_elements(T.vec(1, T.i32), [s]) + ) + for ni in range_constexpr(num_acc_n_packed): + gs = buffer_ops.buffer_load( + sw_rsrc, + _gate_scale_bases[ni] + k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + gs = _rearrange_b_scale(gs) + gate_b_scale.append( + vector.from_elements(T.vec(1, T.i32), [gs]) + ) + if not gate_only: + us = buffer_ops.buffer_load( + sw_rsrc, + _up_scale_bases[ni] + k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + us = _rearrange_b_scale(us) + up_b_scale.append( + vector.from_elements(T.vec(1, T.i32), [us]) + ) + return [a_scale_tile, gate_b_scale, up_b_scale] - acc_gate = [acc_init] * (num_acc_n * m_repeat) - acc_up = [acc_init] * (num_acc_n * m_repeat) + _lds_base_zero = arith.index(0) - # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- - def store_x_tile_to_lds(vec_x_in_parts, lds_base): + def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] @@ -770,36 +925,87 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): lds_store_16b_xor16( arith, vector, - lds_memref=lds_x, + lds_memref=lds_buffer, vec16_ty=vec16_x, layout_lds=layout_lds, row_local=row_local, col_local_i32=col_local_i32, tx_c4=arith.index(4), k_blocks16=k_blocks16, - lds_base=lds_base, + lds_base=_lds_base_zero, vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - # --- A LDS load helper for K64 (load 16B once, extract 2x i64 halves) --- - def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): - # Swizzle in bytes, then convert to element offset for memref indexing. + if use_async_copy: + _dma_bytes = 16 + _wave_size = 64 + _eff_bytes_per_buffer = ( + int(tile_m) * int(_eff_lds_stride) * int(a_elem_bytes) + ) + _num_dma_loads = max( + 1, _eff_bytes_per_buffer // (total_threads * _dma_bytes) + ) + + def dma_x_tile_to_lds(base_k, lds_buffer): + c4_idx = arith.index(4) + base_k_div4 = ( + (base_k / c_a_pack) + * arith.constant(int(elem_bytes), index=True) + ) / arith.index(4) + + lds_ptr_i64 = None + for i in range_constexpr(_num_dma_loads): + row_local_i = x_row_local[i] + col_local_i32_i = x_col_local_i32[i] + col_local_sw = swizzle_xor16( + row_local_i, col_local_i32_i * c4_idx, k_blocks16 + ) + row_k_dw = x_row_base_div4[i] + base_k_div4 + global_byte_idx = row_k_dw * c4_idx + col_local_sw + global_offset = arith.index_cast(T.i32, global_byte_idx) + + if i == 0: + lds_addr = memref.extract_aligned_pointer_as_index( + lds_buffer + ) + wave_id * arith.constant( + _wave_size * _dma_bytes, index=True + ) + lds_ptr_i64 = rocdl.readfirstlane( + T.i64, arith.index_cast(T.i64, lds_addr) + ) + else: + lds_ptr_i64 = lds_ptr_i64 + arith.constant( + total_threads * _dma_bytes, type=T.i64 + ) + + lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") + lds_ptr = llvm.inttoptr(lds_ptr_type, lds_ptr_i64) + + rocdl.raw_ptr_buffer_load_lds( + x_rsrc, + lds_ptr, + arith.constant(_dma_bytes, type=T.i32), + global_offset, + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + ) + + def prefetch_x_to_lds(base_k, lds_buffer): + dma_x_tile_to_lds(base_k, lds_buffer) + + def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): col_base_swz_bytes = swizzle_xor16( curr_row_a_lds, col_base, k_blocks16 ) col_base_swz = ( col_base_swz_bytes if elem_bytes == 1 - else (col_base_swz_bytes // arith.index(2)) - ) - idx_a16 = lds_row_major_idx( - curr_row_a_lds, - col_base_swz, - fx.Index(lds_stride), - lds_base, + else (col_base_swz_bytes / arith.index(2)) ) - loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) + idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds) + loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16]) a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) a0 = vector.extract( a_i64x2, static_position=[0], dynamic_position=[] @@ -809,136 +1015,127 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): ) return a0, a1 - def compute_f8f6f4_tile( + def prefetch_full_a_from_lds(lds_buffer): + """Load entire A tile from LDS into registers before compute.""" + a_regs = [] + for k_idx in range_constexpr(k_unroll): + col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack + for mi_idx in range_constexpr(m_repeat): + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row = row_a_lds + mi_val + a0, a1 = lds_load_packs_k64(curr_row, col_base, lds_buffer) + if is_f8_a: + a2, a3 = lds_load_packs_k64( + curr_row, col_base + 64, lds_buffer + ) + a_regs.append((a0, a1, a2, a3)) + else: + a_regs.append((a0, a1)) + return a_regs + + # Compute tile: gate + up MFMA interleaved, same A data, different B data. + # Two accumulator sets; after all K tiles, acc = acc_gate + acc_up (f32 add). + def compute_tile( acc_gate_in, acc_up_in, gate_b_tile_in, up_b_tile_in, - lds_base, - *, - a0_prefetch=None, + a_tile_regs, a_scale=None, gate_b_scale=None, up_b_scale=None, - prefetch_epilogue: bool = False, + *, + prefetch_epilogue=False, ): gate_list = list(acc_gate_in) - up_list = list(acc_up_in) - - # Re-sync all threads before consuming the current LDS tile. - gpu.barrier() - rocdl.sched_barrier(0) - + up_list = list(acc_up_in) if not gate_only else None + mfma_res_ty = vec4_f32 epilogue_pf = None - if enable_bias and prefetch_epilogue: - gate_bias = [] - up_bias = [] - for ni in range_constexpr(num_acc_n): - global_n = by_n + n_tile_base + ni * 16 + lane_mod_16 - gate_offset = expert_off_idx + global_n - up_offset = expert_off_idx + global_n + inter_dim - gate_bias.append( - buffer_ops.buffer_load( - bias_rsrc, gate_offset, vec_width=1, dtype=T.f32 - ) - ) - up_bias.append( - buffer_ops.buffer_load( - bias_rsrc, up_offset, vec_width=1, dtype=T.f32 + if prefetch_epilogue and doweight_stage1: + tw_pf = [] + lane_div_16_mul4_pf = lane_div_16 * arith.index(4) + ii_idx_list_pf = [ + arith.constant(ii, index=True) for ii in range(4) + ] + for mi in range_constexpr(m_repeat): + mi_base_pf = arith.constant(mi * 16, index=True) + for ii in range_constexpr(4): + row_off_pf = lane_div_16_mul4_pf + ii_idx_list_pf[ii] + sorted_row_pf = bx_m + mi_base_pf + row_off_pf + tw_pf.append( + buffer_ops.buffer_load( + sorted_w_rsrc, + sorted_row_pf, + vec_width=1, + dtype=f32, + ) ) - ) - epilogue_pf = (gate_bias, up_bias) - - if (int(tile_k) % 128) != 0: - raise ValueError( - f"tile_k must be divisible by 128 for mfma_scale_x128, got tile_k={tile_k}" - ) + epilogue_pf = (None, tw_pf, None) - mfma_res_ty = T.f32x4 + c0_i64 = arith.constant(0, type=T.i64) vec4_i64 = T.vec(4, T.i64) vec8_i32 = T.vec(8, T.i32) - c0_i64 = arith.constant(0, type=T.i64) def pack_i64x4_to_i32x8(x0, x1, x2, x3): v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) return vector.bitcast(vec8_i32, v4) - # Gate and Up MFMA interleaved in the same inner loop. + # B-major: fix B (ni), cycle A (mi) -- B from VMEM stays + # in registers while A from LDS is repacked per mi. for ku128 in range_constexpr(k_unroll_packed): - for mi in range_constexpr(m_repeat_packed): - a_scale_i32 = a_scale[ku128 * m_repeat_packed + mi] - a_scale_val = vector.extract( - a_scale_i32, + for ni in range_constexpr(num_acc_n_packed): + gate_bs_i32 = gate_b_scale[ku128 * num_acc_n_packed + ni] + gate_bs_val = vector.extract( + gate_bs_i32, static_position=[0], dynamic_position=[], ) - for ni in range_constexpr(num_acc_n_packed): - gate_bs_i32 = gate_b_scale[ - ku128 * num_acc_n_packed + ni - ] - gate_bs_val = vector.extract( - gate_bs_i32, - static_position=[0], - dynamic_position=[], - ) - up_bs_i32 = up_b_scale[ - ku128 * num_acc_n_packed + ni - ] + if not gate_only: + up_bs_i32 = up_b_scale[ku128 * num_acc_n_packed + ni] up_bs_val = vector.extract( - up_bs_i32, - static_position=[0], - dynamic_position=[], + up_bs_i32, static_position=[0], dynamic_position=[] ) - for ikxdl in range_constexpr(pack_K): - k_idx = ku128 * pack_K + ikxdl - gate_bp0, gate_bp1 = gate_b_tile_in[k_idx] + for ikxdl in range_constexpr(pack_K): + k_idx = ku128 * pack_K + ikxdl + gate_bp0, gate_bp1 = gate_b_tile_in[k_idx] + if not gate_only: up_bp0, up_bp1 = up_b_tile_in[k_idx] - col_base = ( - col_offset_base - + (k_idx * 128) // a_elem_vec_pack + for inxdl in range_constexpr(pack_N): + ni_idx = ni * pack_N + inxdl + gb0 = gate_bp0[ni_idx] + gb1 = gate_bp1[ni_idx] + gb128 = pack_i64x4_to_i32x8( + gb0, gb1, c0_i64, c0_i64 ) - for imxdl in range_constexpr(pack_M): - col_base0 = col_base - mi_idx = mi * pack_M + imxdl - mi_val = arith.constant( - mi_idx * 16, index=True + if not gate_only: + ub0 = up_bp0[ni_idx] + ub1 = up_bp1[ni_idx] + ub128 = pack_i64x4_to_i32x8( + ub0, ub1, c0_i64, c0_i64 ) - curr_row_a_lds = row_a_lds + mi_val - - a0, a1 = lds_load_packs_k64( - curr_row_a_lds, col_base0, lds_base + for mi in range_constexpr(m_repeat_packed): + a_scale_i32 = a_scale[ + ku128 * m_repeat_packed + mi + ] + a_scale_val = vector.extract( + a_scale_i32, + static_position=[0], + dynamic_position=[], ) - - if is_f8_a: - col_base1 = col_base + 64 - a2, a3 = lds_load_packs_k64( - curr_row_a_lds, col_base1, lds_base - ) - a128 = pack_i64x4_to_i32x8( - a0, a1, a2, a3 - ) - else: - a128 = pack_i64x4_to_i32x8( - a0, a1, c0_i64, c0_i64 - ) - - for inxdl in range_constexpr(pack_N): - ni_idx = ni * pack_N + inxdl + for imxdl in range_constexpr(pack_M): + mi_idx = mi * pack_M + imxdl + _a_reg_idx = k_idx * m_repeat + mi_idx + if is_f8_a: + a0, a1, a2, a3 = a_tile_regs[_a_reg_idx] + a128 = pack_i64x4_to_i32x8( + a0, a1, a2, a3 + ) + else: + a0, a1 = a_tile_regs[_a_reg_idx] + a128 = pack_i64x4_to_i32x8( + a0, a1, c0_i64, c0_i64 + ) acc_idx = mi_idx * num_acc_n + ni_idx - - gb0 = gate_bp0[ni_idx] - gb1 = gate_bp1[ni_idx] - gb128 = pack_i64x4_to_i32x8( - gb0, gb1, c0_i64, c0_i64 - ) - - ub0 = up_bp0[ni_idx] - ub1 = up_bp1[ni_idx] - ub128 = pack_i64x4_to_i32x8( - ub0, ub1, c0_i64, c0_i64 - ) - - rocdl.sched_barrier(0) gate_list[acc_idx] = ( rocdl.mfma_scale_f32_16x16x128_f8f6f4( mfma_res_ty, @@ -955,376 +1152,992 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): ], ) ) - rocdl.sched_barrier(0) - up_list[acc_idx] = ( - rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [ - a128, - ub128, - up_list[acc_idx], - cbsz, - blgp, - ikxdl * pack_M + imxdl, - a_scale_val, - ikxdl * pack_N + inxdl, - up_bs_val, - ], + if not gate_only: + up_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + ub128, + up_list[acc_idx], + cbsz, + blgp, + ikxdl * pack_M + imxdl, + a_scale_val, + ikxdl * pack_N + inxdl, + up_bs_val, + ], + ) ) - ) return gate_list, up_list, epilogue_pf - # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- - lds_tile_elems = fx.Index(tile_m * lds_stride) - lds_base_cur = arith.index(0) - lds_base_nxt = lds_tile_elems + def load_a_subtile(k_idx, mi_idx, lds_buffer): + """Load a single A sub-tile from LDS (one ds_read).""" + col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row = row_a_lds + mi_val + a0, a1 = lds_load_packs_k64(curr_row, col_base, lds_buffer) + if is_f8_a: + a2, a3 = lds_load_packs_k64(curr_row, col_base + 64, lds_buffer) + return (a0, a1, a2, a3) + else: + return (a0, a1) + + def compute_bmajor_mfma_phase( + all_a_tiles, + gate_b_single, + up_b_single, + a_scale_vals, + gate_bs_val, + up_bs_val, + gate_list, + up_list, + k_idx, + ni_idx, + ikxdl, + inxdl, + ): + """B-major MFMA: fix one B (ni), cycle all A tiles (mi). - # Optional scheduler hints (copied from tuned GEMM); can be disabled via env. - rocdl.sched_barrier(0) + Packs B once and reuses across all mi iterations. + A tiles come from LDS (already available, no VMEM wait). - def hot_loop_scheduler(): - mfma_group = num_acc_n * 2 - # K64 micro-step: 2x K32 MFMA per gemm. - mfma_total = (k_unroll * 2) * m_repeat * mfma_group - mfma_per_iter = 2 * mfma_group - sche_iters = ( - 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) - ) + all_a_tiles: flat list indexed by [k*m_repeat + mi]. + gate_b_single/up_b_single: (b0, b1) for one specific ni. + When gate_only, up_b_single is None. + a_scale_vals: list of A scale scalars indexed by mi_packed. + """ + c0_i64 = arith.constant(0, type=T.i64) + vec4_i64 = T.vec(4, T.i64) + vec8_i32 = T.vec(8, T.i32) - # DS-read preload (CK default is 2); clamp to non-negative. - rocdl.sched_dsrd(2) - rocdl.sched_mfma(2) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) + def _pack(x0, x1, x2, x3): + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) - # DS-write hints near the end: match total X LDS-store micro-ops per thread. - dswr_tail = num_x_loads - if dswr_tail > sche_iters: - dswr_tail = sche_iters - dswr_start = sche_iters - dswr_tail - for sche_i in range_constexpr(sche_iters): - rocdl.sched_vmem(1) - rocdl.sched_mfma(mfma_group) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(mfma_group) - if sche_i >= dswr_start - 1: - rocdl.sched_dswr(1) + mfma_res_ty = vec4_f32 + gb128 = _pack(gate_b_single[0], gate_b_single[1], c0_i64, c0_i64) + if not gate_only: + ub128 = _pack(up_b_single[0], up_b_single[1], c0_i64, c0_i64) + + for mi_p in range_constexpr(m_repeat_packed): + a_scale_val = a_scale_vals[mi_p] + for imxdl in range_constexpr(pack_M): + mi_idx = mi_p * pack_M + imxdl + a_reg = all_a_tiles[k_idx * m_repeat + mi_idx] + + if is_f8_a: + a128 = _pack(a_reg[0], a_reg[1], a_reg[2], a_reg[3]) + else: + a128 = _pack(a_reg[0], a_reg[1], c0_i64, c0_i64) + + acc_idx = mi_idx * num_acc_n + ni_idx + gate_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + gb128, + gate_list[acc_idx], + cbsz, + blgp, + ikxdl * pack_M + imxdl, + a_scale_val, + ikxdl * pack_N + inxdl, + gate_bs_val, + ], + ) + if not gate_only: + up_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + ub128, + up_list[acc_idx], + cbsz, + blgp, + ikxdl * pack_M + imxdl, + a_scale_val, + ikxdl * pack_N + inxdl, + up_bs_val, + ], + ) + ) + + def _interleaved_half( + lds_read, + lds_write, + next_k_dma_py, + next_k_load, + prev_a_tile, + prev_gate_w, + prev_up_w, + prev_a_scale, + prev_gate_bs, + prev_up_bs, + acc_gate, + acc_up, + ): + """One flatmm-style interleaved half-iteration (deep pipeline). + + Generalized for arbitrary m_repeat (block_m=32, 64, ...). + DMA targets lds_write (OTHER buffer) while ds_read uses + lds_read (already DMA'd in previous half). + + Interleaving schedule (per half): + Phase 0: scale VMEM + 2 ds_read(A) -> 4 MFMA(prev) + Phase 1..N: B VMEM(distributed) + 2 ds_read(A, if avail) -> 4 MFMA(prev) + Phase N+1..: remaining B VMEM -> 4 MFMA(prev) + """ + _abs_k = k_base_idx + arith.constant(next_k_load, index=True) + _bk = _abs_k // arith.constant(2, index=True) + _sk = _abs_k // arith.constant(pack_K * 128, index=True) + _k_off = _sk * layout_b_scale.stride_k0 + + rocdl.sched_barrier(0) + rocdl.s_waitcnt(3) + _barrier() rocdl.sched_barrier(0) - # Prologue: prefetch tile0, store to LDS(cur), sync. - k0 = arith.index(0) - x_regs0 = load_x_tile(k0) - gate_w0, up_w0 = load_b_tile(k0) + # DMA A to OTHER buffer (for next half), non-blocking + _abs_k_dma = k_base_idx + arith.constant(next_k_dma_py, index=True) + if use_async_copy and next_k_dma_py < int(_k_dim): + prefetch_x_to_lds(_abs_k_dma, lds_write) + if not use_async_copy: + _x_regs = load_x_tile(_abs_k_dma) + + # ---- Extract previous scale values ---- + _prev_asvs = [] + for _mi_p in range_constexpr(m_repeat_packed): + _prev_asvs.append( + vector.extract( + prev_a_scale[_mi_p], + static_position=[0], + dynamic_position=[], + ) + ) + _prev_gsv = vector.extract( + prev_gate_bs[0], + static_position=[0], + dynamic_position=[], + ) + if not gate_only: + _prev_usv = vector.extract( + prev_up_bs[0], + static_position=[0], + dynamic_position=[], + ) + # ---- Execute phases from unified schedule ---- + _a_all = {} + _b_gate_all = {} + _b_up_all = {} + + for _p in range_constexpr(_pipe_n_phases): + # Scale VMEM loads (phase 0 only) + if _pp_has_scale[_p]: + _new_as_list = [] + for _mi_p in range_constexpr(m_repeat_packed): + _raw_as = buffer_ops.buffer_load( + sx_rsrc, + _a_scale_bases[_mi_p] + _k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + _new_as_list.append(_rearrange_a_scale(_raw_as)) + _new_gs = buffer_ops.buffer_load( + sw_rsrc, + _gate_scale_bases[0] + _k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + _new_gs = _rearrange_b_scale(_new_gs) + if not gate_only: + _new_us = buffer_ops.buffer_load( + sw_rsrc, + _up_scale_bases[0] + _k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + _new_us = _rearrange_b_scale(_new_us) + + # B VMEM loads + for _b_j in range_constexpr(len(_pp_b_loads[_p])): + _b_type, _b_ku, _b_ni = _pp_b_loads[_p][_b_j] + if _b_type == "gate": + _b_gate_all[(_b_ku, _b_ni)] = load_b_packs_k64( + _bk, + _b_ku, + gate_n_blk_list[_b_ni], + gate_n_intra_list[_b_ni], + ) + else: + _b_up_all[(_b_ku, _b_ni)] = load_b_packs_k64( + _bk, + _b_ku, + up_n_blk_list[_b_ni], + up_n_intra_list[_b_ni], + ) + + # A ds_reads + rocdl.sched_barrier(0) + for _a_j in range_constexpr(len(_pp_a_reads[_p])): + _ak, _ami = _pp_a_reads[_p][_a_j] + _a_all[(_ak, _ami)] = load_a_subtile( + _ak, + _ami, + lds_read, + ) + rocdl.sched_barrier(0) + + # MFMAs on prev data + rocdl.s_setprio(1) + for _m_j in range_constexpr(len(_pp_mfma[_p])): + _k_idx, _ni_idx, _ikxdl, _inxdl, _ku128 = _pp_mfma[_p][_m_j] + _up_b_single = ( + ( + prev_up_w[_k_idx][0][_ni_idx], + prev_up_w[_k_idx][1][_ni_idx], + ) + if not gate_only + else None + ) + compute_bmajor_mfma_phase( + prev_a_tile, + ( + prev_gate_w[_k_idx][0][_ni_idx], + prev_gate_w[_k_idx][1][_ni_idx], + ), + _up_b_single, + _prev_asvs, + _prev_gsv, + _prev_usv if not gate_only else None, + acc_gate, + acc_up, + _k_idx, + _ni_idx, + _ikxdl, + _inxdl, + ) + rocdl.s_setprio(0) + rocdl.sched_barrier(0) + + # ---- Assemble loaded data for next half-iteration ---- + cur_a_tile = [] + for _k in range_constexpr(k_unroll): + for _mi in range_constexpr(m_repeat): + cur_a_tile.append(_a_all[(_k, _mi)]) + + cur_gate_w = [] + cur_up_w = None if gate_only else [] + for ku in range_constexpr(k_unroll): + g_packs0, g_packs1 = [], [] + u_packs0, u_packs1 = [], [] + for ni in range_constexpr(num_acc_n): + g = _b_gate_all[(ku, ni)] + g_packs0.append(g[0]) + g_packs1.append(g[1]) + if not gate_only: + u = _b_up_all[(ku, ni)] + u_packs0.append(u[0]) + u_packs1.append(u[1]) + cur_gate_w.append((g_packs0, g_packs1)) + if not gate_only: + cur_up_w.append((u_packs0, u_packs1)) + + cur_a_scale = [] + for _mi_p in range_constexpr(m_repeat_packed): + cur_a_scale.append( + vector.from_elements( + T.vec(1, T.i32), + [_new_as_list[_mi_p]], + ) + ) + cur_gate_bs = [vector.from_elements(T.vec(1, T.i32), [_new_gs])] + if not gate_only: + cur_up_bs = [vector.from_elements(T.vec(1, T.i32), [_new_us])] + else: + cur_up_bs = None + + if not use_async_copy: + store_x_tile_to_lds(_x_regs, lds_write) + + return ( + cur_a_tile, + cur_gate_w, + cur_up_w, + cur_a_scale, + cur_gate_bs, + cur_up_bs, + acc_gate, + acc_up, + ) + + # Pipeline (split ping/pong allocators) + rocdl.sched_barrier(0) + + k0 = k_base_idx + if use_async_copy: + prefetch_x_to_lds(k0, lds_x_pong) + else: + x_regs0 = load_x_tile(k0) + store_x_tile_to_lds(x_regs0, lds_x_pong) + rocdl.sched_barrier(0) + _k0_scale = k_base_idx // arith.constant(pack_K * 128, index=True) a_scale_pong, gate_bs_pong, up_bs_pong = prefetch_ab_scale_tile( - k0 // 2 + _k0_scale ) - store_x_tile_to_lds(x_regs0, lds_base_cur) - gpu.barrier() + _c_tile_m_idx = arith.constant(tile_m, index=True) + _tid_in_range = arith.cmpi(CmpIPredicate.ult, tx, _c_tile_m_idx) + _if_tid = scf.IfOp(_tid_in_range) + with ir.InsertionPoint(_if_tid.then_block): + _tid_row = bx_m + tx + _tid_val = buffer_ops.buffer_load( + sorted_rsrc, _tid_row, vec_width=1, dtype=T.i32 + ) + _tid_vec1 = vector.from_elements(T.vec(1, T.i32), [_tid_val]) + vector.store(_tid_vec1, lds_tid, [tx]) + scf.YieldOp([]) - # Loop-carried ping/pong state. - lds_base_pong = lds_base_cur - lds_base_ping = lds_base_nxt - gate_w_pong = gate_w0 - up_w_pong = up_w0 + acc_gate = [acc_init] * num_acc_n * m_repeat + acc_up = [acc_init] * num_acc_n * m_repeat if not gate_only else None + + _k1 = k_base_idx + arith.constant(tile_k, index=True) + rocdl.sched_barrier(0) + if use_async_copy: + prefetch_x_to_lds(_k1, lds_x_ping) + else: + _x_regs_prime = load_x_tile(_k1) + store_x_tile_to_lds(_x_regs_prime, lds_x_ping) - a0_prefetch_pong = None + _k0_b = k_base_idx // arith.constant(2, index=True) + gate_w0, up_w0 = load_b_tile(_k0_b) + # Prime the deep pipeline: DMA K=tile_k -> ping (1 tile ahead) + # rocdl.s_waitcnt(8) + gpu.barrier() + rocdl.sched_barrier(0) + a_tile_pong = prefetch_full_a_from_lds(lds_x_pong) - if os.environ.get("FLYDSL_STAGE1_EARLY_RETURN", "0") == "1": - return + rocdl.sched_barrier(0) + rocdl.s_waitcnt(6) - num_k_tiles_py = int(model_dim) // int(tile_k) + num_k_tiles_py = int(_k_dim) // int(tile_k) odd_k_tiles = (num_k_tiles_py % 2) == 1 tail_tiles = 1 if odd_k_tiles else 2 k_main2_py = (num_k_tiles_py - tail_tiles) * int(tile_k) if k_main2_py < 0: k_main2_py = 0 + gate_w_pong = gate_w0 + up_w_pong = up_w0 - _skip_compute = ( - os.environ.get("FLYDSL_STAGE1_SKIP_COMPUTE", "0") - == "1" - ) - if k_main2_py > 0: - for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): - k_iv = k_iv_py - next_k1 = k_iv + tile_k - x_regs_ping = load_x_tile(next_k1) - gate_w_ping, up_w_ping = load_b_tile(next_k1 // 2) - a_scale_ping, gate_bs_ping, up_bs_ping = ( - prefetch_ab_scale_tile(next_k1 // pack_K // 128) - ) + def _sched_hints_stage1_gate_up(): + """Stage1 hot-loop scheduler adapted from the gate/up gufusion pipeline. + + The original hot loop doubles the B-side VMEM and MFMA streams: + - gate B load + up B load + - gate B-scale load + up B-scale load + - gate MFMA + up MFMA + + The scheduler API here is less expressive than the original + `__builtin_amdgcn_sched_group_barrier`, so we encode the same + idea with a compact heuristic: + - always double MFMA groups (`num_acc_n * 2`) + - use 2 VMEM groups only when the N tile is wide enough to + sustain the extra B-side traffic (`num_acc_n >= 4`) + - otherwise keep 1 VMEM group to avoid over-throttling the + smaller `tile_n=128` kernels + """ + # mfma_group = num_acc_n * 2 + # mfma_total = (k_unroll * 2) * m_repeat * mfma_group + # mfma_per_iter = 2 * mfma_group + # sche_iters = ( + # 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) + # ) + + # # Approximate the doubled B-side prefetch pressure. + # vmem_groups = 2 if int(num_acc_n) >= 4 else 1 + + # rocdl.sched_dsrd(2) + # rocdl.sched_mfma(2) + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(1) + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(1) + + # dswr_tail = num_x_loads + # if dswr_tail > sche_iters: + # dswr_tail = sche_iters + # dswr_start = sche_iters - dswr_tail + + # for sche_i in range_constexpr(sche_iters): + # rocdl.sched_vmem(vmem_groups) + # rocdl.sched_mfma(mfma_group) + # rocdl.sched_dsrd(1) + # rocdl.sched_mfma(mfma_group) + # if sche_i >= dswr_start - 1: + # rocdl.sched_dswr(1) + # rocdl.sched_barrier(0) + + if use_async_copy: + a_vmem_load = max(1, tile_m // 32) + mfma_group = a_vmem_load + rocdl.sched_vmem(a_vmem_load) - if _skip_compute: - store_x_tile_to_lds(x_regs_ping, lds_base_ping) - gpu.barrier() - a0_prefetch_ping = None - next_k2 = k_iv + (tile_k * 2) - x_regs_pong = load_x_tile(next_k2) - gate_w_pong, up_w_pong = load_b_tile(next_k2 // 2) - a_scale_pong, gate_bs_pong, up_bs_pong = ( - prefetch_ab_scale_tile(next_k2 // pack_K // 128) - ) - store_x_tile_to_lds(x_regs_pong, lds_base_pong) - gpu.barrier() - a0_prefetch_pong = None - continue + rocdl.sched_mfma(mfma_group) - acc_gate, acc_up, _ = compute_f8f6f4_tile( + b_vmem_total = k_unroll * num_acc_n * 2 + vmem_count = b_vmem_total + 2 + a_vmem_load + + if tile_m == 16: + for i in range_constexpr(2): + rocdl.sched_dsrd(1) + rocdl.sched_mfma(1) + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + for i in range_constexpr(9): + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + else: + for i in range_constexpr(a_vmem_load * 4): + rocdl.sched_dsrd(1) + rocdl.sched_mfma(1) + rocdl.sched_vmem(1) + rocdl.sched_mfma(mfma_group) + + if tile_m == 32: + for i in range_constexpr(vmem_count - a_vmem_load * 4): + rocdl.sched_vmem(1) + rocdl.sched_mfma(mfma_group) + elif tile_m == 64: + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + rocdl.sched_vmem(1) + rocdl.sched_mfma(2) + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + rocdl.sched_vmem(1) + rocdl.sched_mfma(2) + + rocdl.sched_barrier(0) + + rocdl.sched_barrier(0) + + if k_main2_py > 0: + for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): + next_k_load_1 = k_iv_py + tile_k + next_k_load_2 = k_iv_py + tile_k * 2 + next_k_dma_1 = k_iv_py + tile_k * 2 + next_k_dma_2 = k_iv_py + tile_k * 3 + + # Half 1: read ping (DMA'd prev half), DMA->pong, MFMA(pong) + ( + a_tile_ping, + gate_w_ping, + up_w_ping, + a_scale_ping, + gate_bs_ping, + up_bs_ping, acc_gate, acc_up, + ) = _interleaved_half( + lds_x_ping, + lds_x_pong, + next_k_dma_1, + next_k_load_1, + a_tile_pong, gate_w_pong, up_w_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - gate_b_scale=gate_bs_pong, - up_b_scale=up_bs_pong, - ) - a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) - gpu.barrier() - - a0_prefetch_ping = None - - next_k2 = k_iv + (tile_k * 2) - x_regs_pong = load_x_tile(next_k2) - gate_w_pong, up_w_pong = load_b_tile(next_k2 // 2) - a_scale_pong, gate_bs_pong, up_bs_pong = ( - prefetch_ab_scale_tile(next_k2 // pack_K // 128) + a_scale_pong, + gate_bs_pong, + up_bs_pong, + acc_gate, + acc_up, ) - acc_gate, acc_up, _ = compute_f8f6f4_tile( + # Half 2: read pong (DMA'd Half 1), DMA->ping, MFMA(ping) + ( + a_tile_pong, + gate_w_pong, + up_w_pong, + a_scale_pong, + gate_bs_pong, + up_bs_pong, acc_gate, acc_up, + ) = _interleaved_half( + lds_x_pong, + lds_x_ping, + next_k_dma_2, + next_k_load_2, + a_tile_ping, gate_w_ping, up_w_ping, - lds_base_ping, - a0_prefetch=a0_prefetch_ping, - a_scale=a_scale_ping, - gate_b_scale=gate_bs_ping, - up_b_scale=up_bs_ping, + a_scale_ping, + gate_bs_ping, + up_bs_ping, + acc_gate, + acc_up, ) - a0_prefetch_ping = None - store_x_tile_to_lds(x_regs_pong, lds_base_pong) - gpu.barrier() - a0_prefetch_pong = None + # _wave_mod2_b = wave_id % arith.constant(2, index=True) + # _wave_odd = arith.cmpi( + # CmpIPredicate.eq, _wave_mod2_b, arith.constant(1, index=True) + # ) + # _if_wave_odd = scf.IfOp(_wave_odd) + # with ir.InsertionPoint(_if_wave_odd.then_block): + # # gpu.barrier() + # _barrier() + # scf.YieldOp([]) if odd_k_tiles: - acc_gate, acc_up, epilogue_pf = compute_f8f6f4_tile( + acc_gate, acc_up, epilogue_pf = compute_tile( acc_gate, acc_up, gate_w_pong, up_w_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - gate_b_scale=gate_bs_pong, - up_b_scale=up_bs_pong, + a_tile_pong, + a_scale_pong, + gate_bs_pong, + up_bs_pong, prefetch_epilogue=True, ) else: - k_tail1 = k_in - tile_k - x_regs_ping = load_x_tile(k_tail1) - gate_w_ping, up_w_ping = load_b_tile(k_tail1 // 2) + _k_tail_rel = arith.constant(_k_dim - tile_k, index=True) + k_tail1 = k_base_idx + _k_tail_rel + if use_async_copy: + prefetch_x_to_lds(k_tail1, lds_x_ping) + else: + x_regs_ping = load_x_tile(k_tail1) + gate_w_ping, up_w_ping = load_b_tile( + k_tail1 // arith.constant(2, index=True) + ) a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( - k_tail1 // pack_K // 128 + k_tail1 // arith.constant(pack_K * 128, index=True) ) - - acc_gate, acc_up, _ = compute_f8f6f4_tile( + acc_gate, acc_up, _ = compute_tile( acc_gate, acc_up, gate_w_pong, up_w_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - gate_b_scale=gate_bs_pong, - up_b_scale=up_bs_pong, + a_tile_pong, + a_scale_pong, + gate_bs_pong, + up_bs_pong, ) - a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) - gpu.barrier() - - a0_prefetch_ping = None - - acc_gate, acc_up, epilogue_pf = compute_f8f6f4_tile( + if not use_async_copy: + store_x_tile_to_lds(x_regs_ping, lds_x_ping) + rocdl.s_waitcnt(0) + _barrier() + a_tile_ping = prefetch_full_a_from_lds(lds_x_ping) + acc_gate, acc_up, epilogue_pf = compute_tile( acc_gate, acc_up, gate_w_ping, up_w_ping, - lds_base_ping, - a0_prefetch=a0_prefetch_ping, - a_scale=a_scale_ping, - gate_b_scale=gate_bs_ping, - up_b_scale=up_bs_ping, + a_tile_ping, + a_scale_ping, + gate_bs_ping, + up_bs_ping, prefetch_epilogue=True, ) - # Store epilogue to out[t, slot, inter] - topk_i32_v = topk_i32 - inter_i32_v = fx.Int32(inter_dim) - mask24_i32 = fx.Int32(0xFFFFFF) + # silu(gate) * up in f32 before epilogue + # silu(x) = x * sigmoid(x); use HW fast path: exp2, rcp + def _silu_mul_vec4(gate_v4, up_v4): + """Element-wise silu(gate) * up on vec4_f32.""" + result_elems = [] + for ei in range_constexpr(4): + g = vector.extract( + gate_v4, static_position=[ei], dynamic_position=[] + ) + u = vector.extract( + up_v4, static_position=[ei], dynamic_position=[] + ) + neg_log2e = arith.constant(-1.4426950408889634, type=f32) + t = g * neg_log2e + emu = llvm.call_intrinsic( + f32, "llvm.amdgcn.exp2.f32", [t], [], [] + ) + one = arith.constant(1.0, type=f32) + den = one + emu + sig = llvm.call_intrinsic( + f32, "llvm.amdgcn.rcp.f32", [den], [], [] + ) + result_elems.append(g * sig * u) + return vector.from_elements(vec4_f32, result_elems) + + if not _is_splitk: + acc = [None] * (int(num_acc_n) * int(m_repeat)) + for _mi in range_constexpr(m_repeat): + for _ni in range_constexpr(num_acc_n): + _aidx = _mi * num_acc_n + _ni + acc[_aidx] = _silu_mul_vec4(acc_gate[_aidx], acc_up[_aidx]) + + # ---- Epilogue: CShuffle + direct store (accumulate=False) ---- + # Output: out[(t*topk+s) * inter_dim + col] = silu(gate) * up + # For split-K: skip silu, output gate/up separately with atomic add + tw_pf = None + if epilogue_pf is not None: + _, tw_pf, _ = epilogue_pf - # Epilogue hoists to keep IR + Python build time small: - col_i32_list = [] - for ni in range_constexpr(num_acc_n): - col_i32_list.append(arith.index_cast(T.i32, col_g_list[ni])) + mask24_i32 = arith.constant(0xFFFFFF) + topk_i32_v = topk_i32 + tokens_i32_v = tokens_i32 - _lane_div_16_mul4 = lane_div_16 * arith.index(4) - inter_i32_local = inter_i32_v + from flydsl._mlir.dialects import fly as _fly - # Optional: CK-style CShuffle epilogue for better global store coalescing. - # Uses EVec=4 (buffer store "x4" of fp16 elements). - _use_cshuffle_epilog = (out_dtype == "fp8") or bool( - use_cshuffle_epilog + _llvm_ptr_ty = ir.Type.parse("!llvm.ptr") + out_base_ptr = _fly.extract_aligned_pointer_as_index( + _llvm_ptr_ty, arg_out ) + out_base_i64 = llvm.ptrtoint(T.i64, out_base_ptr) + out_base_idx = arith.index_cast(ir.IndexType.get(), out_base_i64) - _mask_even_i32 = fx.Int32(0xFFFFFFFE) + if lds_out is None: + raise RuntimeError("CShuffle epilogue requires lds_out") - if _use_cshuffle_epilog: - if lds_out is None: - raise RuntimeError( - "CShuffle epilogue enabled but lds_out is not allocated/aliased." - ) + _apply_weight = doweight_stage1 and not _is_splitk - def write_row_to_lds( - *, - mi: int, - ii: int, - row_in_tile, - row, - row_base_lds, - col_base_local, - num_acc_n: int, - lds_out, - ): - # `row` is the sorted-row index (bx_m + row_in_tile). - row_i32 = arith.index_cast(T.i32, row) - row_valid0 = arith.cmpi( - arith.CmpIPredicate.ult, row_i32, max_token_id_i32 - ) - row_safe = arith.select( - row_valid0, row, fx.Index(0) - ) - fused2 = buffer_ops.buffer_load( - sorted_rsrc, row_safe, vec_width=1, dtype=T.i32 - ) - _t2 = fused2 & mask24_i32 - - # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: + def write_row_to_lds( + *, + mi: int, + ii: int, + row_in_tile, + row, + row_base_lds, + col_base_local, + num_acc_n: int, + lds_out, + ): + if _apply_weight: + tw_idx = (mi * 4) + ii + if tw_pf is not None: + tw = tw_pf[tw_idx] + else: tw = buffer_ops.buffer_load( - sorted_w_rsrc, row_safe, vec_width=1, dtype=T.f32 + sorted_w_rsrc, row, vec_width=1, dtype=f32 ) + for ni in range_constexpr(num_acc_n): + col_local = col_base_local + (ni * 16) + acc_idx = mi * num_acc_n + ni + v = vector.extract( + acc[acc_idx], static_position=[ii], dynamic_position=[] + ) + if _apply_weight: + v = v * tw + if _need_quant: + lds_idx = row_base_lds + col_local + vec1_f32 = T.vec(1, f32) + v1 = vector.from_elements(vec1_f32, [v]) + vector.store(v1, lds_out, [lds_idx], alignment=4) + else: + v_out = arith.trunc_f(out_elem(), v) + lds_idx = row_base_lds + col_local + vec1_out = T.vec(1, out_elem()) + v1 = vector.from_elements(vec1_out, [v_out]) + vector.store(v1, lds_out, [lds_idx], alignment=2) + + _out_row_stride = ( + inter_dim * 2 * out_elem_bytes + if _is_splitk + else (inter_dim // 2 if _need_quant else inter_dim * out_elem_bytes) + ) - for ni in range_constexpr(num_acc_n): - col_local = col_base_local + (ni * 16) + def precompute_row(*, row_local, row): + fused2 = memref.load(lds_tid, [row_local]) + row_i32 = arith.index_cast(T.i32, row) + row_valid0 = arith.cmpi(CmpIPredicate.ult, row_i32, num_valid_i32) + t = fused2 & mask24_i32 + s = fused2 >> 24 + t_ok = arith.cmpi(CmpIPredicate.ult, t, tokens_i32_v) + s_ok = arith.cmpi(CmpIPredicate.ult, s, topk_i32_v) + row_valid = arith.andi(row_valid0, arith.andi(t_ok, s_ok)) + t_idx = arith.index_cast(ir.IndexType.get(), t) + s_idx = arith.index_cast(ir.IndexType.get(), s) + ts_idx = t_idx * arith.constant(topk, index=True) + s_idx + row_byte_base = out_base_idx + ts_idx * arith.constant( + _out_row_stride, index=True + ) + return ((fused2, row_byte_base), row_valid) + + def _idx_to_llvm_ptr(idx_val, addr_space=1): + idx_v = idx_val._value if hasattr(idx_val, "_value") else idx_val + i64_v = arith.index_cast(T.i64, idx_v) + i64_raw = i64_v._value if hasattr(i64_v, "_value") else i64_v + ptr_ty = ir.Type.parse(f"!llvm.ptr<{addr_space}>") + return llvm.inttoptr(ptr_ty, i64_raw) + + _e_vec = _e_vec_s1 + _e_vec_sk = 2 + _cshuffle_nlane = min(32, tile_n // _e_vec) + _cshuffle_nlane_sk = min(32, tile_n // _e_vec_sk) + _num_threads_per_quant_blk = _num_threads_per_quant_blk_s1 + + _c0_i32 = arith.constant(0, type=T.i32) + _c1_i32 = arith.constant(1, type=T.i32) + _c2_i32 = arith.constant(2, type=T.i32) + _c3_i32 = arith.constant(3, type=T.i32) + _c4_i32 = arith.constant(4, type=T.i32) + _c5_i32 = arith.constant(5, type=T.i32) + _c7_i32 = arith.constant(7, type=T.i32) + _c15_i32 = arith.constant(15, type=T.i32) + _c21_i32 = arith.constant(21, type=T.i32) + _c23_i32 = arith.constant(23, type=T.i32) + _c28_i32 = arith.constant(28, type=T.i32) + _c31_i32 = arith.constant(31, type=T.i32) + _c32_i32 = arith.constant(32, type=T.i32) + _c64_i32 = arith.constant(64, type=T.i32) + _c126_i32 = arith.constant(126, type=T.i32) + _c127_i32 = arith.constant(127, type=T.i32) + _c254_i32 = arith.constant(254, type=T.i32) + _c256_i32 = arith.constant(256, type=T.i32) + _c0xFF_i32 = arith.constant(0xFF, type=T.i32) + _c0x200000_i32 = arith.constant(0x200000, type=T.i32) + _c0xFF800000_i32 = arith.constant(0xFF800000, type=T.i32) + _c0x400000_i32 = arith.constant(0x400000, type=T.i32) + _c0x7FFFFF_i32 = arith.constant(0x7FFFFF, type=T.i32) + _c0x80000000_i32 = arith.constant(0x80000000, type=T.i32) + _c0_f32 = arith.constant(0.0, type=T.f32) + + def _f32_to_e2m1(qx_f32): + """Convert a scaled f32 value to fp4 (e2m1) 4-bit integer.""" + qx = qx_f32.bitcast(T.i32) + s = qx & _c0x80000000_i32 + e = (qx >> _c23_i32) & _c0xFF_i32 + m = qx & _c0x7FFFFF_i32 + adj_exp = arith.maxsi(_c126_i32 - e, _c0_i32) + m_denorm = (_c0x400000_i32 | (m >> _c1_i32)) >> adj_exp + is_denorm = arith.cmpi(CmpIPredicate.ult, e, _c127_i32) + m = arith.select(is_denorm, m_denorm, m) + e = arith.maxsi(e - _c126_i32, _c0_i32) + combined = (e << _c2_i32) | (m >> _c21_i32) + rounded = (combined + _c1_i32) >> _c1_i32 + e2m1 = arith.minui(rounded, _c7_i32) + return (s >> _c28_i32) | e2m1 + + if _need_sort: + _n32_sort = _sorted_scale_cols_i32 * _c32_i32 + + # Mutable slot for split-K N-offset (gate=0, up=inter_dim) + _sk_n_offset = [0] - acc_idx = mi * num_acc_n + ni - vg = vector.extract( - acc_gate[acc_idx], - static_position=[ii], - dynamic_position=[], - ) - vu = vector.extract( - acc_up[acc_idx], - static_position=[ii], - dynamic_position=[], + def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): + fused, row_byte_base = row_ctx + if _need_quant and not _is_splitk: + frag_vals = [] + for i in range_constexpr(_e_vec): + frag_vals.append( + vector.extract( + frag, static_position=[i], dynamic_position=[] + ) ) - if enable_bias: - gate_bias_list, up_bias_list = epilogue_pf - vg = vg + gate_bias_list[ni] - vu = vu + up_bias_list[ni] - - if act == "swiglu": - y = swiglu(vg, vu) - else: - y = silu(vg) * vu - - if doweight_stage1: - y = y * tw - - lds_idx = row_base_lds + col_local - v1 = vector.from_elements(vec1_f32, [y]) - vector.store(v1, lds_out, [lds_idx], alignment=1) + local_max = _c0_f32 + for i in range_constexpr(_e_vec): + abs_v = llvm.call_intrinsic( + f32, "llvm.fabs.f32", [frag_vals[i]], [], [] + ) + local_max = arith.maximumf(local_max, abs_v) + + for _si in range_constexpr(_num_shuffle_steps_s1): + off = arith.constant(_shuffle_dists_s1[_si], type=T.i32) + peer = local_max.shuffle_xor(off, _c64_i32) + local_max = arith.maximumf(local_max, peer) + + max_i32 = local_max.bitcast(T.i32) + max_rounded = (max_i32 + _c0x200000_i32) & _c0xFF800000_i32 + exp_field = max_rounded >> _c23_i32 + e8m0_biased = arith.maxsi(exp_field - _c2_i32, _c0_i32) + + quant_exp = _c254_i32 - e8m0_biased + quant_scale = (quant_exp << _c23_i32).bitcast(T.f32) + + fp4_vals = [] + for i in range_constexpr(_e_vec): + scaled_v = frag_vals[i] * quant_scale + fp4_vals.append(_f32_to_e2m1(scaled_v)) + + packed_i32 = fp4_vals[0] | (fp4_vals[1] << _c4_i32) + for k in range_constexpr(1, _e_vec // 2): + byte_k = fp4_vals[2 * k] | (fp4_vals[2 * k + 1] << _c4_i32) + packed_i32 = packed_i32 | ( + byte_k << arith.constant(k * 8, type=T.i32) + ) - def precompute_row(*, row_local, row): - row_i32 = arith.index_cast(T.i32, row) - row_valid0 = arith.cmpi( - arith.CmpIPredicate.ult, row_i32, max_token_id_i32 + ptr_addr_idx = row_byte_base + col_g0 / arith.constant( + 2, index=True ) - row_safe = arith.select( - row_valid0, row, fx.Index(0) - ) - fused2 = buffer_ops.buffer_load( - sorted_rsrc, row_safe, vec_width=1, dtype=T.i32 - ) - t2 = fused2 & mask24_i32 - s2 = fused2 >> 24 - t_valid = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32) - s_valid = arith.cmpi(arith.CmpIPredicate.ult, s2, topk_i32_v) - ts_valid = row_valid0 & (t_valid & s_valid) - t2_safe = arith.select(ts_valid, t2, fx.Int32(0)) - s2_safe = arith.select(ts_valid, s2, fx.Int32(0)) - idx0 = (t2_safe * topk_i32_v + s2_safe) * inter_i32_local - return idx0, ts_valid - - def store_pair( - *, row_local, row, row_ctx, col_pair0, col_g0, frag - ): - idx0 = row_ctx - col_i32 = arith.index_cast(T.i32, col_g0) - idx_out = idx0 + col_i32 - if out_dtype == "fp8": - frag = vector.bitcast(vec4_f32, frag) - frag0 = vector.extract( - frag, static_position=[0], dynamic_position=[] + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + _pack_bytes = _e_vec // 2 + if _pack_bytes == 1: + store_val = arith.TruncIOp(T.i8, packed_i32) + store_raw = ( + store_val._value + if hasattr(store_val, "_value") + else store_val ) - frag1 = vector.extract( - frag, static_position=[1], dynamic_position=[] + llvm.StoreOp( + store_raw, out_ptr_v, alignment=1, nontemporal=True ) - frag2 = vector.extract( - frag, static_position=[2], dynamic_position=[] + elif _pack_bytes == 2: + store_val = arith.TruncIOp(T.i16, packed_i32) + store_raw = ( + store_val._value + if hasattr(store_val, "_value") + else store_val ) - frag3 = vector.extract( - frag, static_position=[3], dynamic_position=[] + llvm.StoreOp( + store_raw, out_ptr_v, alignment=2, nontemporal=True ) - - out_fp8 = fx.Int32(0) - out_fp8 = rocdl.cvt_pk_fp8_f32( - src_a=arith.unwrap(frag0), - src_b=arith.unwrap(frag1), - old=arith.unwrap(out_fp8), - word_sel=0, - res=T.i32, + else: + packed_raw = ( + packed_i32._value + if hasattr(packed_i32, "_value") + else packed_i32 ) - out_fp8 = rocdl.cvt_pk_fp8_f32( - src_a=arith.unwrap(frag2), - src_b=arith.unwrap(frag3), - old=arith.unwrap(out_fp8), - word_sel=1, - res=T.i32, + llvm.StoreOp( + packed_raw, out_ptr_v, alignment=4, nontemporal=True ) - buffer_ops.buffer_store(out_fp8, out_rsrc, idx_out // 4) - else: - out_vec_ty = T.vec(4, _out_elem_type()) - out_vals = [] - for fi in range_constexpr(4): - frag_i = vector.extract( - frag, static_position=[fi], dynamic_position=[] + + if _need_sort: + col_g0_i32 = arith.index_cast(T.i32, col_g0) + is_scale_writer = arith.cmpi( + CmpIPredicate.eq, col_g0_i32 & _c31_i32, _c0_i32 + ) + _if_scale = scf.IfOp(is_scale_writer) + with ir.InsertionPoint(_if_scale.then_block): + row_i32_s = arith.index_cast(T.i32, row) + col_s_i32 = col_g0_i32 >> _c5_i32 + d0 = row_i32_s >> _c5_i32 + d1 = (row_i32_s >> _c4_i32) & _c1_i32 + d2 = row_i32_s & _c15_i32 + d3 = col_s_i32 >> _c3_i32 + d4 = (col_s_i32 >> _c2_i32) & _c1_i32 + d5 = col_s_i32 & _c3_i32 + byte_off = ( + d0 * _n32_sort + + d3 * _c256_i32 + + d5 * _c64_i32 + + d2 * _c4_i32 + + d4 * _c2_i32 + + d1 ) - out_vals.append( - arith.trunc_f(_out_elem_type(), frag_i) + e8m0_i8 = arith.TruncIOp(T.i8, e8m0_biased) + buffer_ops.buffer_store( + e8m0_i8, + sorted_scale_rsrc, + byte_off, + offset_is_bytes=True, ) - out_vec = vector.from_elements(out_vec_ty, out_vals) - buffer_ops.buffer_store(out_vec, out_rsrc, idx_out) + scf.YieldOp([]) + elif _is_splitk: + col_idx = col_g0 + arith.constant(_sk_n_offset[0], index=True) + byte_off_col = col_idx * arith.constant( + out_elem_bytes, index=True + ) + ptr_addr_idx = row_byte_base + byte_off_col + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + frag_v = frag._value if hasattr(frag, "_value") else frag + llvm.AtomicRMWOp( + llvm.AtomicBinOp.fadd, + out_ptr_v, + frag_v, + llvm.AtomicOrdering.monotonic, + syncscope="agent", + alignment=_e_vec_sk * out_elem_bytes, + ) + else: + col_idx = col_g0 + byte_off_col = col_idx * arith.constant( + out_elem_bytes, index=True + ) + ptr_addr_idx = row_byte_base + byte_off_col + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + frag_v = frag._value if hasattr(frag, "_value") else frag + llvm.StoreOp( + frag_v, + out_ptr_v, + alignment=_e_vec * out_elem_bytes, + nontemporal=True, + ) + + _frag_elem = ( + ir.F32Type.get() + if _need_quant + else (ir.BF16Type.get() if out_is_bf16 else ir.F16Type.get()) + ) + + if gate_only: + # gate_only: single pass, by_n covers full [0, 2*inter_dim) + _eff_e_vec = _e_vec_sk + acc = acc_gate + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_eff_e_vec, + cshuffle_nlane=_cshuffle_nlane_sk, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_frag_elem, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + ) + elif _is_splitk: + # Two-pass epilogue: gate then up, each with atomic add + _eff_e_vec = _e_vec_sk + + # Pass 1: gate + acc = acc_gate + _sk_n_offset[0] = 0 + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_eff_e_vec, + cshuffle_nlane=_cshuffle_nlane_sk, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_frag_elem, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + ) + + gpu.barrier() - mfma_epilog( - use_cshuffle=True, + # Pass 2: up + acc = acc_up + _sk_n_offset[0] = inter_dim + c_shuffle_epilog( arith=arith, vector=vector, gpu=gpu, @@ -1332,7 +2145,9 @@ def store_pair( range_constexpr=range_constexpr, tile_m=tile_m, tile_n=tile_n, - e_vec=4, + e_vec=_eff_e_vec, + cshuffle_nlane=_cshuffle_nlane_sk, + block_size=total_threads, m_repeat=m_repeat, num_acc_n=num_acc_n, tx=tx, @@ -1342,85 +2157,51 @@ def store_pair( by_n=by_n, n_tile_base=n_tile_base, lds_out=lds_out, - frag_elem_type=T.f32, + frag_elem_type=_frag_elem, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + ) + else: + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_e_vec, + cshuffle_nlane=_cshuffle_nlane, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_frag_elem, write_row_to_lds=write_row_to_lds, precompute_row=precompute_row, store_pair=store_pair, ) - return - - def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): - # `row` is the sorted-row index (bx_m + row_in_tile). - row_i32 = arith.index_cast(T.i32, row) - row_valid0 = arith.cmpi( - arith.CmpIPredicate.ult, row_i32, max_token_id_i32 - ) - row_safe = arith.select( - row_valid0, row, fx.Index(0) - ) - fused2 = buffer_ops.buffer_load( - sorted_rsrc, row_safe, vec_width=1, dtype=T.i32 - ) - t2 = fused2 & mask24_i32 - s2 = fused2 >> 24 - t_valid = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32) - s_valid = arith.cmpi(arith.CmpIPredicate.ult, s2, topk_i32_v) - ts_valid = row_valid0 & (t_valid & s_valid) - t2_safe = arith.select(ts_valid, t2, fx.Int32(0)) - s2_safe = arith.select(ts_valid, s2, fx.Int32(0)) - - # out linear index base = ((t*topk + s)*inter_dim) (invariant across ni) - idx0 = (t2_safe * topk_i32_v + s2_safe) * inter_i32_local - - # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: - tw = buffer_ops.buffer_load( - sorted_w_rsrc, row_safe, vec_width=1, dtype=T.f32 - ) - - _if_valid = scf.IfOp(ts_valid) - with _if_then(_if_valid): - for ni in range_constexpr(num_acc_n): - col_i32 = col_i32_list[ni] - acc_idx = mi * num_acc_n + ni - vg = vector.extract( - acc_gate[acc_idx], - static_position=[ii], - dynamic_position=[], - ) - vu = vector.extract( - acc_up[acc_idx], - static_position=[ii], - dynamic_position=[], - ) - if enable_bias: - gate_bias_list, up_bias_list = epilogue_pf - vg = vg + gate_bias_list[ni] - vu = vu + up_bias_list[ni] - - if act == "swiglu": - y = swiglu(vg, vu) - else: - y = silu(vg) * vu - - if doweight_stage1: - y = y * tw - y = arith.trunc_f(_out_elem_type(), y) - idx_out = idx0 + col_i32 - buffer_ops.buffer_store(y, out_rsrc, idx_out) + _if_blk = scf.IfOp(blk_valid) + with ir.InsertionPoint(_if_blk.then_block): + _ifexpert_of = scf.IfOp(exp_valid) + with ir.InsertionPoint(_ifexpert_of.then_block): + _moe_gemm1_body() + scf.YieldOp([]) + scf.YieldOp([]) - mfma_epilog( - use_cshuffle=False, - arith=arith, - range_constexpr=range_constexpr, - m_repeat=m_repeat, - lane_div_16=lane_div_16, - bx_m=bx_m, - body_row=_stage1_store_row, - ) + gpu.barrier() + scf.YieldOp([]) + _for_ip.__exit__(None, None, None) - # -- Host launcher (flyc.jit + .launch) -------------------------------- + # -- Host launcher -- _cache_tag = ( module_name, a_dtype, @@ -1435,6 +2216,13 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): model_dim_pad, inter_dim_pad, use_cshuffle_epilog, + persist_m, + fuse_fp4_quant, + fuse_sort_scale, + use_async_copy, + waves_per_eu, + k_batch, + gate_only, ) @flyc.jit @@ -1456,14 +2244,28 @@ def launch_mixed_moe_gemm1( stream: fx.Stream, ): _ = _cache_tag - allocator.finalized = False + allocator_pong.finalized = False + allocator_ping.finalized = False ctx = CompilationContext.get_current() with ir.InsertionPoint(ctx.gpu_module_body): - allocator.finalize() + allocator_pong.finalize() + allocator_ping.finalize() - inter_in = arith.index_cast(T.index, i32_inter_in) - gx = inter_in // fx.Index(tile_n) - gy = arith.index_cast(T.index, i32_size_expert_ids_in) + inter_in = arith.index_cast(ir.IndexType.get(), i32_inter_in.ir_value()) + if gate_only: + gx = inter_in / arith.constant(tile_n, index=True) + else: + gx = ( + inter_in + / arith.constant(2, index=True) + / arith.constant(tile_n, index=True) + ) + _c_pm_l = arith.constant(persist_m, index=True) + gy = ( + arith.index_cast(ir.IndexType.get(), i32_size_expert_ids_in.ir_value()) + + _c_pm_l + - arith.constant(1, index=True) + ) / _c_pm_l moe_gemm1( arg_out, @@ -1480,16 +2282,12 @@ def launch_mixed_moe_gemm1( i32_inter_in, i32_k_in, i32_size_expert_ids_in, - ).launch( - grid=(gx, gy, 1), - block=(256, 1, 1), - stream=stream, - ) + ).launch(grid=(gx, gy, k_batch), block=(total_threads, 1, 1), stream=stream) return launch_mixed_moe_gemm1 -@functools.lru_cache(maxsize=1024) +@functools.lru_cache(maxsize=None) def compile_mixed_moe_gemm2( *, model_dim: int, @@ -1512,9 +2310,15 @@ def compile_mixed_moe_gemm2( model_dim_pad: int = 0, inter_dim_pad: int = 0, persist_m: int = 4, + sort_block_m: int = 0, ): """Compile stage2 kernel (`moe_gemm2`) and return the compiled executable. + persist_m: + - > 0: legacy mode -- each CTA processes exactly persist_m consecutive M tiles. + - <= 0: **persistent mode** -- grid_y = cu_num (auto-detected), each CTA + round-robins over M tiles with stride cu_num. + a_dtype: - "fp8": A2 is fp8 - "fp16": A2 is fp16 (caller uses tile_k halved vs fp8 to match MFMA K halving) @@ -1534,9 +2338,19 @@ def compile_mixed_moe_gemm2( `use_cshuffle_epilog` controls whether we use the LDS CShuffle epilogue before global atomics (recommended for performance). + + `sort_block_m` is the block_size used by moe_sorting / stage1. When 0 (default), + assumed equal to `tile_m`. When set, stage2 can use a different tile_m from + sorting/stage1. Requires sort_block_m % tile_m == 0. """ + _sort_block_m = tile_m if sort_block_m <= 0 else sort_block_m + if _sort_block_m != tile_m and _sort_block_m % tile_m != 0: + raise ValueError( + f"sort_block_m ({_sort_block_m}) must be a multiple of tile_m ({tile_m})" + ) + gpu_arch = get_hip_arch() - allocator = SmemAllocator(None, arch=gpu_arch) + allocator = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0") if a_dtype not in ("fp8", "fp16", "int8", "fp4"): raise ValueError( @@ -1554,9 +2368,13 @@ def compile_mixed_moe_gemm2( is_f4_a = a_dtype == "fp4" is_f4_b = b_dtype == "fp4" - pack_M = 2 - pack_N = 2 - pack_K = 2 + _scale_pack_m = 2 # physical mn_pack in preshuffle microscale layout + _scale_pack_n = 2 + _scale_pack_k = 2 # physical k_pack in preshuffle scale layout + pack_M = min(_scale_pack_m, tile_m // 16) + pack_N = min(_scale_pack_n, tile_n // 64) + _k_unroll_raw = (int(tile_k) * (2 if a_dtype == "fp16" else 1)) // 128 + pack_K = min(_scale_pack_k, _k_unroll_raw) elem_bytes = 1 @@ -1568,6 +2386,22 @@ def compile_mixed_moe_gemm2( cbsz = 0 if is_f8_a else 4 blgp = 4 + # ---- Static B preshuffle strides (compile-time) ---- + # All values below are Python ints computable at kernel-compile time. + # Using them in an explicit multiply-add replaces the fly dialect's + # dynamic ``crd2idx`` path which emits Barrett reduction for the + # non-power-of-2 ``n0 = experts*model_dim//16`` shape. + _b_kpack_bytes_s = 8 if (b_dtype == "int4") else 16 + _b_kpack_elems_s = _b_kpack_bytes_s // b_elem_bytes + _b_c_k_s = inter_dim // _scale_pack_k + _b_c_k0_s = (_b_c_k_s * b_elem_bytes) // 64 + _b_stride_nlane = _b_kpack_elems_s # 16 + _b_stride_klane = 16 * _b_stride_nlane # 256 + _b_stride_k0 = 4 * _b_stride_klane # 1024 + _b_stride_n0 = _b_c_k0_s * _b_stride_k0 # c_k0 * 1024 + assert model_dim % 16 == 0, "model_dim must be divisible by 16" + _expert_b_stride = (model_dim // 16) * _b_stride_n0 + # K64-byte micro-step: always 64 bytes per `ku`. For fp16, this is 32 elements (2xK16 MFMA). if (tile_k_bytes % 64) != 0: raise ValueError( @@ -1587,6 +2421,32 @@ def compile_mixed_moe_gemm2( "compile_moe_gemm2(accumulate=False) only supports out_dtype in {'f16','bf16'}" ) is_int4 = b_dtype == "int4" + # INT4 here means W4A8: A2 is int8, W is packed int4 and unpacked to int8 in-kernel. + is_int8 = False + + mfma_i32_k32 = None + if is_int8: + mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( + rocdl, "mfma_i32_16x16x32_i8", None + ) + if mfma_i32_k32 is None: + raise AttributeError( + "INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " + "(or `rocdl.mfma_i32_16x16x32_i8`)." + ) + + def _x_elem_type(): + if is_f4_b: + return T.f8 if is_f8_a else T.i8 + return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) + + def _w_elem_type(): + if is_f4_b: + return T.i8 + return T.f16 if is_f16_b else (T.i8 if is_int8 else T.f8) + + def _scale_elem_type(): + return T.i32 total_threads = 256 bytes_x_per_tile = int(tile_m) * int(tile_k) * int(a_elem_bytes) @@ -1597,13 +2457,15 @@ def compile_mixed_moe_gemm2( ) bytes_per_thread_x = bytes_x_per_tile // total_threads - pad_k = 0 + _use_lds128 = os.environ.get("FLIR_CK_LDS128", "1") in ( + "1", + "true", + "True", + "YES", + "yes", + ) + pad_k = 0 if _use_lds128 else 8 lds_stride = tile_k + pad_k - # gfx950+ has buffer_atomic_pk_add_bf16; gfx942 uses global atomics via raw pointer. - _has_buffer_atomic_bf16 = str(gpu_arch).startswith(("gfx95", "gfx12")) - _needs_global_atomic_bf16 = out_is_bf16 and not _has_buffer_atomic_bf16 - if out_is_bf16 and not supports_bf16_global_atomics(gpu_arch): - raise ValueError(f"out_dtype='bf16' requires bf16 global atomics, got arch={gpu_arch!r}") if out_is_f32: # Match origin/dev_a16w4: f32 output uses scalar atomics and does NOT use the CShuffle epilogue. @@ -1616,7 +2478,7 @@ def compile_mixed_moe_gemm2( ) else: if use_cshuffle_epilog is None: - _use_cshuffle_epilog = os.environ.get("FLYDSL_MOE_STAGE2_CSHUFFLE", "1") in ( + _use_cshuffle_epilog = os.environ.get("FLIR_MOE_STAGE2_CSHUFFLE", "1") in ( "1", "true", "True", @@ -1627,7 +2489,7 @@ def compile_mixed_moe_gemm2( _use_cshuffle_epilog = bool(use_cshuffle_epilog) if not _use_cshuffle_epilog: raise ValueError( - "stage2 f16 output currently requires CShuffle epilogue (FLYDSL_MOE_STAGE2_CSHUFFLE=1)." + "stage2 f16 output currently requires CShuffle epilogue (FLIR_MOE_STAGE2_CSHUFFLE=1)." ) # NOTE: Keep this as a callable so we don't require an MLIR Context at Python-time. @@ -1638,13 +2500,26 @@ def out_elem(): # IMPORTANT: include tiling in the module name to avoid accidentally reusing a compiled # binary for a different (tile_m, tile_n, tile_k) configuration. # See stage1 note: include ABI tag to prevent binary reuse across signature changes. - # IMPORTANT: module name participates in FlyDSL's compile cache key. + # IMPORTANT: module name participates in the compiler cache key. # Dynamic-shape variant: safe to reuse across (tokens/sorted_size/size_expert_ids) at runtime. # Keep a distinct ABI tag so the compile cache never mixes with historical signatures. + _persistent = persist_m <= 0 + if _persistent: + try: + from aiter.jit.utils.chip_info import get_cu_num + except ImportError: + def get_cu_num(): + return 304 + + _cu_num = get_cu_num() + else: + _cu_num = 0 + _sbm_tag = "" if _sort_block_m == tile_m else f"_sbm{_sort_block_m}" + _pm_tag = f"_persist_cu{_cu_num}" if _persistent else f"_pm{persist_m}" module_name = ( f"mfma_moe2_a{a_dtype}_w{b_dtype}_{out_s}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" - f"_vscale_fix3" + f"_vscale_fix3{_pm_tag}{_sbm_tag}" ).replace("-", "_") # -- LDS sizing (pure Python; no MLIR Context needed) --------------------- # Reuse a single allocation for both: @@ -1659,7 +2534,7 @@ def out_elem(): lds_total_elems = lds_total_bytes if a_elem_bytes == 1 else (lds_total_bytes // 2) def x_lds_elem(): - return T.f16 if is_f16_a else T.f8 + return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) lds_alloc_bytes = int(lds_total_elems) * int(a_elem_bytes) lds_alloc_offset = allocator._align(allocator.ptr, 16) @@ -1684,43 +2559,48 @@ def moe_gemm2( i32_k_in: fx.Int32, i32_size_expert_ids_in: fx.Int32, ): - tokens_in = arith.index_cast(T.index, i32_tokens_in) - n_in = arith.index_cast(T.index, i32_n_in) - k_in = arith.index_cast(T.index, i32_k_in) + + tokens_in = arith.index_cast(ir.IndexType.get(), i32_tokens_in.ir_value()) + n_in = arith.index_cast(ir.IndexType.get(), i32_n_in.ir_value()) + k_in = arith.index_cast(ir.IndexType.get(), i32_k_in.ir_value()) size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) - x_elem = T.f16 if is_f16_a else T.f8 - vec4_f32 = T.vec(4, T.f32) - vec4_i32 = T.vec(4, T.i32) + x_elem = T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) + # For int4, weights are stored as packed bytes (i8) and unpacked to i8 packs. + f32 = T.f32 + i32 = T.i32 + i64 = T.i64 + vec4_f32 = T.vec(4, f32) + vec4_i32 = T.vec(4, i32) vec16_elems = 16 if a_elem_bytes == 1 else 8 vec8_elems = 8 if a_elem_bytes == 1 else 4 vec4_elems = 4 if a_elem_bytes == 1 else 2 vec16_x = T.vec(vec16_elems, x_elem) - vec2_i64 = T.vec(2, T.i64) + vec2_i64 = T.vec(2, i64) - acc_init = arith.constant_vector(0.0, vec4_f32) + acc_init = ( + arith.constant_vector(0, vec4_i32) + if is_int8 + else arith.constant_vector(0.0, vec4_f32) + ) # A2 layout (flatten token-slot -> M; use i32 for fly.make_shape). - topk_idx = fx.Index(topk) + topk_idx = arith.constant(topk, index=True) m_in = tokens_in * topk_idx - # fly.make_shape requires i32/i64, not index - m_i32_v = arith.index_cast(T.i32, m_in) - k_i32_v = i32_k_in # B preshuffle layout: [experts*model_dim, inter_dim] - c_n_total = fx.Index(experts * model_dim) + c_n_total = arith.constant(experts * model_dim, index=True) kpack_bytes = 8 if is_int4 else 16 - b_layout = make_preshuffle_b_layout( - arith, - c_n=c_n_total, - c_k=k_in // pack_K, - kpack_bytes=kpack_bytes, - elem_bytes=b_elem_bytes, - ) - layout_b = b_layout.layout_b + from kernels.layout_utils import _div_pow2, _mod_pow2 + + def check_c_n_valid_gate(base_n): + return arith.cmpi(CmpIPredicate.ult, base_n, model_dim - model_dim_pad) + + def check_c_k_valid_gate(base_k): + return arith.cmpi(CmpIPredicate.ult, base_k, inter_dim - inter_dim_pad) # A&B's scale preshuffle layout # For fp4, k_in is already packed (inter_dim // a_elem_vec_pack), so we need original inter_dim - c_k_orig = fx.Index(inter_dim) + c_k_orig = arith.constant(inter_dim, index=True) layout_a_scale = make_preshuffle_scale_layout( arith, c_mn=m_in, c_k=c_k_orig ) @@ -1733,14 +2613,14 @@ def moe_gemm2( layout_lds = fx.make_layout(shape_lds, stride_lds) tx = gpu.thread_id("x") - # Align with Aiter launch mapping: - # - blockIdx.x -> N dimension (tile along model_dim) - # - blockIdx.y -> expert-block id / M dimension (tile along sorted M) - by = gpu.block_id("x") # tile along model_dim - bx_persist = gpu.block_id("y") # tile along sorted M + by = gpu.block_id("x") # tile along model_dim (N-dim) + bx_persist = gpu.block_id("y") # persistent WG index (M-dim) # XOR16 swizzle parameter (in bytes; constant, power-of-two in our configs). - k_blocks16 = fx.Index(tile_k_bytes // 16) + k_blocks16 = arith.constant(tile_k_bytes // 16, index=True) + layout_tx_wave_lane = fx.make_layout((4, 64), stride=(64, 1)) + layout_lane16 = fx.make_layout((4, 16), stride=(16, 1)) + base_ptr = allocator.get_base() lds_x_ptr = SmemPtr( base_ptr, @@ -1772,14 +2652,15 @@ def moe_gemm2( # Buffer resources. # For dynamic memrefs, `max_size=False` cannot infer the logical size from the memref *type*, # so we should pass `num_records_bytes` explicitly for stable hardware OOB behavior. - c_topk = fx.Index(topk) + c_topk = arith.constant(topk, index=True) # X(A2): buffer size in bytes, accounting for FP4 packing (2 elements per byte). # fp8/int8: 1 byte per element -> bytes = tokens*topk * K # fp4: 2 elements per byte -> bytes = tokens*topk * K / 2 - c_a_pack = fx.Index(int(a_elem_vec_pack)) - c_elem_bytes = fx.Index(int(a_elem_bytes)) - x_nbytes_idx = ((tokens_in * c_topk) * k_in * c_elem_bytes) // c_a_pack + c_elem_bytes = arith.constant(int(a_elem_bytes), index=True) + x_nbytes_idx = _div_pow2( + (tokens_in * c_topk) * k_in * c_elem_bytes, int(a_elem_vec_pack) + ) x_nbytes_i32 = arith.index_cast(T.i32, x_nbytes_idx) x_rsrc = buffer_ops.create_buffer_resource( arg_x, max_size=False, num_records_bytes=x_nbytes_i32 @@ -1790,14 +2671,14 @@ def moe_gemm2( # OUT: [tokens, model_dim] -> clamp to descriptor max (i32 bytes) to avoid overflow on huge tokens. out_elem_bytes = 4 if out_is_f32 else 2 out_nbytes_idx = ( - tokens_in * n_in * fx.Index(out_elem_bytes) + tokens_in * n_in * arith.constant(out_elem_bytes, index=True) ) if not bool(accumulate): out_nbytes_idx = ( tokens_in * arith.index(topk) * n_in - * fx.Index(out_elem_bytes) + * arith.constant(out_elem_bytes, index=True) ) out_nbytes_i32 = arith.index_cast(T.i32, out_nbytes_idx) out_rsrc = buffer_ops.create_buffer_resource( @@ -1808,22 +2689,26 @@ def moe_gemm2( numids_rsrc = buffer_ops.create_buffer_resource( arg_num_valid_ids, max_size=False, - num_records_bytes=fx.Int32(4), + num_records_bytes=arith.constant(4, type=T.i32), ) num_valid_i32 = buffer_ops.buffer_load( - numids_rsrc, fx.Index(0), vec_width=1, dtype=T.i32 + numids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=T.i32 ) - num_valid_idx = arith.index_cast(T.index, num_valid_i32) + # num_valid_ids is a scalar (same value for all lanes) loaded into + # VGPR. Promote to SGPR so downstream buffer resource descriptors + # that use it for num_records stay in SGPRs, eliminating the + # expensive waterfall loop the compiler would otherwise emit. + num_valid_i32 = rocdl.ReadfirstlaneOp(T.i32, num_valid_i32).res + num_valid_idx = arith.index_cast(ir.IndexType.get(), num_valid_i32) # fp16 path ignores scales completely (implicit scale=1.0). if is_f16_a: sx_rsrc = None else: if is_f4_a: - # A2 microscale: packed i32 holding e8m0 bytes for [sorted_size, K/32]. - c32 = fx.Index(32) - kblk = k_in // c32 - # Total bytes = num_valid_ids * kblk. + # A2 microscale: e8m0 in sorted layout [sorted_size, K/32]. + # Caller must pre-scatter a2_scale via moe_mxfp4_sort. + kblk = _div_pow2(k_in, 32) sx_nbytes_idx = num_valid_idx * kblk sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes_idx) sx_rsrc = buffer_ops.create_buffer_resource( @@ -1831,7 +2716,7 @@ def moe_gemm2( ) else: # scale_x (A2 scale): [tokens*topk] f32 -> bytes = tokens*topk*4 - sx_nbytes_idx = (tokens_in * c_topk) * fx.Index(4) + sx_nbytes_idx = (tokens_in * c_topk) * arith.constant(4, index=True) sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes_idx) sx_rsrc = buffer_ops.create_buffer_resource( arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 @@ -1842,20 +2727,19 @@ def moe_gemm2( else: # Weight microscale buffer (packed i32 holding e8m0 bytes). # Use an exact descriptor size so hardware OOB checking works. - c32 = fx.Index(32) - kblk_w = k_in // c32 # K/32 - mn_w = fx.Index(experts * model_dim) + kblk_w = _div_pow2(k_in, 32) # K/32 + mn_w = arith.constant(experts * model_dim, index=True) sw_nbytes_idx = mn_w * kblk_w # bytes (e8m0) sw_nbytes_i32 = arith.index_cast(T.i32, sw_nbytes_idx) sw_rsrc = buffer_ops.create_buffer_resource( arg_scale_w, max_size=False, num_records_bytes=sw_nbytes_i32 ) - # sorted_token_ids / sorted_weights: [blocks*tile_m] (CK-style padded length) + # sorted_token_ids / sorted_weights: [blocks*tile_m] (padded length) sorted_nbytes_idx = ( size_expert_ids_in - * fx.Index(tile_m) - * fx.Index(4) + * arith.constant(tile_m, index=True) + * arith.constant(4, index=True) ) sorted_nbytes_i32 = arith.index_cast(T.i32, sorted_nbytes_idx) sorted_rsrc = buffer_ops.create_buffer_resource( @@ -1867,8 +2751,14 @@ def moe_gemm2( arg_sorted_weights, max_size=False, num_records_bytes=sorted_nbytes_i32 ) - # expert ids: [blocks] i32 -> bytes = size_expert_ids_in*4 - eid_nbytes_idx = size_expert_ids_in * fx.Index(4) + # expert ids: [sort_blocks] i32. + _c_sbm = arith.constant(_sort_block_m, index=True) + _c_tm = arith.constant(tile_m, index=True) + _c1 = arith.constant(1, index=True) + _sort_blocks_ub = _div_pow2( + size_expert_ids_in * _c_tm + _c_sbm - _c1, _sort_block_m + ) + eid_nbytes_idx = _sort_blocks_ub * arith.constant(4, index=True) eid_nbytes_i32 = arith.index_cast(T.i32, eid_nbytes_idx) expert_rsrc = buffer_ops.create_buffer_resource( arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32 @@ -1879,33 +2769,100 @@ def moe_gemm2( else None ) - # ---- persist_m loop ---- - _PERSIST_M = persist_m - _c0_p = arith.index(0) - _c1_p = arith.index(1) - _c_pm = arith.index(_PERSIST_M) - _for_persist = scf.ForOp(_c0_p, _c_pm, _c1_p) + # ---- persist loop ---- + _c0_p = arith.constant(0, index=True) + _c1_p = arith.constant(1, index=True) + + if _persistent: + # Expert-phase scheduling: contiguous M-tile dispatch. + # grid_y = cu_num, each CTA handles a contiguous chunk of M-tiles: + # [bx_persist * tiles_per_block, ..., (bx_persist+1) * tiles_per_block - 1] + # Adjacent blocks process adjacent M-tiles -> same expert -> B weight L2 reuse. + _c_cu = arith.constant(_cu_num, index=True) + _c_tm_p = arith.constant(tile_m, index=True) + _num_valid_idx = arith.index_cast(ir.IndexType.get(), num_valid_i32) + _total_m_tiles = (_num_valid_idx + _c_tm_p - _c1_p) / _c_tm_p + _tiles_per_block = (_total_m_tiles + _c_cu - _c1_p) / _c_cu + _i1 = ir.IntegerType.get_signless(1) + _init_active = arith.constant(1, type=_i1) + _for_persist = scf.ForOp(_c0_p, _tiles_per_block, _c1_p, [_init_active]) + else: + # Legacy mode: fixed persist_m consecutive tiles. + _c_pm = arith.constant(persist_m, index=True) + _init_prev_expert = arith.constant(0, type=T.i32) + _init_prev_b_base = arith.constant(0, index=True) + _for_persist = scf.ForOp( + _c0_p, + _c_pm, + _c1_p, + [_init_prev_expert, _init_prev_b_base], + ) + _for_ip = ir.InsertionPoint(_for_persist.body) _for_ip.__enter__() _mi_p = _for_persist.induction_variable - bx = bx_persist * _c_pm + _mi_p - bx_m = bx * fx.Index(tile_m) + + if _persistent: + _still_active = _for_persist.inner_iter_args[0] + bx = bx_persist * _tiles_per_block + _mi_p + else: + _prev_expert_i32 = _for_persist.inner_iter_args[0] + _prev_expert_b_base = _for_persist.inner_iter_args[1] + bx = bx_persist * arith.constant(persist_m, index=True) + _mi_p + + bx_m = bx * arith.constant(tile_m, index=True) # Early-exit guard: skip garbage expert blocks beyond `num_valid_ids`. bx_m_i32 = arith.index_cast(T.i32, bx_m) - blk_valid = arith.cmpi(arith.CmpIPredicate.ult, bx_m_i32, num_valid_i32) + blk_valid = arith.cmpi(CmpIPredicate.ult, bx_m_i32, num_valid_i32) + sort_blk = _div_pow2(bx_m, _sort_block_m) expert_i32 = buffer_ops.buffer_load( - expert_rsrc, bx, vec_width=1, dtype=T.i32 + expert_rsrc, sort_blk, vec_width=1, dtype=T.i32 ) expert_idx = arith.index_cast(T.index, expert_i32) exp_valid = arith.cmpi( - arith.CmpIPredicate.ult, expert_i32, fx.Int32(experts) + CmpIPredicate.ult, expert_i32, arith.constant(experts, type=T.i32) + ) + + if _persistent: + # Absolute B-base: no cross-iteration state needed. + _expert_b_base = expert_idx * arith.constant( + _expert_b_stride, index=True + ) + else: + # Legacy incremental B-base: delta = (cur - prev) * stride + _delta_expert = arith.subi(expert_i32, _prev_expert_i32) + _delta_expert_idx = arith.index_cast(ir.IndexType.get(), _delta_expert) + _delta_b = _delta_expert_idx * arith.constant( + _expert_b_stride, index=True + ) + _expert_b_base = _prev_expert_b_base + _delta_b + + # Early-exit: if the first row of this tile is a sentinel (all-padding tile), + # skip the entire GEMM. + _first_tok = buffer_ops.buffer_load( + sorted_rsrc, bx_m, vec_width=1, dtype=T.i32 ) + _first_tid = arith.andi(_first_tok, arith.constant(0xFFFFFF, type=T.i32)) + _tokens_i32_guard = arith.index_cast(T.i32, tokens_in) + tile_has_tokens = arith.cmpi( + CmpIPredicate.ult, _first_tid, _tokens_i32_guard + ) + + # For tile_m < 32 (pack_M < _scale_pack_m): shift a_scale i32 so the + # correct bytes land at the op_sel positions we use. + if pack_M < _scale_pack_m: + _m_off = _mod_pow2(_div_pow2(bx_m, 16), _scale_pack_m) + _m_scale_shift_i32 = arith.index_cast( + T.i32, _m_off * arith.constant(8, index=True) + ) + else: + _m_scale_shift_i32 = None def _moe_gemm2_then_body(): # Expert id for this M tile. - n_idx = fx.Index(model_dim) + n_idx = arith.constant(model_dim, index=True) expert_off_idx = expert_idx * n_idx # index # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- @@ -1930,25 +2887,24 @@ def _moe_gemm2_then_body(): ) num_x_loads = bytes_per_thread_x // x_load_bytes chunk_i32 = x_load_bytes // 4 # dwords per chunk (1/2/4) - vec4_i32 = T.vec(4, T.i32) - vec2_i32 = T.vec(2, T.i32) - vec1_i32 = T.vec(1, T.i32) + vec4_i32 = T.vec(4, i32) - c_k_div4 = ( - (k_in // c_a_pack) * fx.Index(int(a_elem_bytes)) - ) // arith.index(4) - c_k_div4_i32 = arith.index_cast(T.i32, c_k_div4) + c_k_div4 = _div_pow2( + _div_pow2(k_in, int(a_elem_vec_pack)) + * arith.constant(int(a_elem_bytes), index=True), + 4, + ) tile_k_dwords = (int(tile_k) * int(a_elem_bytes)) // ( 4 * int(a_elem_vec_pack) ) layout_x_tile_div4 = fx.make_layout( (tile_m, tile_k_dwords), stride=(tile_k_dwords, 1) ) - c_chunk_i32 = fx.Index(chunk_i32) + c_chunk_i32 = arith.constant(chunk_i32, index=True) tx_i32_base = tx * c_chunk_i32 - topk_i32 = fx.Int32(topk) - mask24 = fx.Int32(0xFFFFFF) + topk_i32 = arith.constant(topk) + mask24 = arith.constant(0xFFFFFF) # Sentinel clamp uses `tokens` as the upper bound: t_valid = (t < tokens). tokens_i32 = arith.index_cast(T.i32, tokens_in) @@ -1962,6 +2918,8 @@ def x_tile_chunk_coord_i32(i: int): chunk_i32=chunk_i32, ) + vec1_i32 = T.vec(1, i32) + vec2_i32 = T.vec(2, i32) x_load_vec_elems = ( x_load_bytes if a_elem_bytes == 1 else x_load_bytes // a_elem_bytes ) @@ -2009,26 +2967,25 @@ def load_x(idx_i32): fused_i = buffer_ops.buffer_load( sorted_rsrc, sorted_row_i, vec_width=1, dtype=T.i32 ) - t_i32 = fused_i & mask24 - s_i32 = fused_i >> fx.Int32(24) - # Keep `blk_valid` only; remove per-row token validity checks. - - t_valid = arith.cmpi(arith.CmpIPredicate.ult, t_i32, tokens_i32) - s_valid = arith.cmpi(arith.CmpIPredicate.ult, s_i32, topk_i32) - ts_valid = t_valid & s_valid - t_safe = arith.select(ts_valid, t_i32, fx.Int32(0)) - s_safe = arith.select(ts_valid, s_i32, fx.Int32(0)) + t_i32 = arith.andi(fused_i, mask24) + s_i32 = arith.shrui(fused_i, arith.constant(24)) + + t_valid = arith.cmpi(CmpIPredicate.ult, t_i32, tokens_i32) + s_valid = arith.cmpi(CmpIPredicate.ult, s_i32, topk_i32) + ts_valid = arith.andi(t_valid, s_valid) + t_safe = arith.select(ts_valid, t_i32, arith.constant(0)) + s_safe = arith.select(ts_valid, s_i32, arith.constant(0)) row_ts_i32 = t_safe * topk_i32 + s_safe row_ts_idx = arith.index_cast(T.index, row_ts_i32) - # Base row offset in dword units: row_ts_idx * (k_in/4) x_row_base_div4.append(row_ts_idx * c_k_div4) def load_x_tile(base_k): - base_k_div4 = ( - (base_k // c_a_pack) - * fx.Index(int(a_elem_bytes)) - ) // arith.index(4) + base_k_div4 = _div_pow2( + _div_pow2(base_k, int(a_elem_vec_pack)) + * arith.constant(int(a_elem_bytes), index=True), + 4, + ) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] @@ -2043,35 +3000,43 @@ def load_x_tile(base_k): return parts # tx -> wave/lane (GEMM-style decomposition). - wave_id, lane_id = split_row_major_2d(tx, fx.Index(64)) - lane_div_16, lane_mod_16 = split_row_major_2d( - lane_id, fx.Index(16) - ) + coord_wl = idx2crd(tx, layout_tx_wave_lane) + wave_id = layout_get(coord_wl, 0) + lane_id = layout_get(coord_wl, 1) + coord_l16 = idx2crd(lane_id, layout_lane16) + lane_div_16 = layout_get(coord_l16, 0) + lane_mod_16 = layout_get(coord_l16, 1) row_a_lds = lane_mod_16 - col_offset_base = lane_div_16 * fx.Index(16) + col_offset_base = lane_div_16 * arith.constant(16, index=True) # Dynamic N tiling within block. - by_n = by * fx.Index(tile_n) num_waves = 4 n_per_wave = tile_n // num_waves num_acc_n = n_per_wave // 16 - c_n_per_wave = fx.Index(n_per_wave) - wave_mod_4 = wave_id % fx.Index(4) + c_n_per_wave = arith.constant(n_per_wave, index=True) + wave_mod_4 = _mod_pow2(wave_id, 4) n_tile_base = wave_mod_4 * c_n_per_wave - # Precompute (n_blk, n_intra) for B, and col indices for output. - n_intra_list = [] - n_blk_list = [] + by_n = by * arith.constant(tile_n, index=True) + + if pack_N < _scale_pack_n: + _global_n_base = expert_off_idx + by_n + n_tile_base + _n_off = _mod_pow2(_div_pow2(_global_n_base, 16), _scale_pack_n) + _n_scale_shift_i32 = arith.index_cast( + T.i32, _n_off * arith.constant(8, index=True) + ) + else: + _n_scale_shift_i32 = None + n_intra_list = [None] * num_acc_n + n_blk_list = [None] * num_acc_n for i in range_constexpr(num_acc_n): offset = i * 16 - c_offset = fx.Index(offset) + c_offset = arith.constant(offset, index=True) global_n = by_n + n_tile_base + c_offset + lane_mod_16 - row_w = expert_off_idx + global_n - n_blk, n_intra = split_row_major_2d(row_w, fx.Index(16)) - n_blk_list.append(n_blk) - n_intra_list.append(n_intra) + n_blk_list[i] = _div_pow2(global_n, 16) + n_intra_list[i] = _mod_pow2(global_n, 16) m_repeat = tile_m // 16 k_unroll = tile_k_bytes // 128 # K64-byte micro-step (2x MFMA) @@ -2084,21 +3049,24 @@ def load_x_tile(base_k): # --- B Load Logic (K64) - shared layout with preshuffle GEMM --- def load_b_packs_k64(base_k, ku: int, ni: int): """Load one K64-byte B micro-step: single 16B load, split into 2x i64.""" - c64 = fx.Index(64) base_k_bytes = base_k * arith.constant( int(b_elem_bytes), index=True ) - k0_base = base_k_bytes // c64 - k0 = k0_base + fx.Index(ku) + k0_base = _div_pow2(base_k_bytes, 64) + k0 = k0_base + arith.constant(ku, index=True) k1 = lane_div_16 - coord_pack = ( - n_blk_list[ni], - k0, - k1, - n_intra_list[ni], - fx.Index(0), + # Incremental B addressing: _expert_b_base carries the + # expert's preshuffle offset (updated via delta each + # persist_m iteration); local n_blk/n_intra contribute + # the per-lane within-tile offset. All strides are + # compile-time constants -> shift/mul, no Barrett. + idx_pack = ( + _expert_b_base + + n_blk_list[ni] * arith.constant(_b_stride_n0, index=True) + + k0 * arith.constant(_b_stride_k0, index=True) + + k1 * arith.constant(_b_stride_klane, index=True) + + n_intra_list[ni] * arith.constant(_b_stride_nlane, index=True) ) - idx_pack = crd2idx(coord_pack, layout_b) vec_elems = kpack_bytes // int(b_elem_bytes) b16 = _buffer_load_vec( @@ -2106,7 +3074,7 @@ def load_b_packs_k64(base_k, ku: int, ni: int): vector, w_rsrc, idx_pack, - elem_type=_w_elem_type(is_f4_b=is_f4_b, is_f16_b=is_f16_b), + elem_type=_w_elem_type(), vec_elems=vec_elems, elem_bytes=b_elem_bytes, offset_in_bytes=(b_elem_bytes == 1), @@ -2132,6 +3100,35 @@ def load_b_tile(base_k): b_tile.append((packs0, packs1)) return b_tile + _b_split_enabled = k_unroll >= 2 + _b_split_ku = k_unroll // 2 if _b_split_enabled else k_unroll + + def load_b_tile_lo(base_k): + """Load first half of B tile (ku < _b_split_ku).""" + b_tile = [] + for ku in range_constexpr(_b_split_ku): + packs0 = [] + packs1 = [] + for ni in range_constexpr(num_acc_n): + b0, b1 = load_b_packs_k64(base_k, ku, ni) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile + + def load_b_tile_hi(base_k): + """Load second half of B tile (ku >= _b_split_ku).""" + b_tile = [] + for ku in range_constexpr(_b_split_ku, k_unroll): + packs0 = [] + packs1 = [] + for ni in range_constexpr(num_acc_n): + b0, b1 = load_b_packs_k64(base_k, ku, ni) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile + def load_scale(arg_scale, rsrc, scale_info, ku, mni): k_lane = lane_div_16 n_lane = lane_mod_16 @@ -2145,7 +3142,16 @@ def load_scale(arg_scale, rsrc, scale_info, ku, mni): s = buffer_ops.buffer_load(rsrc, idx_pack, vec_width=1, dtype=T.i32) return vector.from_elements(T.vec(1, T.i32), [s]) - def load_b_scale_tile(base_k): + def _apply_k_shift(scale_vec, k_shift_bits): + if k_shift_bits > 0: + val = vector.extract( + scale_vec, static_position=[0], dynamic_position=[] + ) + val = arith.shrui(val, arith.constant(k_shift_bits, type=T.i32)) + return vector.from_elements(T.vec(1, T.i32), [val]) + return scale_vec + + def load_b_scale_tile(base_k, k_shift_bits=0): b_scale_tile = [] for ku in range_constexpr(k_unroll_packed): for ni in range_constexpr(num_acc_n_packed): @@ -2155,12 +3161,19 @@ def load_b_scale_tile(base_k): layout_b_scale, ku + base_k, ni - + (expert_off_idx + by_n + n_tile_base) // pack_N // 16, + + _div_pow2( + _div_pow2( + expert_off_idx + by_n + n_tile_base, + _scale_pack_n, + ), + 16, + ), ) + scale = _apply_k_shift(scale, k_shift_bits) b_scale_tile.append(scale) return b_scale_tile - def load_a_scale_tile(base_k): + def load_a_scale_tile(base_k, k_shift_bits=0): a_scale_tile = [] for ku in range_constexpr(k_unroll_packed): for mi in range_constexpr(m_repeat_packed): @@ -2169,13 +3182,17 @@ def load_a_scale_tile(base_k): sx_rsrc, layout_a_scale, ku + base_k, - mi + bx_m // pack_M // 16, + mi + _div_pow2(_div_pow2(bx_m, _scale_pack_m), 16), ) + scale = _apply_k_shift(scale, k_shift_bits) a_scale_tile.append(scale) return a_scale_tile - def prefetch_ab_scale_tile(base_k): - return [load_a_scale_tile(base_k), load_b_scale_tile(base_k)] + def prefetch_ab_scale_tile(base_k, k_shift_bits=0): + return [ + load_a_scale_tile(base_k, k_shift_bits), + load_b_scale_tile(base_k, k_shift_bits), + ] vec8_x = T.vec(vec8_elems, x_elem) vec4_x_lds = T.vec(vec4_elems, x_elem) @@ -2240,14 +3257,11 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): col_base_swz = ( col_base_swz_bytes if elem_bytes == 1 - else (col_base_swz_bytes // arith.index(2)) - ) - idx_a16 = lds_row_major_idx( - curr_row_a_lds, - col_base_swz, - fx.Index(lds_stride), - lds_base, + else (col_base_swz_bytes / arith.index(2)) ) + # Pass as list so layout_utils.crd2idx uses static arith path + idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds) + idx_a16 = idx_a16 + lds_base loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) a0 = vector.extract( @@ -2267,9 +3281,17 @@ def compute_tile( *, prefetch_epilogue: bool = False, a0_prefetch=None, + a1_prefetch=None, + b_hi_loader=None, ): + if b_hi_loader is not None: + b_tile_full = [None] * k_unroll + for i in range_constexpr(_b_split_ku): + b_tile_full[i] = b_tile_in[i] + else: + b_tile_full = b_tile_in acc_list = list(acc_in) - mfma_res_ty = vec4_f32 + mfma_res_ty = vec4_i32 if is_int8 else vec4_f32 epilogue_pf = None bias = None @@ -2281,7 +3303,7 @@ def compute_tile( bias_offset = expert_off_idx + global_n bias.append( buffer_ops.buffer_load( - bias_rsrc, bias_offset, vec_width=1, dtype=T.f32 + bias_rsrc, bias_offset, vec_width=1, dtype=f32 ) ) tw_pf = None @@ -2289,10 +3311,10 @@ def compute_tile( tw_pf = [] lane_div_16_mul4_pf = lane_div_16 * arith.index(4) ii_idx_list_pf = [ - fx.Index(ii) for ii in range(4) + arith.constant(ii, index=True) for ii in range(4) ] for mi in range_constexpr(m_repeat): - mi_base_pf = fx.Index(mi * 16) + mi_base_pf = arith.constant(mi * 16, index=True) for ii in range_constexpr(4): row_off_pf = ( lane_div_16_mul4_pf + ii_idx_list_pf[ii] @@ -2304,7 +3326,7 @@ def compute_tile( sorted_w_rsrc, sorted_row_pf, vec_width=1, - dtype=T.f32, + dtype=f32, ) ) epilogue_pf = (None, tw_pf, bias) @@ -2317,13 +3339,34 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) return vector.bitcast(vec8_i32, v4) - # fp4 path - for ku128 in range_constexpr(k_unroll_packed): + # fp4 path -- single k_idx loop [0, k_unroll). + # b_hi load is issued at the very start so all k_unroll + # MFMAs can overlap the VMEM latency. + _pack_K_shift = (pack_K - 1).bit_length() + _pack_K_mask = pack_K - 1 + + if b_hi_loader is not None: + _b_hi = b_hi_loader() + for _bhi_i in range_constexpr(len(_b_hi)): + b_tile_full[_b_split_ku + _bhi_i] = _b_hi[_bhi_i] + + for k_idx in range_constexpr(k_unroll): + ku128 = k_idx >> _pack_K_shift + ikxdl = k_idx & _pack_K_mask + + b_packs0, b_packs1 = b_tile_full[k_idx] + + col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack + for mi in range_constexpr(m_repeat_packed): a_scale_i32 = a_scale[ku128 * m_repeat_packed + mi] a_scale_val = vector.extract( a_scale_i32, static_position=[0], dynamic_position=[] ) + if _m_scale_shift_i32 is not None: + a_scale_val = arith.shrui( + a_scale_val, _m_scale_shift_i32 + ) for ni in range_constexpr(num_acc_n_packed): b_scale_i32 = b_scale[ku128 * num_acc_n_packed + ni] b_scale_val = vector.extract( @@ -2331,76 +3374,77 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): static_position=[0], dynamic_position=[], ) - for ikxdl in range_constexpr(pack_K): - k_idx = ku128 * pack_K + ikxdl - - b_packs0, b_packs1 = b_tile_in[k_idx] - - col_base = ( - col_offset_base - + (k_idx * 128) // a_elem_vec_pack + if _n_scale_shift_i32 is not None: + b_scale_val = arith.shrui( + b_scale_val, _n_scale_shift_i32 ) - for imxdl in range_constexpr(pack_M): - col_base0 = col_base - mi_idx = mi * pack_M + imxdl - mi_val = fx.Index(mi_idx * 16) - curr_row_a_lds = row_a_lds + mi_val - - if ( - (a0_prefetch is not None) - and (k_idx == 0) - and (mi_idx == 0) - ): - a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64( - curr_row_a_lds, col_base0, lds_base - ) + for imxdl in range_constexpr(pack_M): + col_base0 = col_base + mi_idx = mi * pack_M + imxdl + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + + if ( + (a0_prefetch is not None) + and (k_idx == 0) + and (mi_idx == 0) + ): + a0, a1 = a0_prefetch + elif ( + (a1_prefetch is not None) + and (k_idx == 1) + and (mi_idx == 0) + ): + a0, a1 = a1_prefetch + else: + a0, a1 = lds_load_packs_k64( + curr_row_a_lds, col_base0, lds_base + ) - if is_f8_a: - col_base1 = col_base + 64 - a2, a3 = lds_load_packs_k64( - curr_row_a_lds, col_base1, lds_base - ) - a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) - else: - a128 = pack_i64x4_to_i32x8( - a0, a1, c0_i64, c0_i64 - ) + if is_f8_a: + col_base1 = col_base + 64 + a2, a3 = lds_load_packs_k64( + curr_row_a_lds, col_base1, lds_base + ) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + else: + a128 = pack_i64x4_to_i32x8( + a0, a1, c0_i64, c0_i64 + ) - for inxdl in range_constexpr(pack_N): - ni_idx = ni * pack_N + inxdl + for inxdl in range_constexpr(pack_N): + ni_idx = ni * pack_N + inxdl - b0 = b_packs0[ni_idx] - b1 = b_packs1[ni_idx] - b128 = pack_i64x4_to_i32x8( - b0, b1, c0_i64, c0_i64 - ) + b0 = b_packs0[ni_idx] + b1 = b_packs1[ni_idx] + b128 = pack_i64x4_to_i32x8( + b0, b1, c0_i64, c0_i64 + ) - acc_idx = mi_idx * num_acc_n + ni_idx - rocdl.sched_barrier(0) - acc_list[acc_idx] = ( - rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [ - a128, - b128, - acc_list[acc_idx], - cbsz, - blgp, - ikxdl * pack_M + imxdl, - a_scale_val, - ikxdl * pack_N + inxdl, - b_scale_val, - ], - ) + acc_idx = mi_idx * num_acc_n + ni_idx + rocdl.sched_barrier(0) + acc_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + acc_list[acc_idx], + cbsz, + blgp, + ikxdl * _scale_pack_m + imxdl, + a_scale_val, + ikxdl * _scale_pack_n + inxdl, + b_scale_val, + ], ) + ) return acc_list, epilogue_pf # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- - lds_tile_elems = fx.Index(tile_m * lds_stride) + lds_tile_elems = arith.constant(tile_m * lds_stride, index=True) lds_base_cur = arith.index(0) lds_base_nxt = lds_tile_elems @@ -2451,15 +3495,18 @@ def hot_loop_scheduler(): rocdl.sched_barrier(0) - # Prologue. - k0 = arith.index(0) - x_regs0 = load_x_tile(k0) - b_cur = load_b_tile(k0) - a_scale_pong, b_scale_pong = prefetch_ab_scale_tile(k0 // pack_K // 128) - store_x_tile_to_lds(x_regs0, lds_base_cur) + def _k_shift_bits(k_py): + if pack_K >= _scale_pack_k: + return 0 + return ((k_py // 128) % _scale_pack_k) * _scale_pack_m * 8 + + def _k_base(k_py): + return k_py // _scale_pack_k // 128 + # Preload sorted_idx into lds_tid for epilogue precompute_row - _c_tile_m_idx = fx.Index(tile_m) - _tid_in_range = arith.cmpi(arith.CmpIPredicate.ult, tx, _c_tile_m_idx) + # (N-independent; placed before N-tile loop so it's done once per M-tile.) + _c_tile_m_idx = arith.constant(tile_m, index=True) + _tid_in_range = arith.cmpi(CmpIPredicate.ult, tx, _c_tile_m_idx) _if_tid = scf.IfOp(_tid_in_range) with ir.InsertionPoint(_if_tid.then_block): _tid_row = bx_m + tx @@ -2469,17 +3516,42 @@ def hot_loop_scheduler(): _tid_vec1 = vector.from_elements(T.vec(1, T.i32), [_tid_val]) vector.store(_tid_vec1, lds_tid, [tx]) scf.YieldOp([]) + + gpu.barrier() + + # Prologue -- B-first. + k0 = arith.index(0) + if _b_split_enabled: + b_cur = load_b_tile_lo(k0) + else: + b_cur = load_b_tile(k0) + a_scale_pong, b_scale_pong = prefetch_ab_scale_tile( + _k_base(0), _k_shift_bits(0) + ) + # scheduling fence to prevent LLVM from deferring + # the scale buffer_loads past the upcoming barrier. + rocdl.sched_barrier(0) + x_regs0 = load_x_tile(k0) + store_x_tile_to_lds(x_regs0, lds_base_cur) gpu.barrier() acc = [acc_init] * num_acc_n * m_repeat lds_base_pong = lds_base_cur lds_base_ping = lds_base_nxt - # Cross-tile A0 LDS prefetch (default-on): prefetch the first A-pack (K64) for the - # tile we are about to compute from LDS, to overlap with upcoming VMEM. + # Cross-tile A0+A1 LDS prefetch: issue both ds_reads back-to-back + # (??2) so LDS bandwidth is fully utilized and the second read + # completes during MFMA #1/#2 execution. a0_prefetch_pong = lds_load_packs_k64( row_a_lds, col_offset_base, lds_base_pong ) + _a1_col_base = col_offset_base + 128 // a_elem_vec_pack + a1_prefetch_pong = ( + lds_load_packs_k64(row_a_lds, _a1_col_base, lds_base_pong) + if pack_K >= 2 + else None + ) + # Main loop: process K tiles in 2-tile ping-pong steps. # # IMPORTANT: for odd number of K tiles, leave **1** tail tile; for even, leave **2**. @@ -2492,20 +3564,31 @@ def hot_loop_scheduler(): if k_main2_py < 0: k_main2_py = 0 - c2_tile_k = fx.Index(tile_k * 2) + c2_tile_k = arith.constant(tile_k * 2, index=True) b_pong = b_cur + k0_pong_bk = k0 + # Only emit the scf.for when there are actually iterations to run. # When k_main2_py == 0 the loop body is empty; emitting an scf.for # would create a region whose internal SSA values cannot be used # by the post-loop tail code. + def _make_b_hi_loader(base_k): + """Create a b_hi_loader callable for a given base_k.""" + return lambda _bk=base_k: load_b_tile_hi(_bk) + if k_main2_py > 0: for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): - k_iv = k_iv_py + k_iv = arith.index(k_iv_py) next_k1 = k_iv + tile_k + next_k1_bk = next_k1 // 2 x_regs_ping = load_x_tile(next_k1) - b_ping = load_b_tile(next_k1 // 2) + b_ping_lo = ( + load_b_tile_lo(next_k1_bk) + if _b_split_enabled + else load_b_tile(next_k1_bk) + ) a_scale_ping, b_scale_ping = prefetch_ab_scale_tile( - next_k1 // pack_K // 128 + _k_base(next_k1), _k_shift_bits(next_k1) ) acc, _ = compute_tile( @@ -2515,6 +3598,12 @@ def hot_loop_scheduler(): a_scale_pong, b_scale_pong, a0_prefetch=a0_prefetch_pong, + a1_prefetch=a1_prefetch_pong, + b_hi_loader=( + _make_b_hi_loader(k0_pong_bk) + if _b_split_enabled + else None + ), ) store_x_tile_to_lds(x_regs_ping, lds_base_ping) # hot_loop_scheduler() @@ -2524,22 +3613,40 @@ def hot_loop_scheduler(): a0_prefetch_ping = lds_load_packs_k64( row_a_lds, col_offset_base, lds_base_ping ) + a1_prefetch_ping = ( + lds_load_packs_k64(row_a_lds, _a1_col_base, lds_base_ping) + if pack_K >= 2 + else None + ) next_k2 = k_iv + c2_tile_k + next_k2_py = k_iv_py + tile_k * 2 + next_k2_bk = next_k2 // 2 x_regs_pong = load_x_tile(next_k2) - b_pong = load_b_tile(next_k2 // 2) + b_pong = ( + load_b_tile_lo(next_k2_bk) + if _b_split_enabled + else load_b_tile(next_k2_bk) + ) a_scale_pong, b_scale_pong = prefetch_ab_scale_tile( - next_k2 // pack_K // 128 + _k_base(next_k2_py), _k_shift_bits(next_k2_py) ) acc, _ = compute_tile( acc, - b_ping, + b_ping_lo, lds_base_ping, a_scale_ping, b_scale_ping, a0_prefetch=a0_prefetch_ping, + a1_prefetch=a1_prefetch_ping, + b_hi_loader=( + _make_b_hi_loader(next_k1_bk) + if _b_split_enabled + else None + ), ) + k0_pong_bk = next_k2_bk store_x_tile_to_lds(x_regs_pong, lds_base_pong) # hot_loop_scheduler() gpu.barrier() @@ -2548,6 +3655,11 @@ def hot_loop_scheduler(): a0_prefetch_pong = lds_load_packs_k64( row_a_lds, col_offset_base, lds_base_pong ) + a1_prefetch_pong = ( + lds_load_packs_k64(row_a_lds, _a1_col_base, lds_base_pong) + if pack_K >= 2 + else None + ) if odd_k_tiles: # Tail: single remaining tile (already in `b_cur` / `lds_base_pong`). @@ -2558,16 +3670,28 @@ def hot_loop_scheduler(): a_scale_pong, b_scale_pong, a0_prefetch=a0_prefetch_pong, + a1_prefetch=a1_prefetch_pong, prefetch_epilogue=True, + b_hi_loader=( + _make_b_hi_loader(k0_pong_bk) if _b_split_enabled else None + ), ) else: # Tail: 2 remaining tiles. k_tail1 = (k_in + tile_k - 1) // tile_k * tile_k - tile_k + k_tail1_py = ( + int(inter_dim) + tile_k - 1 + ) // tile_k * tile_k - tile_k + k_tail1_bk = k_tail1 // 2 x_regs_ping = load_x_tile(k_tail1) - b_ping = load_b_tile(k_tail1 // 2) + b_ping_lo = ( + load_b_tile_lo(k_tail1_bk) + if _b_split_enabled + else load_b_tile(k_tail1_bk) + ) a_scale_ping, b_scale_ping = prefetch_ab_scale_tile( - k_tail1 // pack_K // 128 + _k_base(k_tail1_py), _k_shift_bits(k_tail1_py) ) acc, _ = compute_tile( @@ -2577,6 +3701,10 @@ def hot_loop_scheduler(): a_scale_pong, b_scale_pong, a0_prefetch=a0_prefetch_pong, + a1_prefetch=a1_prefetch_pong, + b_hi_loader=( + _make_b_hi_loader(k0_pong_bk) if _b_split_enabled else None + ), ) store_x_tile_to_lds(x_regs_ping, lds_base_ping) @@ -2587,23 +3715,37 @@ def hot_loop_scheduler(): a0_prefetch_ping = lds_load_packs_k64( row_a_lds, col_offset_base, lds_base_ping ) + a1_prefetch_ping = ( + lds_load_packs_k64(row_a_lds, _a1_col_base, lds_base_ping) + if pack_K >= 2 + else None + ) acc, epilogue_pf = compute_tile( acc, - b_ping, + b_ping_lo, lds_base_ping, a_scale_ping, b_scale_ping, a0_prefetch=a0_prefetch_ping, + a1_prefetch=a1_prefetch_ping, prefetch_epilogue=True, + b_hi_loader=( + _make_b_hi_loader(k_tail1_bk) if _b_split_enabled else None + ), ) # ---------------- Epilogue: LDS CShuffle + atomic half2 (x2) ---------------- # Reuse the shared helper so GEMM / MoE kernels share the exact same CShuffle skeleton. - model_i32 = fx.Int32(model_dim) - zero_i32 = fx.Int32(0) - c2_i32 = fx.Int32(2) - mask_even_i32 = fx.Int32(0xFFFFFFFE) + tw_pf = None + bias_pf = None + if epilogue_pf is not None: + _, tw_pf, bias_pf = epilogue_pf + + mask24_i32 = arith.constant(0xFFFFFF) + topk_i32_v = topk_i32 + + zero_i32 = arith.constant(0) def atomic_add_f16x2(val_f16x2, byte_off_i32): rocdl.raw_ptr_buffer_atomic_fadd( @@ -2614,24 +3756,24 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): zero_i32, ) - sw_pf = None - tw_pf = None - bias_pf = None - if epilogue_pf is not None: - sw_pf, tw_pf, bias_pf = epilogue_pf - - mask24_i32 = fx.Int32(0xFFFFFF) - topk_i32_v = topk_i32 - # Weight scales for the N tile (col_g depends on lane/wave/by but not on (t,s)). if lds_out is None: raise RuntimeError( - "FLYDSL_MOE_STAGE2_CSHUFFLE=1 but lds_out is not allocated/aliased." + "FLIR_MOE_STAGE2_CSHUFFLE=1 but lds_out is not allocated/aliased." ) - out_base_idx = None - if _needs_global_atomic_bf16: - out_base_idx = buffer_ops.extract_base_index(arg_out) + # Precompute the output base address (i64 index) for ALL paths. + # Both accumulate=True (global atomic) and accumulate=False (global store) + # need 64-bit addressing to avoid i32 offset overflow when + # tokens * model_dim * elem_bytes > INT32_MAX (~150K tokens for model_dim=7168). + from flydsl._mlir.dialects import fly as _fly + + _llvm_ptr_ty = ir.Type.parse("!llvm.ptr") + out_base_ptr = _fly.extract_aligned_pointer_as_index( + _llvm_ptr_ty, arg_out + ) + out_base_i64 = llvm.ptrtoint(T.i64, out_base_ptr) + out_base_idx = arith.index_cast(T.index, out_base_i64) def write_row_to_lds( *, @@ -2650,7 +3792,7 @@ def write_row_to_lds( tw = tw_pf[tw_idx] else: tw = buffer_ops.buffer_load( - sorted_w_rsrc, row, vec_width=1, dtype=T.f32 + sorted_w_rsrc, row, vec_width=1, dtype=f32 ) for ni in range_constexpr(num_acc_n): @@ -2659,6 +3801,8 @@ def write_row_to_lds( v = vector.extract( acc[acc_idx], static_position=[ii], dynamic_position=[] ) + if is_int8: + v = arith.sitofp(f32, v) if enable_bias: v = v + bias_pf[ni] @@ -2677,52 +3821,69 @@ def precompute_row(*, row_local, row): # to avoid extra VMEM round-trips in the epilogue. fused2 = memref.load(lds_tid, [row_local]) row_i32 = arith.index_cast(T.i32, row) - row_valid0 = arith.cmpi(arith.CmpIPredicate.ult, row_i32, num_valid_i32) + row_valid0 = arith.cmpi(CmpIPredicate.ult, row_i32, num_valid_i32) t = fused2 & mask24_i32 s = fused2 >> 24 - t_ok = arith.cmpi(arith.CmpIPredicate.ult, t, tokens_i32) - s_ok = arith.cmpi(arith.CmpIPredicate.ult, s, topk_i32_v) - row_valid = row_valid0 & (t_ok & s_ok) + t_ok = arith.cmpi(CmpIPredicate.ult, t, tokens_i32) + s_ok = arith.cmpi(CmpIPredicate.ult, s, topk_i32_v) + row_valid = arith.andi(row_valid0, arith.andi(t_ok, s_ok)) + t_idx = arith.index_cast(ir.IndexType.get(), t) + s_idx = arith.index_cast(ir.IndexType.get(), s) + ts_idx = t_idx * arith.constant(topk, index=True) + s_idx + if accumulate: + row_byte_base = out_base_idx + t_idx * arith.constant( + model_dim * out_elem_bytes, index=True + ) + else: + row_byte_base = out_base_idx + ts_idx * arith.constant( + model_dim * out_elem_bytes, index=True + ) + return ((fused2, row_byte_base), row_valid) - return (fused2, row_valid) + def _idx_to_llvm_ptr(idx_val, addr_space=1): + """Convert an index-typed byte address to !llvm.ptr.""" + idx_v = idx_val._value if hasattr(idx_val, "_value") else idx_val + i64_v = arith.index_cast(T.i64, idx_v) + i64_raw = i64_v._value if hasattr(i64_v, "_value") else i64_v + ptr_ty = ir.Type.parse(f"!llvm.ptr<{addr_space}>") + return llvm.inttoptr(ptr_ty, i64_raw) def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): - fused = row_ctx - t = fused & mask24_i32 - s = fused >> 24 - idx0 = t * model_i32 + fused, row_byte_base = row_ctx if not bool(accumulate): - ts = t * topk_i32_v + s - idx0 = ts * model_i32 - col_i32 = arith.index_cast(T.i32, col_g0) - idx_elem = idx0 + col_i32 - idx_elem_even = idx_elem & mask_even_i32 - if _needs_global_atomic_bf16: - if bool(accumulate): - byte_off = idx_elem_even * c2_i32 - byte_off_idx = arith.index_cast(T.index, byte_off) - ptr_addr_idx = out_base_idx + byte_off_idx - out_ptr = buffer_ops.create_llvm_ptr(ptr_addr_idx, address_space=1) - out_ptr_v = out_ptr._value if hasattr(out_ptr, "_value") else out_ptr - frag_v = frag._value if hasattr(frag, "_value") else frag - - llvm.AtomicRMWOp( - llvm.AtomicBinOp.fadd, - out_ptr_v, - frag_v, - llvm.AtomicOrdering.monotonic, - syncscope="agent", - alignment=4, - ) - else: - buffer_ops.buffer_store(frag, out_rsrc, idx_elem_even) + # ---- 64-bit global store path (avoids i32 offset overflow) ---- + col_idx = col_g0 + byte_off_col = col_idx * arith.constant( + out_elem_bytes, index=True + ) + ptr_addr_idx = row_byte_base + byte_off_col + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + frag_v = frag._value if hasattr(frag, "_value") else frag + llvm.StoreOp( + frag_v, + out_ptr_v, + alignment=_e_vec * out_elem_bytes, + nontemporal=True, + ) else: - byte_off = idx_elem_even * c2_i32 - if bool(accumulate): - atomic_add_f16x2(frag, byte_off) - else: - buffer_ops.buffer_store(frag, out_rsrc, idx_elem_even) + # ---- accumulate=True: 64-bit global atomic path ---- + col_idx = col_g0 + byte_off_col = col_idx * arith.constant( + out_elem_bytes, index=True + ) + ptr_addr_idx = row_byte_base + byte_off_col + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + frag_v = frag._value if hasattr(frag, "_value") else frag + llvm.AtomicRMWOp( + llvm.AtomicBinOp.fadd, + out_ptr_v, + frag_v, + llvm.AtomicOrdering.monotonic, + syncscope="agent", + alignment=_e_vec * out_elem_bytes, + ) + _e_vec = 2 if accumulate else min(tile_n // 32, 8) c_shuffle_epilog( arith=arith, vector=vector, @@ -2731,7 +3892,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): range_constexpr=range_constexpr, tile_m=tile_m, tile_n=tile_n, - e_vec=2, + e_vec=_e_vec, m_repeat=m_repeat, num_acc_n=num_acc_n, tx=tx, @@ -2741,22 +3902,38 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): by_n=by_n, n_tile_base=n_tile_base, lds_out=lds_out, - frag_elem_type=(T.bf16 if out_is_bf16 else T.f16), + frag_elem_type=( + ir.BF16Type.get() if out_is_bf16 else ir.F16Type.get() + ), write_row_to_lds=write_row_to_lds, precompute_row=precompute_row, store_pair=store_pair, ) - _if_blk = scf.IfOp(blk_valid) - with ir.InsertionPoint(_if_blk.then_block): - _ifexpert_of = scf.IfOp(exp_valid) - with ir.InsertionPoint(_ifexpert_of.then_block): + _all_valid = arith.andi(blk_valid, arith.andi(exp_valid, tile_has_tokens)) + + if _persistent: + # Short-circuit: contiguous tiles are monotonically increasing, + # so once bx_m >= num_valid_ids all remaining tiles are invalid. + _cur_active = arith.andi(_still_active, blk_valid) + _do_gemm = arith.andi( + _cur_active, arith.andi(exp_valid, tile_has_tokens) + ) + _if_valid = scf.IfOp(_do_gemm) + with ir.InsertionPoint(_if_valid.then_block): _moe_gemm2_then_body() scf.YieldOp([]) - scf.YieldOp([]) - gpu.barrier() - scf.YieldOp([]) + gpu.barrier() + scf.YieldOp([_cur_active]) + else: + _if_valid = scf.IfOp(_all_valid) + with ir.InsertionPoint(_if_valid.then_block): + _moe_gemm2_then_body() + scf.YieldOp([]) + + gpu.barrier() + scf.YieldOp([expert_i32, _expert_b_base]) _for_ip.__exit__(None, None, None) # -- Host launcher (flyc.jit + .launch) -------------------------------- @@ -2775,6 +3952,8 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): inter_dim_pad, use_cshuffle_epilog, persist_m, + _sort_block_m, + _cu_num if _persistent else 0, ) @flyc.jit @@ -2795,19 +3974,23 @@ def launch_mixed_moe_gemm2( i32_size_expert_ids_in: fx.Int32, stream: fx.Stream, ): + _ = _cache_tag allocator.finalized = False ctx = CompilationContext.get_current() with ir.InsertionPoint(ctx.gpu_module_body): allocator.finalize() n_in = arith.index_cast(T.index, i32_n_in) - gx = n_in // fx.Index(tile_n) - _c_pm_l = fx.Index(persist_m) - gy = ( - arith.index_cast(T.index, i32_size_expert_ids_in) - + _c_pm_l - - fx.Index(1) - ) / _c_pm_l + gx = n_in / arith.constant(tile_n, index=True) + if _persistent: + gy = arith.constant(_cu_num, index=True) + else: + _c_pm_l = arith.constant(persist_m, index=True) + gy = ( + arith.index_cast(ir.IndexType.get(), i32_size_expert_ids_in.ir_value()) + + _c_pm_l + - arith.constant(1, index=True) + ) / _c_pm_l moe_gemm2( arg_out, diff --git a/kernels/silu_and_mul_fq.py b/kernels/silu_and_mul_fq.py new file mode 100644 index 00000000..ba6ff390 --- /dev/null +++ b/kernels/silu_and_mul_fq.py @@ -0,0 +1,368 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Fused silu_and_mul + MXFP4 quantization + sorted-scale write kernel (FlyDSL). + +Designed for split-K MOE stage1 post-processing: + + input : tmp_out (token_num * topk, inter_dim * 2) bf16 + sorted : sorted_token_ids (sorted_len,) i32 -- packed (token<<0 | slot<<24) + num_valid_ids (1,) i32 + output : out_fp4 raw byte buffer -- FP4x2 packed, row stride = inter_dim//2 + out_scale_sorted raw byte buffer -- tiled E8M0 scale (same layout as moe_mxfp4_sort) + +Grid: (num_sorted_rows, 1, 1) -- one workgroup per sorted row (including blockM padding). +Block: (BLOCK_THREADS, 1, 1) + +Each workgroup: + 1. Loads sorted_token_ids[bid] -> (token_id, slot_id) -> row = token_id * topk + slot_id + 2. If bid < num_valid_ids (valid row): + a. Reads gate = tmp_out[row, 0:inter_dim], up = tmp_out[row, inter_dim:2*inter_dim] + b. Computes silu(gate) * up in f32 + c. Per-1x32 MXFP4 quant -> writes packed FP4 + E8M0 scale in tiled layout + 3. If bid >= num_valid_ids (blockM padding row): + a. Writes zero FP4 bytes to out_fp4 + b. Writes zero E8M0 scale to out_scale_sorted (keeps tiled layout consistent) +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.expr import arith, vector, range_constexpr +from flydsl.expr.typing import T, Int32 +from flydsl.expr.arith import ArithValue, CmpIPredicate +from flydsl.compiler.kernel_function import CompilationContext + +from flydsl._mlir import ir +from flydsl._mlir.dialects import llvm, scf +from flydsl.expr import buffer_ops + +BLOCK_THREADS = 256 +WARP_SIZE = 64 + + +def build_silu_and_mul_fq_module(inter_dim: int, topk: int): + """Return a JIT launcher for fused silu_and_mul + mxfp4 quant + scale sort. + + Parameters + ---------- + inter_dim : int + Output columns of stage1 (after activation). Input has inter_dim*2 cols. + Must be divisible by 32 (MXFP4 block size). + topk : int + Number of expert slots per token. + """ + assert inter_dim % 32 == 0, f"inter_dim={inter_dim} must be divisible by 32" + + scale_cols = inter_dim // 32 + ELEMS_PER_THREAD = (inter_dim + BLOCK_THREADS - 1) // BLOCK_THREADS + # VEC: number of f32 elements each thread handles; must be even for FP4 packing + VEC = max(ELEMS_PER_THREAD, 2) + if VEC % 2 != 0: + VEC += 1 + assert 32 % VEC == 0, f"VEC={VEC} must divide 32 evenly" + # threads that actually participate in a 32-element quant group + THREADS_PER_QUANT_BLK = 32 // VEC + # shuffle distances for intra-group reduction + SHUFFLE_DISTS = [] + d = 1 + while d < THREADS_PER_QUANT_BLK: + SHUFFLE_DISTS.append(d) + d *= 2 + + elem_bytes_bf16 = 2 + + @flyc.kernel + def silu_and_mul_fq_kernel( + x: fx.Tensor, # (token_num*topk, inter_dim*2) bf16 + out_fp4: fx.Tensor, # raw byte buffer for packed FP4 output + out_scale_sorted: fx.Tensor, # raw byte buffer for sorted E8M0 scales + sorted_ids: fx.Tensor, # (sorted_len,) i32 + num_valid_ids: fx.Tensor, # (1,) i32 + token_num: Int32, # host scalar + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + f32 = T.f32 + i32 = T.i32 + + c0_i32 = arith.constant(0, type=i32) + c1_i32 = arith.constant(1, type=i32) + c2_i32 = arith.constant(2, type=i32) + c3_i32 = arith.constant(3, type=i32) + c4_i32 = arith.constant(4, type=i32) + c5_i32 = arith.constant(5, type=i32) + c7_i32 = arith.constant(7, type=i32) + c15_i32 = arith.constant(15, type=i32) + c21_i32 = arith.constant(21, type=i32) + c23_i32 = arith.constant(23, type=i32) + c28_i32 = arith.constant(28, type=i32) + c31_i32 = arith.constant(31, type=i32) + c32_i32 = arith.constant(32, type=i32) + c64_i32 = arith.constant(64, type=i32) + c126_i32 = arith.constant(126, type=i32) + c127_i32 = arith.constant(127, type=i32) + c254_i32 = arith.constant(254, type=i32) + c256_i32 = arith.constant(256, type=i32) + c0xFF_i32 = arith.constant(0xFF, type=i32) + c0x200000_i32 = arith.constant(0x200000, type=i32) + c0xFF800000_i32 = arith.constant(0xFF800000, type=i32) + c0x400000_i32 = arith.constant(0x400000, type=i32) + c0x7FFFFF_i32 = arith.constant(0x7FFFFF, type=i32) + c0x80000000_i32 = arith.constant(0x80000000, type=i32) + c0_f32 = arith.constant(0.0, type=f32) + c1_f32 = arith.constant(1.0, type=f32) + + scale_cols_i32 = arith.constant(scale_cols, type=i32) + inter_dim_i32 = arith.constant(inter_dim, type=i32) + topk_i32 = arith.constant(topk, type=i32) + n32_sort = scale_cols_i32 * c32_i32 + + # Buffer resources + in_rsrc = buffer_ops.create_buffer_resource(x, max_size=True) + out_rsrc = buffer_ops.create_buffer_resource(out_fp4, max_size=True) + scale_rsrc = buffer_ops.create_buffer_resource(out_scale_sorted, max_size=True) + tid_rsrc = buffer_ops.create_buffer_resource(sorted_ids, max_size=True) + nv_rsrc = buffer_ops.create_buffer_resource(num_valid_ids, max_size=True) + + num_valid = buffer_ops.buffer_load(nv_rsrc, c0_i32, vec_width=1, dtype=i32) + token_num_i32 = ArithValue(token_num) + bid_i32 = ArithValue(bid) + + row_in_range = arith.cmpi(CmpIPredicate.ult, bid_i32, num_valid) + fused_tid_val = buffer_ops.buffer_load( + tid_rsrc, bid_i32, vec_width=1, dtype=i32 + ) + mask24 = arith.constant(0xFFFFFF, type=i32) + token_id = fused_tid_val & mask24 + slot_id = ArithValue(fused_tid_val) >> arith.constant(24, type=i32) + t_ok = arith.cmpi(CmpIPredicate.ult, token_id, token_num_i32) + s_ok = arith.cmpi(CmpIPredicate.ult, slot_id, topk_i32) + is_valid = arith.andi(row_in_range, arith.andi(t_ok, s_ok)) + + def _f32_to_e2m1(qx_f32): + """Convert a scaled f32 value to fp4 (e2m1) 4-bit integer.""" + qx = qx_f32.bitcast(i32) + s = qx & c0x80000000_i32 + e = (qx >> c23_i32) & c0xFF_i32 + m = qx & c0x7FFFFF_i32 + adj_exp = arith.maxsi(c126_i32 - e, c0_i32) + m_denorm = (c0x400000_i32 | (m >> c1_i32)) >> adj_exp + is_denorm = arith.cmpi(CmpIPredicate.ult, e, c127_i32) + m = arith.select(is_denorm, m_denorm, m) + e = arith.maxsi(e - c126_i32, c0_i32) + combined = (e << c2_i32) | (m >> c21_i32) + rounded = (combined + c1_i32) >> c1_i32 + e2m1 = arith.minui(rounded, c7_i32) + return (s >> c28_i32) | e2m1 + + thread_id = ArithValue(tid) + + COLS_PER_ITER = BLOCK_THREADS * VEC + + for iter_idx in range_constexpr( + (inter_dim + COLS_PER_ITER - 1) // COLS_PER_ITER + ): + col0 = thread_id * arith.constant(VEC, type=i32) + arith.constant( + iter_idx * COLS_PER_ITER, type=i32 + ) + + col_valid = arith.cmpi(CmpIPredicate.ult, col0, inter_dim_i32) + _if_col = scf.IfOp(col_valid) + with ir.InsertionPoint(_if_col.then_block): + + _if_valid = scf.IfOp(is_valid, has_else=True) + with ir.InsertionPoint(_if_valid.then_block): + in_row = token_id * topk_i32 + slot_id + # FP4 output in token order: row = token_id * topk + slot_id + out_row_byte_base = in_row * arith.constant( + inter_dim // 2, type=i32 + ) + fp4_byte_off = out_row_byte_base + (col0 >> c1_i32) + in_row_byte_base = in_row * arith.constant( + inter_dim * 2 * elem_bytes_bf16, type=i32 + ) + up_byte_offset = arith.constant( + inter_dim * elem_bytes_bf16, type=i32 + ) + + gate_byte = in_row_byte_base + col0 * arith.constant( + elem_bytes_bf16, type=i32 + ) + up_byte = gate_byte + up_byte_offset + gate_dw = gate_byte >> c2_i32 + up_dw = up_byte >> c2_i32 + vec_dw = VEC * elem_bytes_bf16 // 4 + + gate_raw = buffer_ops.buffer_load( + in_rsrc, gate_dw, vec_width=vec_dw, dtype=i32 + ) + up_raw = buffer_ops.buffer_load( + in_rsrc, up_dw, vec_width=vec_dw, dtype=i32 + ) + + vec_bf16_ty = T.vec(VEC, T.bf16) + vec_f32_ty = T.vec(VEC, f32) + if vec_dw == 1: + vec1_i32_ty = T.vec(1, i32) + gate_vec = vector.from_elements(vec1_i32_ty, [gate_raw]) + up_vec = vector.from_elements(vec1_i32_ty, [up_raw]) + gate_bf16 = vector.bitcast(vec_bf16_ty, gate_vec) + up_bf16 = vector.bitcast(vec_bf16_ty, up_vec) + else: + gate_bf16 = vector.bitcast(vec_bf16_ty, gate_raw) + up_bf16 = vector.bitcast(vec_bf16_ty, up_raw) + gate_f32 = gate_bf16.extf(vec_f32_ty) + up_f32 = up_bf16.extf(vec_f32_ty) + + neg_log2e = arith.constant(-1.4426950408889634, type=f32) + act_vals = [] + for vi in range_constexpr(VEC): + g = vector.extract( + gate_f32, static_position=[vi], dynamic_position=[] + ) + u = vector.extract( + up_f32, static_position=[vi], dynamic_position=[] + ) + t = g * neg_log2e + emu = llvm.call_intrinsic( + f32, "llvm.amdgcn.exp2.f32", [t], [], [] + ) + den = c1_f32 + emu + sig = llvm.call_intrinsic( + f32, "llvm.amdgcn.rcp.f32", [den], [], [] + ) + act_vals.append(g * sig * u) + + local_max = c0_f32 + for vi in range_constexpr(VEC): + abs_v = llvm.call_intrinsic( + f32, "llvm.fabs.f32", [act_vals[vi]], [], [] + ) + local_max = arith.maximumf(local_max, abs_v) + + for sh_dist in SHUFFLE_DISTS: + off = arith.constant(sh_dist, type=i32) + peer = local_max.shuffle_xor(off, c64_i32) + local_max = arith.maximumf(local_max, peer) + + max_i32_v = local_max.bitcast(i32) + max_rounded = (max_i32_v + c0x200000_i32) & c0xFF800000_i32 + exp_field = max_rounded >> c23_i32 + e8m0_biased = arith.maxsi(exp_field - c2_i32, c0_i32) + + quant_exp = c254_i32 - e8m0_biased + quant_scale = (quant_exp << c23_i32).bitcast(f32) + + fp4_vals = [] + for vi in range_constexpr(VEC): + scaled_v = act_vals[vi] * quant_scale + fp4_vals.append(_f32_to_e2m1(scaled_v)) + + packed_i32 = fp4_vals[0] | (fp4_vals[1] << c4_i32) + for k in range_constexpr(1, VEC // 2): + byte_k = fp4_vals[2 * k] | (fp4_vals[2 * k + 1] << c4_i32) + packed_i32 = packed_i32 | ( + byte_k << arith.constant(k * 8, type=i32) + ) + + _pack_bytes = VEC // 2 + if _pack_bytes == 1: + store_val = arith.TruncIOp(T.i8, packed_i32) + buffer_ops.buffer_store( + store_val, out_rsrc, fp4_byte_off, offset_is_bytes=True + ) + elif _pack_bytes == 2: + store_val = arith.TruncIOp(T.i16, packed_i32) + buffer_ops.buffer_store( + store_val, out_rsrc, fp4_byte_off, offset_is_bytes=True + ) + else: + buffer_ops.buffer_store( + packed_i32, out_rsrc, fp4_byte_off, offset_is_bytes=True + ) + + lane_in_blk = col0 & c31_i32 + _if_sw = scf.IfOp(arith.cmpi(CmpIPredicate.eq, lane_in_blk, c0_i32)) + with ir.InsertionPoint(_if_sw.then_block): + row_s = bid_i32 + col_s = col0 >> c5_i32 + d0 = row_s >> c5_i32 + d1 = (row_s >> c4_i32) & c1_i32 + d2 = row_s & c15_i32 + d3 = col_s >> c3_i32 + d4 = (col_s >> c2_i32) & c1_i32 + d5 = col_s & c3_i32 + s_byte_off = ( + d0 * n32_sort + + d3 * c256_i32 + + d5 * c64_i32 + + d2 * c4_i32 + + d4 * c2_i32 + + d1 + ) + e8m0_i8 = arith.TruncIOp(T.i8, e8m0_biased) + buffer_ops.buffer_store( + e8m0_i8, scale_rsrc, s_byte_off, offset_is_bytes=True + ) + scf.YieldOp([]) + scf.YieldOp([]) + + with ir.InsertionPoint(_if_valid.else_block): + # Padding row: skip FP4 write (stage2 gather-loads by token_id, + # so padding rows are never read). Only write zero scale. + lane_in_blk_p = col0 & c31_i32 + _if_sw_p = scf.IfOp( + arith.cmpi(CmpIPredicate.eq, lane_in_blk_p, c0_i32) + ) + with ir.InsertionPoint(_if_sw_p.then_block): + row_s_p = bid_i32 + col_s_p = col0 >> c5_i32 + d0_p = row_s_p >> c5_i32 + d1_p = (row_s_p >> c4_i32) & c1_i32 + d2_p = row_s_p & c15_i32 + d3_p = col_s_p >> c3_i32 + d4_p = (col_s_p >> c2_i32) & c1_i32 + d5_p = col_s_p & c3_i32 + s_byte_off_p = ( + d0_p * n32_sort + + d3_p * c256_i32 + + d5_p * c64_i32 + + d2_p * c4_i32 + + d4_p * c2_i32 + + d1_p + ) + c0_i8 = arith.TruncIOp(T.i8, c0_i32) + buffer_ops.buffer_store( + c0_i8, scale_rsrc, s_byte_off_p, offset_is_bytes=True + ) + scf.YieldOp([]) + scf.YieldOp([]) + scf.YieldOp([]) + + @flyc.jit + def launch_silu_and_mul_fq( + x: fx.Tensor, + out_fp4: fx.Tensor, + out_scale_sorted: fx.Tensor, + sorted_ids: fx.Tensor, + num_valid_ids: fx.Tensor, + token_num: fx.Int32, + num_sorted_rows: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + pass + + idx_rows = arith.index_cast(T.index, num_sorted_rows) + launcher = silu_and_mul_fq_kernel( + x, out_fp4, out_scale_sorted, sorted_ids, num_valid_ids, token_num + ) + launcher.launch( + grid=(idx_rows, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_silu_and_mul_fq From c1bb97073ce6dbe8f5ab035f217bacd2585dc244 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Mon, 13 Apr 2026 10:13:27 +0000 Subject: [PATCH 2/3] Fix merge conflicts: restore PreshuffleBLayout and adapt MFMA API - Restore PreshuffleBLayout/make_preshuffle_b_layout and main's _unpack_int4_to_int8_pair/_pack_i32_pair_to_i64 that were lost during merge conflict resolution - Update moe_gemm_2stage.py mfma_fn calls to match new flydsl MFMA API: pass (res, a, b, c, cbsz, abid, blgp) as positional args instead of list, and access .result on the returned Operation Made-with: Cursor --- kernels/mfma_preshuffle_pipeline.py | 7 +------ kernels/moe_gemm_2stage.py | 28 ++++++++++++++-------------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index 28e228c0..1ec312e5 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -250,7 +250,6 @@ def _pack_i32_pair_to_i64(lo, hi, vector): return vector.extract(v64, static_position=[0], dynamic_position=[]) - def _i8x4_in_i32_to_bf16x4_i64(val_i32, arith, vector, scale_val=None): """Convert one i32 (4 signed int8 bytes) to 4 bf16 packed as i64. @@ -687,14 +686,10 @@ def lds_load_pack_k32( "make_preshuffle_b_layout", "make_preshuffle_scale_layout", "load_b_pack_k32", - "load_b_raw_w4a16", - "unpack_b_w4a16", - "load_b_raw_w4a16_groupwise", - "unpack_b_w4a16_groupwise", - "extract_bf16_scale", "split_row_major_2d", "swizzle_xor16", "tile_chunk_coord_i32", + "unpack_b_w4a16", ] diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index bce43ece..57dde7b1 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -841,23 +841,23 @@ def mfma_k64(acc_in, a0, a1, b0, b1): else: av = _i64x2_to_v8bf16(a0, a1) bv = _i64x2_to_v8bf16(b0, b1) - return mfma_fn(mfma_res_ty, [av, bv, acc_in, 0, 0, 0]) + return mfma_fn(mfma_res_ty, av, bv, acc_in, 0, 0, 0).result if is_f16: a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) b1v = _i64_to_v4f16(b1) - acc_mid = mfma_fn(mfma_res_ty, [a0v, b0v, acc_in, 0, 0, 0]) - return mfma_fn(mfma_res_ty, [a1v, b1v, acc_mid, 0, 0, 0]) + acc_mid = mfma_fn(mfma_res_ty, a0v, b0v, acc_in, 0, 0, 0).result + return mfma_fn(mfma_res_ty, a1v, b1v, acc_mid, 0, 0, 0).result if is_bf16: a0v = _i64_to_v4i16(a0) a1v = _i64_to_v4i16(a1) b0v = _i64_to_v4i16(b0) b1v = _i64_to_v4i16(b1) - acc_mid = mfma_fn(mfma_res_ty, [a0v, b0v, acc_in, 0, 0, 0]) - return mfma_fn(mfma_res_ty, [a1v, b1v, acc_mid, 0, 0, 0]) - acc_mid = mfma_fn(mfma_res_ty, [a0, b0, acc_in, 0, 0, 0]) - return mfma_fn(mfma_res_ty, [a1, b1, acc_mid, 0, 0, 0]) + acc_mid = mfma_fn(mfma_res_ty, a0v, b0v, acc_in, 0, 0, 0).result + return mfma_fn(mfma_res_ty, a1v, b1v, acc_mid, 0, 0, 0).result + acc_mid = mfma_fn(mfma_res_ty, a0, b0, acc_in, 0, 0, 0).result + return mfma_fn(mfma_res_ty, a1, b1, acc_mid, 0, 0, 0).result def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): """MFMA f32 partial -> scale -> add to f32 accumulator via math.fma on vector.""" @@ -2230,23 +2230,23 @@ def mfma_k64(acc0, a0, a1, b0, b1): else: av = _i64x2_to_v8bf16(a0, a1) bv = _i64x2_to_v8bf16(b0, b1) - return mfma_fn(mfma_res_ty, [av, bv, acc0, 0, 0, 0]) + return mfma_fn(mfma_res_ty, av, bv, acc0, 0, 0, 0).result if is_f16: a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) b1v = _i64_to_v4f16(b1) - acc1 = mfma_fn(mfma_res_ty, [a0v, b0v, acc0, 0, 0, 0]) - return mfma_fn(mfma_res_ty, [a1v, b1v, acc1, 0, 0, 0]) + acc1 = mfma_fn(mfma_res_ty, a0v, b0v, acc0, 0, 0, 0).result + return mfma_fn(mfma_res_ty, a1v, b1v, acc1, 0, 0, 0).result if is_bf16: a0v = _i64_to_v4i16(a0) a1v = _i64_to_v4i16(a1) b0v = _i64_to_v4i16(b0) b1v = _i64_to_v4i16(b1) - acc1 = mfma_fn(mfma_res_ty, [a0v, b0v, acc0, 0, 0, 0]) - return mfma_fn(mfma_res_ty, [a1v, b1v, acc1, 0, 0, 0]) - acc1 = mfma_fn(mfma_res_ty, [a0, b0, acc0, 0, 0, 0]) - return mfma_fn(mfma_res_ty, [a1, b1, acc1, 0, 0, 0]) + acc1 = mfma_fn(mfma_res_ty, a0v, b0v, acc0, 0, 0, 0).result + return mfma_fn(mfma_res_ty, a1v, b1v, acc1, 0, 0, 0).result + acc1 = mfma_fn(mfma_res_ty, a0, b0, acc0, 0, 0, 0).result + return mfma_fn(mfma_res_ty, a1, b1, acc1, 0, 0, 0).result def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): """MFMA f32 partial -> scale -> add to f32 accumulator via math.fma on vector.""" From b9aab3465f424fb5aa8817165ed70e0e6872b9d8 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Mon, 13 Apr 2026 12:50:19 +0000 Subject: [PATCH 3/3] Fix MFMA API compat: wrap new/old rocdl ops with unified mfma_fn New flydsl MFMA ops (mfma_f32_16x16x32_f16/bf16) use positional args and return Operation (need .result), while legacy ops (fp8, int8, k16) still use the old (res_type, [list]) calling convention. Introduce a thin mfma_fn wrapper at both call sites that dispatches correctly based on _use_mfma_k32, keeping all 14 call sites in list format: mfma_fn(res_ty, [a, b, c, 0, 0, 0]). Made-with: Cursor --- kernels/moe_gemm_2stage.py | 44 ++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 57dde7b1..73f5c5b2 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -781,9 +781,11 @@ def compute_tile( up_list = list(acc_up_in) mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 if _use_mfma_k32: - mfma_fn = rocdl.mfma_f32_16x16x32_f16 if is_f16 else rocdl.mfma_f32_16x16x32_bf16 + _raw_mfma = rocdl.mfma_f32_16x16x32_f16 if is_f16 else rocdl.mfma_f32_16x16x32_bf16 + def mfma_fn(res_ty, args): + return _raw_mfma(res_ty, *args).result else: - mfma_fn = ( + _raw_mfma = ( mfma_i32_k32 if is_int8 else ( @@ -792,6 +794,8 @@ def compute_tile( else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) ) ) + def mfma_fn(res_ty, args): + return _raw_mfma(res_ty, args) # Optional: prefetch epilogue scales while we are about to run the last MFMA tile, # matching the preshuffle GEMM pattern of overlapping scale loads with MFMA. @@ -841,23 +845,23 @@ def mfma_k64(acc_in, a0, a1, b0, b1): else: av = _i64x2_to_v8bf16(a0, a1) bv = _i64x2_to_v8bf16(b0, b1) - return mfma_fn(mfma_res_ty, av, bv, acc_in, 0, 0, 0).result + return mfma_fn(mfma_res_ty, [av, bv, acc_in, 0, 0, 0]) if is_f16: a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) b1v = _i64_to_v4f16(b1) - acc_mid = mfma_fn(mfma_res_ty, a0v, b0v, acc_in, 0, 0, 0).result - return mfma_fn(mfma_res_ty, a1v, b1v, acc_mid, 0, 0, 0).result + acc_mid = mfma_fn(mfma_res_ty, [a0v, b0v, acc_in, 0, 0, 0]) + return mfma_fn(mfma_res_ty, [a1v, b1v, acc_mid, 0, 0, 0]) if is_bf16: a0v = _i64_to_v4i16(a0) a1v = _i64_to_v4i16(a1) b0v = _i64_to_v4i16(b0) b1v = _i64_to_v4i16(b1) - acc_mid = mfma_fn(mfma_res_ty, a0v, b0v, acc_in, 0, 0, 0).result - return mfma_fn(mfma_res_ty, a1v, b1v, acc_mid, 0, 0, 0).result - acc_mid = mfma_fn(mfma_res_ty, a0, b0, acc_in, 0, 0, 0).result - return mfma_fn(mfma_res_ty, a1, b1, acc_mid, 0, 0, 0).result + acc_mid = mfma_fn(mfma_res_ty, [a0v, b0v, acc_in, 0, 0, 0]) + return mfma_fn(mfma_res_ty, [a1v, b1v, acc_mid, 0, 0, 0]) + acc_mid = mfma_fn(mfma_res_ty, [a0, b0, acc_in, 0, 0, 0]) + return mfma_fn(mfma_res_ty, [a1, b1, acc_mid, 0, 0, 0]) def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): """MFMA f32 partial -> scale -> add to f32 accumulator via math.fma on vector.""" @@ -2162,9 +2166,11 @@ def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False acc_list = list(acc_in) mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 if _use_mfma_k32: - mfma_fn = rocdl.mfma_f32_16x16x32_f16 if is_f16 else rocdl.mfma_f32_16x16x32_bf16 + _raw_mfma = rocdl.mfma_f32_16x16x32_f16 if is_f16 else rocdl.mfma_f32_16x16x32_bf16 + def mfma_fn(res_ty, args): + return _raw_mfma(res_ty, *args).result else: - mfma_fn = ( + _raw_mfma = ( mfma_i32_k32 if is_int8 else ( @@ -2173,6 +2179,8 @@ def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) ) ) + def mfma_fn(res_ty, args): + return _raw_mfma(res_ty, args) epilogue_pf = None if prefetch_epilogue and not use_groupwise_scale: @@ -2230,23 +2238,23 @@ def mfma_k64(acc0, a0, a1, b0, b1): else: av = _i64x2_to_v8bf16(a0, a1) bv = _i64x2_to_v8bf16(b0, b1) - return mfma_fn(mfma_res_ty, av, bv, acc0, 0, 0, 0).result + return mfma_fn(mfma_res_ty, [av, bv, acc0, 0, 0, 0]) if is_f16: a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) b1v = _i64_to_v4f16(b1) - acc1 = mfma_fn(mfma_res_ty, a0v, b0v, acc0, 0, 0, 0).result - return mfma_fn(mfma_res_ty, a1v, b1v, acc1, 0, 0, 0).result + acc1 = mfma_fn(mfma_res_ty, [a0v, b0v, acc0, 0, 0, 0]) + return mfma_fn(mfma_res_ty, [a1v, b1v, acc1, 0, 0, 0]) if is_bf16: a0v = _i64_to_v4i16(a0) a1v = _i64_to_v4i16(a1) b0v = _i64_to_v4i16(b0) b1v = _i64_to_v4i16(b1) - acc1 = mfma_fn(mfma_res_ty, a0v, b0v, acc0, 0, 0, 0).result - return mfma_fn(mfma_res_ty, a1v, b1v, acc1, 0, 0, 0).result - acc1 = mfma_fn(mfma_res_ty, a0, b0, acc0, 0, 0, 0).result - return mfma_fn(mfma_res_ty, a1, b1, acc1, 0, 0, 0).result + acc1 = mfma_fn(mfma_res_ty, [a0v, b0v, acc0, 0, 0, 0]) + return mfma_fn(mfma_res_ty, [a1v, b1v, acc1, 0, 0, 0]) + acc1 = mfma_fn(mfma_res_ty, [a0, b0, acc0, 0, 0, 0]) + return mfma_fn(mfma_res_ty, [a1, b1, acc1, 0, 0, 0]) def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): """MFMA f32 partial -> scale -> add to f32 accumulator via math.fma on vector."""