[Perf] Port mixed_moe kernel optimizations for stage1/stage2#388
Open
[Perf] Port mixed_moe kernel optimizations for stage1/stage2#388
Conversation
Port performance-critical optimizations from aiter's mixed_moe_gemm_2stage kernel body (both stage1 and stage2) into FlyDSL, along with supporting infrastructure changes. Key changes: - mixed_moe_gemm_2stage.py: Full kernel body replacement with aiter version featuring dual SmemAllocator (ping-pong), unified MFMA pipeline schedule, _barrier() for fine-grained waitcnt control, and new parameters (persist_m, fuse_fp4_quant, fuse_sort_scale, use_async_copy, sort_block_m, etc.) - layout_utils.py: New file ported from aiter for layout index arithmetic (crd2idx, idx2crd, _div_pow2, _mod_pow2) - silu_and_mul_fq.py: New file ported from aiter for split-K + fp4 quant after silu fusion - mfma_preshuffle_pipeline.py: Added k_major support, cache_modifier param, bitwise-AND optimization in swizzle_xor16, PreshuffleScaleLayout additions - kernels_common.py: Extracted shared _if_then context manager and validate_moe_dtypes helper - mfma_epilogues.py: Replaced local _if_then with shared import Performance (DeepSeek TP8 FP4, 7168x256, E=256, K=8): - Stage1 Decode t=1: 37.3 -> 26.2 us (-29.8%) - Stage1 Decode t=8: 45.0 -> 31.0 us (-31.1%) - Stage1 Prefill 8K: 561.8 -> 348.8 us (-37.9%) - Stage2 Prefill 8K reduce: 569.1 -> 534.8 us (-6.0%) - FP8 stage2 unchanged (within noise) Made-with: Cursor
Resolve conflicts in mfma_preshuffle_pipeline.py: - Keep PreshuffleBLayout with k_major support (ours) - Keep new _unpack_int4_to_int8_pair, _pack_i32_pair_to_i64, and groupwise W4A16 functions (theirs) - Merge __all__ exports from both sides Made-with: Cursor
- Restore PreshuffleBLayout/make_preshuffle_b_layout and main's _unpack_int4_to_int8_pair/_pack_i32_pair_to_i64 that were lost during merge conflict resolution - Update moe_gemm_2stage.py mfma_fn calls to match new flydsl MFMA API: pass (res, a, b, c, cbsz, abid, blgp) as positional args instead of list, and access .result on the returned Operation Made-with: Cursor
coderfeli
reviewed
Apr 13, 2026
| c126_i32 = arith.constant(126, type=i32) | ||
| c127_i32 = arith.constant(127, type=i32) | ||
| c254_i32 = arith.constant(254, type=i32) | ||
| c256_i32 = arith.constant(256, type=i32) |
coderfeli
reviewed
Apr 13, 2026
| _if_col = scf.IfOp(col_valid) | ||
| with ir.InsertionPoint(_if_col.then_block): | ||
|
|
||
| _if_valid = scf.IfOp(is_valid, has_else=True) |
coderfeli
reviewed
Apr 13, 2026
| local_max = c0_f32 | ||
| for vi in range_constexpr(VEC): | ||
| abs_v = llvm.call_intrinsic( | ||
| f32, "llvm.fabs.f32", [act_vals[vi]], [], [] |
coderfeli
reviewed
Apr 13, 2026
| ) | ||
|
|
||
| lane_in_blk = col0 & c31_i32 | ||
| _if_sw = scf.IfOp(arith.cmpi(CmpIPredicate.eq, lane_in_blk, c0_i32)) |
New flydsl MFMA ops (mfma_f32_16x16x32_f16/bf16) use positional args and return Operation (need .result), while legacy ops (fp8, int8, k16) still use the old (res_type, [list]) calling convention. Introduce a thin mfma_fn wrapper at both call sites that dispatches correctly based on _use_mfma_k32, keeping all 14 call sites in list format: mfma_fn(res_ty, [a, b, c, 0, 0, 0]). Made-with: Cursor
Collaborator
|
Try to make the style of silu closed to norm/quant/rope kernels, using native layout api and internal dsl types. |
Collaborator
|
Also add more benchmark shapes related with your opts, so we can easily notice regression. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Key changes:
featuring dual SmemAllocator (ping-pong), unified MFMA pipeline schedule,
_barrier() for fine-grained waitcnt control, and new parameters (persist_m,
fuse_fp4_quant, fuse_sort_scale, use_async_copy, sort_block_m, etc.)
(crd2idx, idx2crd, _div_pow2, _mod_pow2)
after silu fusion
bitwise-AND optimization in swizzle_xor16, PreshuffleScaleLayout additions
validate_moe_dtypes helper
Performance (DeepSeek TP8 FP4, 7168x256, E=256, K=8):
Technical Details
Test Plan
Test Result
Submission Checklist