Skip to content

[Perf] Port mixed_moe kernel optimizations for stage1/stage2#388

Open
lalala-sh wants to merge 4 commits intomainfrom
port-aiter-mixed-moe
Open

[Perf] Port mixed_moe kernel optimizations for stage1/stage2#388
lalala-sh wants to merge 4 commits intomainfrom
port-aiter-mixed-moe

Conversation

@lalala-sh
Copy link
Copy Markdown
Contributor

@lalala-sh lalala-sh commented Apr 13, 2026

Motivation

Key changes:

  • mixed_moe_gemm_2stage.py: Full kernel body replacement with aiter version
    featuring dual SmemAllocator (ping-pong), unified MFMA pipeline schedule,
    _barrier() for fine-grained waitcnt control, and new parameters (persist_m,
    fuse_fp4_quant, fuse_sort_scale, use_async_copy, sort_block_m, etc.)
  • layout_utils.py: New file ported from aiter for layout index arithmetic
    (crd2idx, idx2crd, _div_pow2, _mod_pow2)
  • silu_and_mul_fq.py: New file ported from aiter for split-K + fp4 quant
    after silu fusion
  • mfma_preshuffle_pipeline.py: Added k_major support, cache_modifier param,
    bitwise-AND optimization in swizzle_xor16, PreshuffleScaleLayout additions
  • kernels_common.py: Extracted shared _if_then context manager and
    validate_moe_dtypes helper
  • mfma_epilogues.py: Replaced local _if_then with shared import

Performance (DeepSeek TP8 FP4, 7168x256, E=256, K=8):

  • Stage1 Decode t=1: 37.3 -> 26.2 us (-29.8%)
  • Stage1 Decode t=8: 45.0 -> 31.0 us (-31.1%)
  • Stage1 Prefill 8K: 561.8 -> 348.8 us (-37.9%)
  • Stage2 Prefill 8K reduce: 569.1 -> 534.8 us (-6.0%)

Technical Details

Test Plan

Test Result

Submission Checklist

Port performance-critical optimizations from aiter's mixed_moe_gemm_2stage
kernel body (both stage1 and stage2) into FlyDSL, along with supporting
infrastructure changes.

Key changes:
- mixed_moe_gemm_2stage.py: Full kernel body replacement with aiter version
  featuring dual SmemAllocator (ping-pong), unified MFMA pipeline schedule,
  _barrier() for fine-grained waitcnt control, and new parameters (persist_m,
  fuse_fp4_quant, fuse_sort_scale, use_async_copy, sort_block_m, etc.)
- layout_utils.py: New file ported from aiter for layout index arithmetic
  (crd2idx, idx2crd, _div_pow2, _mod_pow2)
- silu_and_mul_fq.py: New file ported from aiter for split-K + fp4 quant
  after silu fusion
- mfma_preshuffle_pipeline.py: Added k_major support, cache_modifier param,
  bitwise-AND optimization in swizzle_xor16, PreshuffleScaleLayout additions
- kernels_common.py: Extracted shared _if_then context manager and
  validate_moe_dtypes helper
- mfma_epilogues.py: Replaced local _if_then with shared import

Performance (DeepSeek TP8 FP4, 7168x256, E=256, K=8):
- Stage1 Decode t=1: 37.3 -> 26.2 us (-29.8%)
- Stage1 Decode t=8: 45.0 -> 31.0 us (-31.1%)
- Stage1 Prefill 8K: 561.8 -> 348.8 us (-37.9%)
- Stage2 Prefill 8K reduce: 569.1 -> 534.8 us (-6.0%)
- FP8 stage2 unchanged (within noise)

Made-with: Cursor
Resolve conflicts in mfma_preshuffle_pipeline.py:
- Keep PreshuffleBLayout with k_major support (ours)
- Keep new _unpack_int4_to_int8_pair, _pack_i32_pair_to_i64, and
  groupwise W4A16 functions (theirs)
- Merge __all__ exports from both sides

Made-with: Cursor
- Restore PreshuffleBLayout/make_preshuffle_b_layout and main's
  _unpack_int4_to_int8_pair/_pack_i32_pair_to_i64 that were lost
  during merge conflict resolution
- Update moe_gemm_2stage.py mfma_fn calls to match new flydsl MFMA
  API: pass (res, a, b, c, cbsz, abid, blgp) as positional args
  instead of list, and access .result on the returned Operation

Made-with: Cursor
c126_i32 = arith.constant(126, type=i32)
c127_i32 = arith.constant(127, type=i32)
c254_i32 = arith.constant(254, type=i32)
c256_i32 = arith.constant(256, type=i32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use fx.int

_if_col = scf.IfOp(col_valid)
with ir.InsertionPoint(_if_col.then_block):

_if_valid = scf.IfOp(is_valid, has_else=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use raw if?

local_max = c0_f32
for vi in range_constexpr(VEC):
abs_v = llvm.call_intrinsic(
f32, "llvm.fabs.f32", [act_vals[vi]], [], []
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use rocdl.fabs?

)

lane_in_blk = col0 & c31_i32
_if_sw = scf.IfOp(arith.cmpi(CmpIPredicate.eq, lane_in_blk, c0_i32))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raw if?

New flydsl MFMA ops (mfma_f32_16x16x32_f16/bf16) use positional args
and return Operation (need .result), while legacy ops (fp8, int8, k16)
still use the old (res_type, [list]) calling convention.

Introduce a thin mfma_fn wrapper at both call sites that dispatches
correctly based on _use_mfma_k32, keeping all 14 call sites in list
format: mfma_fn(res_ty, [a, b, c, 0, 0, 0]).

Made-with: Cursor
@coderfeli
Copy link
Copy Markdown
Collaborator

Try to make the style of silu closed to norm/quant/rope kernels, using native layout api and internal dsl types.

@coderfeli
Copy link
Copy Markdown
Collaborator

coderfeli commented Apr 13, 2026

Also add more benchmark shapes related with your opts, so we can easily notice regression.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants