From 792d17ff4c19bdde67aeb3ae68539fbe1e659d12 Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 13 Apr 2026 07:50:06 +0000 Subject: [PATCH] support split-k algo for moe_gemm_2stage --- kernels/moe_gemm_2stage.py | 230 ++++++++++++++++++++++++++++++--- tests/kernels/test_moe_gemm.py | 22 +++- 2 files changed, 233 insertions(+), 19 deletions(-) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index bce43ece..06eddb24 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -109,6 +109,7 @@ def compile_moe_gemm1( out_dtype: str = "f16", use_cshuffle_epilog: bool | None = None, scale_is_bf16: bool = False, + k_batch: int = 1, ): """Compile stage1 kernel (`moe_gemm1`) and return the compiled executable. @@ -123,6 +124,8 @@ def compile_moe_gemm1( - "int4": W4A8 path: X is int8, W is packed int4 (2 values per byte) unpacked to int8 in-kernel - "int4_bf16": W4A16 path: X is bf16, W is packed int4 unpacked to bf16 in-kernel scale_is_bf16: When True, groupwise scales are bf16 (halves scale bandwidth). + k_batch: Split-K factor. When >1, K is partitioned across k_batch CTAs that + atomically accumulate gate/up partials. Caller must pre-zero output. """ gpu_arch = get_hip_arch() @@ -185,6 +188,21 @@ def compile_moe_gemm1( _is_gfx950 = "gfx95" in get_hip_arch() use_gfx950_cvt = is_int4_bf16 and _is_gfx950 + # Split-K validation + _is_splitk = k_batch > 1 + if _is_splitk: + _k_per_batch = model_dim // k_batch + assert model_dim % k_batch == 0, f"model_dim={model_dim} not divisible by k_batch={k_batch}" + assert _k_per_batch % tile_k == 0, f"K_per_batch={_k_per_batch} not divisible by tile_k={tile_k}" + # The ping-pong K-loop requires an even number of K tiles (>=4). + _k_tiles = _k_per_batch // tile_k + assert _k_tiles >= 4 and _k_tiles % 2 == 0, ( + f"K_per_batch/tile_k={_k_tiles} must be even and >=4 for the ping-pong pipeline. " + f"Try a different k_batch (model_dim={model_dim}, tile_k={tile_k})." + ) + else: + _k_per_batch = model_dim + mfma_i32_k32 = None if is_int8: mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( @@ -239,7 +257,8 @@ def compile_moe_gemm1( if use_cshuffle_epilog is None: use_cshuffle_epilog = os.environ.get("FLYDSL_MOE_STAGE1_CSHUFFLE", "1") in ("1", "true", "True", "YES", "yes") use_cshuffle_epilog = bool(use_cshuffle_epilog) - if out_dtype != "f16" and use_cshuffle_epilog: + # Split-K uses f32 atomic CShuffle regardless of out_dtype, so skip this check. + if out_dtype != "f16" and use_cshuffle_epilog and not _is_splitk: raise ValueError("stage1 cshuffle epilog currently supports only f16 output (out_dtype='f16')") epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" @@ -247,10 +266,11 @@ def compile_moe_gemm1( # Keep an explicit ABI tag so signature changes can't accidentally reuse an old binary. _gs_tag = f"_g{group_size}" if use_groupwise_scale else "" scale_tag = "_sbf16" if _scale_is_bf16 else "" + _split_k_tag = f"_splitk{k_batch}" if _is_splitk else "" module_name = ( f"mfma_moe1_{in_dtype}_{out_dtype}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" - f"{_gs_tag}{scale_tag}" + f"{_gs_tag}{scale_tag}{_split_k_tag}" f"_abi3" # also mask sentinel token ids on loads (X/scale_x) to avoid illegal address faults ).replace("-", "_") @@ -259,8 +279,12 @@ def compile_moe_gemm1( # - ping-pong X tiles (2 * tile_m * lds_stride bytes) # - optional epilogue CShuffle tile (tile_m * tile_n f16 -> 2 * tile_m * tile_n bytes) _use_cshuffle_epilog = bool(use_cshuffle_epilog) + # Split-K requires CShuffle epilogue (f32 atomic adds via store_pair callback) + if _is_splitk: + _use_cshuffle_epilog = True + _cshuffle_elem_bytes = 4 if _is_splitk else 2 # f32 for split-K, f16 otherwise lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(elem_bytes) - lds_out_bytes = 2 * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 + lds_out_bytes = _cshuffle_elem_bytes * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 lds_total_bytes = max(lds_x_bytes, lds_out_bytes) lds_total_elems = lds_total_bytes if elem_bytes == 1 else (lds_total_bytes // 2) @@ -350,6 +374,12 @@ def silu(x): by = gpu.block_id("x") # tile along inter_dim bx = gpu.block_id("y") # tile along sorted M + if _is_splitk: + bz = gpu.block_id("z") # K-batch id + k_base_idx = bz * arith.index(_k_per_batch) + else: + k_base_idx = arith.index(0) + # Block validity: compute as early as possible so invalid blocks skip all buffer-resource # setup, LDS pointer math, and gmem prefetch work. bx_m = bx * fx.Index(tile_m) @@ -381,9 +411,11 @@ def silu(x): shape=(lds_total_elems,), ) lds_x = lds_x_ptr.get() - # Alias LDS bytes as fp16 for optional CShuffle epilogue. + # Alias LDS bytes for optional CShuffle epilogue. + # Split-K uses f32 (4B) per element for atomic accumulation; normal uses f16 (2B). + _lds_out_elem_type = T.f32 if _is_splitk else T.f16 lds_out = ( - SmemPtr(base_ptr, lds_x_ptr.byte_offset, T.f16, shape=(tile_m * tile_n,)).get() + SmemPtr(base_ptr, lds_x_ptr.byte_offset, _lds_out_elem_type, shape=(tile_m * tile_n,)).get() if _use_cshuffle_epilog else None ) @@ -401,9 +433,12 @@ def silu(x): w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) - # OUT: [tokens, topk, inter] f16/bf16 -> bytes = tokens*topk*inter*out_elem_bytes - out_elem_bytes = 2 # f16/bf16 - out_nbytes_idx = tokens_in * c_topk * inter_in * fx.Index(out_elem_bytes) + # OUT: normal=[tokens, topk, inter] f16/bf16, split-K=[tokens*topk, 2*inter] f32 + out_elem_bytes = 4 if _is_splitk else 2 + if _is_splitk: + out_nbytes_idx = tokens_in * c_topk * inter_in * fx.Index(2 * out_elem_bytes) + else: + out_nbytes_idx = tokens_in * c_topk * inter_in * fx.Index(out_elem_bytes) out_rsrc = buffer_ops.create_buffer_resource( arg_out, max_size=False, num_records_bytes=out_nbytes_idx ) @@ -992,26 +1027,26 @@ def hot_loop_scheduler(): rocdl.sched_barrier(0) # Prologue: prefetch tile0, store to LDS(cur), sync. - k0 = fx.Index(0) + k0 = k_base_idx x_regs0 = load_x_tile(k0) b_gate_cur = load_b_tile(k0, n_blk_gate, n_intra_gate) b_up_cur = load_b_tile(k0, n_blk_up, n_intra_up) store_x_tile_to_lds(x_regs0, lds_base_cur) gpu.barrier() - + # Loop-carried ping/pong state. lds_base_pong = lds_base_cur # current/compute lds_base_ping = lds_base_nxt # next/load+store - + # Cross-tile A0 LDS prefetch (default-on): prefetch the first A-pack (K64) for the # tile we are about to compute from LDS, to overlap with upcoming VMEM. a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) - + # Ping-pong main loop (2 tiles per iteration), leaving 2 tail tiles. # Uses scf.for with loop-carried accumulators, B-tile prefetch, and A0 LDS prefetch. c2_tile_k = arith.index(tile_k * 2) c_tile_k = arith.index(tile_k) - total_tiles = int(model_dim) // int(tile_k) + total_tiles = int(_k_per_batch) // int(tile_k) pair_iters = max((total_tiles - 2) // 2, 0) # B-tile data layout per k_unroll entry (3 variants): @@ -1091,7 +1126,7 @@ def _unflatten_b_tile(vals): _bu = _unflatten_b_tile(list(state[_p_bu:_p_a0])) _a0pf = (state[_p_a0], state[_p_a0 + 1]) - k_iv = pair_iv * (c_tile_k + c_tile_k) + k_iv = k_base_idx + pair_iv * (c_tile_k + c_tile_k) # ---- stage 0: prefetch+store ping, compute pong ---- next_k1 = k_iv + c_tile_k @@ -1133,7 +1168,7 @@ def _unflatten_b_tile(vals): b_gate_cur = _unflatten_b_tile(list(loop_results[_p_bg:_p_bu])) b_up_cur = _unflatten_b_tile(list(loop_results[_p_bu:_p_a0])) a0_prefetch_pong = (loop_results[_p_a0], loop_results[_p_a0 + 1]) - k_tail1 = k_in - tile_k + k_tail1 = k_base_idx + arith.index(_k_per_batch - tile_k) x_regs_ping = load_x_tile(k_tail1) b_gate_ping = load_b_tile(k_tail1, n_blk_gate, n_intra_gate) b_up_ping = load_b_tile(k_tail1, n_blk_up, n_intra_up) @@ -1214,6 +1249,169 @@ def _unflatten_b_tile(vals): # Uses EVec=4 (buffer store "x4" of fp16 elements). use_cshuffle_epilog_flag = _use_cshuffle_epilog + # ─── Split-K epilogue: two-pass gate/up with f32 atomic fadd ─── + if _is_splitk: + if lds_out is None: + raise RuntimeError("Split-K epilogue requires lds_out (CShuffle)") + + out_base_idx = buffer_ops.extract_base_index(arg_out) + _split_k_out_row_stride = inter_dim * 2 * out_elem_bytes # bytes per row + _split_k_e_vec = 2 # f32 vec2 for atomic fadd + + # Mutable slot: 0 for gate pass, inter_dim for up pass + _split_k_n_offset = [0] + + # Mutable slots for two-pass gate/up selection + _split_k_acc = [acc_gate] + _split_k_sw_vals = [sw_gate_vals] + + def write_row_to_lds_splitk( + *, + mi: int, + ii: int, + row_in_tile, + row, + row_base_lds, + col_base_local, + num_acc_n: int, + lds_out, + ): + """Write scaled f32 partial sums to LDS (no silu, no doweight).""" + _acc = _split_k_acc[0] + _sw = _split_k_sw_vals[0] + # Load per-row scale_x (sx) — same logic as normal epilogue. + fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=T.i32) + t2 = fused2 & mask24_i32 + t_valid = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32_v) + if x_is_token_slot: + s2 = fused2 >> 24 + ts2 = s2 * tokens_i32_v + t2 + sx = ( + fx.Float32(1.0) + if is_f16_or_bf16 + else arith.select( + t_valid, + buffer_ops.buffer_load(sx_rsrc, ts2, vec_width=1, dtype=T.f32), + fx.Float32(0.0), + ) + ) + else: + sx = ( + fx.Float32(1.0) + if is_f16_or_bf16 + else arith.select( + t_valid, + buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=T.f32), + fx.Float32(0.0), + ) + ) + for ni in range_constexpr(num_acc_n): + col_local = col_base_local + (ni * 16) + acc_idx = mi * num_acc_n + ni + v = vector.extract( + _acc[acc_idx], static_position=[ii], dynamic_position=[] + ) + if is_int8: + v = arith.sitofp(T.f32, v) + v = v * sx * _sw[ni] + lds_idx = row_base_lds + col_local + v1 = vector.from_elements(T.vec(1, T.f32), [v]) + vector.store(v1, lds_out, [lds_idx], alignment=4) + + def precompute_row_splitk(*, row_local, row): + fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=T.i32) + t2 = fused2 & mask24_i32 + s2 = fused2 >> 24 + t_ok = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32_v) + t_idx = arith.index_cast(T.index, t2) + s_idx = arith.index_cast(T.index, s2) + ts_idx = t_idx * arith.index(topk) + s_idx + row_byte_base = out_base_idx + ts_idx * arith.index(_split_k_out_row_stride) + return (row_byte_base, t_ok) + + def store_pair_splitk(*, row_local, row, row_ctx, col_pair0, col_g0, frag): + row_byte_base = row_ctx + col_idx = col_g0 + arith.index(_split_k_n_offset[0]) + byte_off_col = col_idx * arith.index(out_elem_bytes) + ptr_addr_idx = row_byte_base + byte_off_col + out_ptr = buffer_ops.create_llvm_ptr(ptr_addr_idx, address_space=1) + out_ptr_v = out_ptr._value if hasattr(out_ptr, "_value") else out_ptr + frag_v = frag._value if hasattr(frag, "_value") else frag + llvm.AtomicRMWOp( + llvm.AtomicBinOp.fadd, + out_ptr_v, + frag_v, + llvm.AtomicOrdering.monotonic, + syncscope="agent", + alignment=_split_k_e_vec * out_elem_bytes, + ) + + _cshuffle_nlane_splitk = min(32, tile_n // _split_k_e_vec) + _splitk_frag_elem = ir.F32Type.get() + + # Pass 1: gate (offset=0) + _split_k_acc[0] = acc_gate + _split_k_sw_vals[0] = sw_gate_vals + _split_k_n_offset[0] = 0 + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_split_k_e_vec, + cshuffle_nlane=_cshuffle_nlane_splitk, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_splitk_frag_elem, + write_row_to_lds=write_row_to_lds_splitk, + precompute_row=precompute_row_splitk, + store_pair=store_pair_splitk, + ) + + gpu.barrier() + + # Pass 2: up (offset=inter_dim) + _split_k_acc[0] = acc_up + _split_k_sw_vals[0] = sw_up_vals + _split_k_n_offset[0] = inter_dim + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_split_k_e_vec, + cshuffle_nlane=_cshuffle_nlane_splitk, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_splitk_frag_elem, + write_row_to_lds=write_row_to_lds_splitk, + precompute_row=precompute_row_splitk, + store_pair=store_pair_splitk, + ) + return + if use_cshuffle_epilog_flag: if lds_out is None: raise RuntimeError("CShuffle epilogue enabled but lds_out is not allocated/aliased.") @@ -1463,7 +1661,7 @@ def launch_moe_gemm1( i32_k_in, i32_size_expert_ids_in, ).launch( - grid=(gx, gy, 1), + grid=(gx, gy, k_batch), block=(256, 1, 1), stream=stream, ) diff --git a/tests/kernels/test_moe_gemm.py b/tests/kernels/test_moe_gemm.py index ca819bba..b979b3fe 100644 --- a/tests/kernels/test_moe_gemm.py +++ b/tests/kernels/test_moe_gemm.py @@ -336,6 +336,7 @@ def run_moe_stage1( scale_dtype: str = "f32", even_dispatch: bool = False, out_dtype: str = "f16", + k_batch: int = 1, ): assert model_dim % 64 == 0 assert model_dim % tile_k == 0 @@ -562,9 +563,13 @@ def run_moe_stage1( scale_w1_1d = scale_w1_flat.view(-1).contiguous() sorted_weights_1d = sorted_weights.contiguous().view(-1) # [sorted_size] - # Output: [tokens, topk, inter_dim] + # Output: normal=[tokens, topk, inter_dim] f16/bf16, split-K=[tokens*topk, 2*inter_dim] f32 _out_torch_dtype = torch.bfloat16 if out_dtype == "bf16" else torch.float16 - out = torch.empty((tokens, topk, inter_dim), device=device, dtype=_out_torch_dtype) + _is_splitk = k_batch > 1 + if _is_splitk: + out = torch.zeros((tokens * topk, 2 * inter_dim), device=device, dtype=torch.float32) + else: + out = torch.empty((tokens, topk, inter_dim), device=device, dtype=_out_torch_dtype) if is_fp4: exe = compile_mixed_moe_gemm1( @@ -605,9 +610,10 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): tile_n=tile_n, tile_k=tile_k, doweight_stage1=bool(doweight_stage1), - use_cshuffle_epilog=False, + use_cshuffle_epilog=None if _is_splitk else False, scale_is_bf16=(scale_dtype == "bf16"), out_dtype=out_dtype, + k_batch=k_batch, ) def _s1_args(o, x, w, sx, sw, st, eids, sw_sorted): @@ -618,6 +624,8 @@ def _s1_args(o, x, w, sx, sw, st, eids, sw_sorted): compiled_exe = flyc.compile(exe, *_s1_args(out, x_q, w_kernel, scale_x_1d, scale_w1_1d, sorted_token_ids, sorted_expert_ids, sorted_weights_1d)) def launch(o, x, w, sx, sw, st, eids, sw_sorted): + if _is_splitk: + o.zero_() compiled_exe(*_s1_args(o, x, w, sx, sw, st, eids, sw_sorted)) _, us = run_perftest( @@ -636,6 +644,14 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): ) torch.cuda.synchronize() + # Split-K post-processing: apply silu(gate)*up on host, reshape to [tokens, topk, inter_dim] + # Note: the gfx950 v_cvt_off_f32_i4 x16 correction is already applied per-CTA in the kernel + # epilogue (linear factor commutes with summation: sum(x_i*16) = 16*sum(x_i)). + if _is_splitk: + gate = out[:, :inter_dim] # [tokens*topk, inter_dim] f32 + up = out[:, inter_dim:] # [tokens*topk, inter_dim] f32 + out = (torch.nn.functional.silu(gate) * up).to(_out_torch_dtype).view(tokens, topk, inter_dim) + if not bool(skip_ref): if is_int8smooth: # x_q is slot-major [topk, tokens, K]; convert to [tokens, topk, K] for ref.