Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 214 additions & 16 deletions kernels/moe_gemm_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def compile_moe_gemm1(
out_dtype: str = "f16",
use_cshuffle_epilog: bool | None = None,
scale_is_bf16: bool = False,
k_batch: int = 1,
):
"""Compile stage1 kernel (`moe_gemm1`) and return the compiled executable.

Expand All @@ -123,6 +124,8 @@ def compile_moe_gemm1(
- "int4": W4A8 path: X is int8, W is packed int4 (2 values per byte) unpacked to int8 in-kernel
- "int4_bf16": W4A16 path: X is bf16, W is packed int4 unpacked to bf16 in-kernel
scale_is_bf16: When True, groupwise scales are bf16 (halves scale bandwidth).
k_batch: Split-K factor. When >1, K is partitioned across k_batch CTAs that
atomically accumulate gate/up partials. Caller must pre-zero output.
"""

gpu_arch = get_hip_arch()
Expand Down Expand Up @@ -185,6 +188,21 @@ def compile_moe_gemm1(
_is_gfx950 = "gfx95" in get_hip_arch()
use_gfx950_cvt = is_int4_bf16 and _is_gfx950

# Split-K validation
_is_splitk = k_batch > 1
if _is_splitk:
_k_per_batch = model_dim // k_batch
assert model_dim % k_batch == 0, f"model_dim={model_dim} not divisible by k_batch={k_batch}"
assert _k_per_batch % tile_k == 0, f"K_per_batch={_k_per_batch} not divisible by tile_k={tile_k}"
# The ping-pong K-loop requires an even number of K tiles (>=4).
_k_tiles = _k_per_batch // tile_k
assert _k_tiles >= 4 and _k_tiles % 2 == 0, (
f"K_per_batch/tile_k={_k_tiles} must be even and >=4 for the ping-pong pipeline. "
f"Try a different k_batch (model_dim={model_dim}, tile_k={tile_k})."
)
else:
_k_per_batch = model_dim

mfma_i32_k32 = None
if is_int8:
mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr(
Expand Down Expand Up @@ -239,18 +257,20 @@ def compile_moe_gemm1(
if use_cshuffle_epilog is None:
use_cshuffle_epilog = os.environ.get("FLYDSL_MOE_STAGE1_CSHUFFLE", "1") in ("1", "true", "True", "YES", "yes")
use_cshuffle_epilog = bool(use_cshuffle_epilog)
if out_dtype != "f16" and use_cshuffle_epilog:
# Split-K uses f32 atomic CShuffle regardless of out_dtype, so skip this check.
if out_dtype != "f16" and use_cshuffle_epilog and not _is_splitk:
raise ValueError("stage1 cshuffle epilog currently supports only f16 output (out_dtype='f16')")

epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct"
# IMPORTANT: module name participates in FlyDSL's compile cache key.
# Keep an explicit ABI tag so signature changes can't accidentally reuse an old binary.
_gs_tag = f"_g{group_size}" if use_groupwise_scale else ""
scale_tag = "_sbf16" if _scale_is_bf16 else ""
_split_k_tag = f"_splitk{k_batch}" if _is_splitk else ""
module_name = (
f"mfma_moe1_{in_dtype}_{out_dtype}_{epilog_tag}"
f"_t{tile_m}x{tile_n}x{tile_k}"
f"{_gs_tag}{scale_tag}"
f"{_gs_tag}{scale_tag}{_split_k_tag}"
f"_abi3" # also mask sentinel token ids on loads (X/scale_x) to avoid illegal address faults
).replace("-", "_")

Expand All @@ -259,8 +279,12 @@ def compile_moe_gemm1(
# - ping-pong X tiles (2 * tile_m * lds_stride bytes)
# - optional epilogue CShuffle tile (tile_m * tile_n f16 -> 2 * tile_m * tile_n bytes)
_use_cshuffle_epilog = bool(use_cshuffle_epilog)
# Split-K requires CShuffle epilogue (f32 atomic adds via store_pair callback)
if _is_splitk:
_use_cshuffle_epilog = True
_cshuffle_elem_bytes = 4 if _is_splitk else 2 # f32 for split-K, f16 otherwise
lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(elem_bytes)
lds_out_bytes = 2 * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0
lds_out_bytes = _cshuffle_elem_bytes * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0
lds_total_bytes = max(lds_x_bytes, lds_out_bytes)
lds_total_elems = lds_total_bytes if elem_bytes == 1 else (lds_total_bytes // 2)

Expand Down Expand Up @@ -350,6 +374,12 @@ def silu(x):
by = gpu.block_id("x") # tile along inter_dim
bx = gpu.block_id("y") # tile along sorted M

if _is_splitk:
bz = gpu.block_id("z") # K-batch id
k_base_idx = bz * arith.index(_k_per_batch)
else:
k_base_idx = arith.index(0)

# Block validity: compute as early as possible so invalid blocks skip all buffer-resource
# setup, LDS pointer math, and gmem prefetch work.
bx_m = bx * fx.Index(tile_m)
Expand Down Expand Up @@ -381,9 +411,11 @@ def silu(x):
shape=(lds_total_elems,),
)
lds_x = lds_x_ptr.get()
# Alias LDS bytes as fp16 for optional CShuffle epilogue.
# Alias LDS bytes for optional CShuffle epilogue.
# Split-K uses f32 (4B) per element for atomic accumulation; normal uses f16 (2B).
_lds_out_elem_type = T.f32 if _is_splitk else T.f16
lds_out = (
SmemPtr(base_ptr, lds_x_ptr.byte_offset, T.f16, shape=(tile_m * tile_n,)).get()
SmemPtr(base_ptr, lds_x_ptr.byte_offset, _lds_out_elem_type, shape=(tile_m * tile_n,)).get()
if _use_cshuffle_epilog
else None
)
Expand All @@ -401,9 +433,12 @@ def silu(x):

w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False)

# OUT: [tokens, topk, inter] f16/bf16 -> bytes = tokens*topk*inter*out_elem_bytes
out_elem_bytes = 2 # f16/bf16
out_nbytes_idx = tokens_in * c_topk * inter_in * fx.Index(out_elem_bytes)
# OUT: normal=[tokens, topk, inter] f16/bf16, split-K=[tokens*topk, 2*inter] f32
out_elem_bytes = 4 if _is_splitk else 2
if _is_splitk:
out_nbytes_idx = tokens_in * c_topk * inter_in * fx.Index(2 * out_elem_bytes)
else:
out_nbytes_idx = tokens_in * c_topk * inter_in * fx.Index(out_elem_bytes)
out_rsrc = buffer_ops.create_buffer_resource(
arg_out, max_size=False, num_records_bytes=out_nbytes_idx
)
Expand Down Expand Up @@ -992,26 +1027,26 @@ def hot_loop_scheduler():
rocdl.sched_barrier(0)

# Prologue: prefetch tile0, store to LDS(cur), sync.
k0 = fx.Index(0)
k0 = k_base_idx
x_regs0 = load_x_tile(k0)
b_gate_cur = load_b_tile(k0, n_blk_gate, n_intra_gate)
b_up_cur = load_b_tile(k0, n_blk_up, n_intra_up)
store_x_tile_to_lds(x_regs0, lds_base_cur)
gpu.barrier()

# Loop-carried ping/pong state.
lds_base_pong = lds_base_cur # current/compute
lds_base_ping = lds_base_nxt # next/load+store

# Cross-tile A0 LDS prefetch (default-on): prefetch the first A-pack (K64) for the
# tile we are about to compute from LDS, to overlap with upcoming VMEM.
a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong)

# Ping-pong main loop (2 tiles per iteration), leaving 2 tail tiles.
# Uses scf.for with loop-carried accumulators, B-tile prefetch, and A0 LDS prefetch.
c2_tile_k = arith.index(tile_k * 2)
c_tile_k = arith.index(tile_k)
total_tiles = int(model_dim) // int(tile_k)
total_tiles = int(_k_per_batch) // int(tile_k)
pair_iters = max((total_tiles - 2) // 2, 0)

# B-tile data layout per k_unroll entry (3 variants):
Expand Down Expand Up @@ -1091,7 +1126,7 @@ def _unflatten_b_tile(vals):
_bu = _unflatten_b_tile(list(state[_p_bu:_p_a0]))
_a0pf = (state[_p_a0], state[_p_a0 + 1])

k_iv = pair_iv * (c_tile_k + c_tile_k)
k_iv = k_base_idx + pair_iv * (c_tile_k + c_tile_k)

# ---- stage 0: prefetch+store ping, compute pong ----
next_k1 = k_iv + c_tile_k
Expand Down Expand Up @@ -1133,7 +1168,7 @@ def _unflatten_b_tile(vals):
b_gate_cur = _unflatten_b_tile(list(loop_results[_p_bg:_p_bu]))
b_up_cur = _unflatten_b_tile(list(loop_results[_p_bu:_p_a0]))
a0_prefetch_pong = (loop_results[_p_a0], loop_results[_p_a0 + 1])
k_tail1 = k_in - tile_k
k_tail1 = k_base_idx + arith.index(_k_per_batch - tile_k)
x_regs_ping = load_x_tile(k_tail1)
b_gate_ping = load_b_tile(k_tail1, n_blk_gate, n_intra_gate)
b_up_ping = load_b_tile(k_tail1, n_blk_up, n_intra_up)
Expand Down Expand Up @@ -1214,6 +1249,169 @@ def _unflatten_b_tile(vals):
# Uses EVec=4 (buffer store "x4" of fp16 elements).
use_cshuffle_epilog_flag = _use_cshuffle_epilog

# ─── Split-K epilogue: two-pass gate/up with f32 atomic fadd ───
if _is_splitk:
if lds_out is None:
raise RuntimeError("Split-K epilogue requires lds_out (CShuffle)")

out_base_idx = buffer_ops.extract_base_index(arg_out)
_split_k_out_row_stride = inter_dim * 2 * out_elem_bytes # bytes per row
_split_k_e_vec = 2 # f32 vec2 for atomic fadd

# Mutable slot: 0 for gate pass, inter_dim for up pass
_split_k_n_offset = [0]

# Mutable slots for two-pass gate/up selection
_split_k_acc = [acc_gate]
_split_k_sw_vals = [sw_gate_vals]

def write_row_to_lds_splitk(
*,
mi: int,
ii: int,
row_in_tile,
row,
row_base_lds,
col_base_local,
num_acc_n: int,
lds_out,
):
"""Write scaled f32 partial sums to LDS (no silu, no doweight)."""
_acc = _split_k_acc[0]
_sw = _split_k_sw_vals[0]
# Load per-row scale_x (sx) — same logic as normal epilogue.
fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=T.i32)
t2 = fused2 & mask24_i32
t_valid = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32_v)
if x_is_token_slot:
s2 = fused2 >> 24
ts2 = s2 * tokens_i32_v + t2
sx = (
fx.Float32(1.0)
if is_f16_or_bf16
else arith.select(
t_valid,
buffer_ops.buffer_load(sx_rsrc, ts2, vec_width=1, dtype=T.f32),
fx.Float32(0.0),
)
)
else:
sx = (
fx.Float32(1.0)
if is_f16_or_bf16
else arith.select(
t_valid,
buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=T.f32),
fx.Float32(0.0),
)
)
for ni in range_constexpr(num_acc_n):
col_local = col_base_local + (ni * 16)
acc_idx = mi * num_acc_n + ni
v = vector.extract(
_acc[acc_idx], static_position=[ii], dynamic_position=[]
)
if is_int8:
v = arith.sitofp(T.f32, v)
v = v * sx * _sw[ni]
lds_idx = row_base_lds + col_local
v1 = vector.from_elements(T.vec(1, T.f32), [v])
vector.store(v1, lds_out, [lds_idx], alignment=4)

def precompute_row_splitk(*, row_local, row):
fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=T.i32)
t2 = fused2 & mask24_i32
s2 = fused2 >> 24
t_ok = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32_v)
t_idx = arith.index_cast(T.index, t2)
s_idx = arith.index_cast(T.index, s2)
ts_idx = t_idx * arith.index(topk) + s_idx
row_byte_base = out_base_idx + ts_idx * arith.index(_split_k_out_row_stride)
return (row_byte_base, t_ok)

def store_pair_splitk(*, row_local, row, row_ctx, col_pair0, col_g0, frag):
row_byte_base = row_ctx
col_idx = col_g0 + arith.index(_split_k_n_offset[0])
byte_off_col = col_idx * arith.index(out_elem_bytes)
ptr_addr_idx = row_byte_base + byte_off_col
out_ptr = buffer_ops.create_llvm_ptr(ptr_addr_idx, address_space=1)
out_ptr_v = out_ptr._value if hasattr(out_ptr, "_value") else out_ptr
frag_v = frag._value if hasattr(frag, "_value") else frag
llvm.AtomicRMWOp(
llvm.AtomicBinOp.fadd,
out_ptr_v,
frag_v,
llvm.AtomicOrdering.monotonic,
syncscope="agent",
alignment=_split_k_e_vec * out_elem_bytes,
)

_cshuffle_nlane_splitk = min(32, tile_n // _split_k_e_vec)
_splitk_frag_elem = ir.F32Type.get()

# Pass 1: gate (offset=0)
_split_k_acc[0] = acc_gate
_split_k_sw_vals[0] = sw_gate_vals
_split_k_n_offset[0] = 0
c_shuffle_epilog(
arith=arith,
vector=vector,
gpu=gpu,
scf=scf,
range_constexpr=range_constexpr,
tile_m=tile_m,
tile_n=tile_n,
e_vec=_split_k_e_vec,
cshuffle_nlane=_cshuffle_nlane_splitk,
block_size=total_threads,
m_repeat=m_repeat,
num_acc_n=num_acc_n,
tx=tx,
lane_div_16=lane_div_16,
lane_mod_16=lane_mod_16,
bx_m=bx_m,
by_n=by_n,
n_tile_base=n_tile_base,
lds_out=lds_out,
frag_elem_type=_splitk_frag_elem,
write_row_to_lds=write_row_to_lds_splitk,
precompute_row=precompute_row_splitk,
store_pair=store_pair_splitk,
)

gpu.barrier()

# Pass 2: up (offset=inter_dim)
_split_k_acc[0] = acc_up
_split_k_sw_vals[0] = sw_up_vals
_split_k_n_offset[0] = inter_dim
c_shuffle_epilog(
arith=arith,
vector=vector,
gpu=gpu,
scf=scf,
range_constexpr=range_constexpr,
tile_m=tile_m,
tile_n=tile_n,
e_vec=_split_k_e_vec,
cshuffle_nlane=_cshuffle_nlane_splitk,
block_size=total_threads,
m_repeat=m_repeat,
num_acc_n=num_acc_n,
tx=tx,
lane_div_16=lane_div_16,
lane_mod_16=lane_mod_16,
bx_m=bx_m,
by_n=by_n,
n_tile_base=n_tile_base,
lds_out=lds_out,
frag_elem_type=_splitk_frag_elem,
write_row_to_lds=write_row_to_lds_splitk,
precompute_row=precompute_row_splitk,
store_pair=store_pair_splitk,
)
return

if use_cshuffle_epilog_flag:
if lds_out is None:
raise RuntimeError("CShuffle epilogue enabled but lds_out is not allocated/aliased.")
Expand Down Expand Up @@ -1463,7 +1661,7 @@ def launch_moe_gemm1(
i32_k_in,
i32_size_expert_ids_in,
).launch(
grid=(gx, gy, 1),
grid=(gx, gy, k_batch),
block=(256, 1, 1),
stream=stream,
)
Expand Down
Loading
Loading