From 3de24c686d94ecaa6a39de73421faf5fcb078366 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jun 2026 12:20:18 -0700 Subject: [PATCH 1/3] Split-K2 SwiGLU GEMV --- docs/contrib_ops/cuda/moe_qmoe.md | 22 ++ .../contrib_ops/cuda/qmoe_gemv_experiments.md | 164 ++++++++++ .../contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu | 295 +++++++++++++++++- .../contrib_ops/cuda/llm/moe_gemm/moe_gemv.h | 1 + .../cuda/llm/moe_gemm/moe_kernels.cu | 39 ++- .../cuda/llm/moe_gemm/moe_kernels.h | 5 +- .../python/transformers/profile_qmoe_gemv.py | 16 + .../python/transformers/profile_qmoe_gemv.sh | 19 +- .../python/transformers/test_qmoe_cuda.py | 11 + 9 files changed, 557 insertions(+), 15 deletions(-) diff --git a/docs/contrib_ops/cuda/moe_qmoe.md b/docs/contrib_ops/cuda/moe_qmoe.md index d0d2b11548ca3..b81c930af5d45 100644 --- a/docs/contrib_ops/cuda/moe_qmoe.md +++ b/docs/contrib_ops/cuda/moe_qmoe.md @@ -989,6 +989,28 @@ per-column INT4, block-wise INT4/INT8, and interleaved-SwiGLU GEMV kernels. | Kernel instantiation | `moe_gemv.cu` adds `__nv_bfloat16` details/instantiations (group sizes 0/32/64/128, INT4/INT8, bias on/off) under `ENABLE_BF16`. | The custom FC1/FC2 GEMV kernels run for BF16; no grouped-GEMM fallback when the FP16 gate would route. | | Profiling | GPT-OSS-20B, Qwen3.6-35B-A3B, and Gemma model shapes profiled with `block_size=64` for both dtypes. | BF16 matches FP16 routing and latency within noise (about 1.3x–1.5x faster than grouped GEMM); SwiGLU BF16 parity tests pass. | +#### Split-K2 SwiGLU GEMV default path + +The fp16 INT4 interleaved-SwiGLU GEMV path uses a two-pass Split-K2 FC1 kernel by +default for supported decode shapes. The first pass computes two K-split FP32 +partials into QMoE workspace, and the second pass reduces those partials, adds +optional bias, and applies the interleaved SwiGLU epilogue. FC2 stays on the +regular `moe_gemv_kernel` path. + +Set `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` before process start to force the +previous single-kernel FC1 SwiGLU GEMV path for debugging, A/B benchmarking, or +bisecting numerical differences. On GPT-OSS-20B, Split-K2 reduced FC1 kernel +work from about 21.42 us to 19.98 us and improved repeated CUDA-graph decode +throughput by about 0.9% to 1.6% with valid focused-helper output. A 1000-sample +MMLU smoke matched the opt-out fallback within noise. A future autotuner can +replace this hand-selected default with per-shape route selection. + +```bash +onnxruntime/test/python/transformers/profile_qmoe_gemv.py \ + --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 \ + --disable-splitk2-swiglu --warmup 5 --repeat 100 --nvtx +``` + #### Experiments rejected after profiling | Experiment | Why it was rejected | diff --git a/docs/contrib_ops/cuda/qmoe_gemv_experiments.md b/docs/contrib_ops/cuda/qmoe_gemv_experiments.md index 33ff5eeb483c6..e2e8c16984f43 100644 --- a/docs/contrib_ops/cuda/qmoe_gemv_experiments.md +++ b/docs/contrib_ops/cuda/qmoe_gemv_experiments.md @@ -978,3 +978,167 @@ Every case reported `has_invalid_output=false`. per-column case for INT4 and INT8. - Per-column INT8 W8A16 decode shapes route to GEMV for both FP16 and BF16 and beat the grouped-GEMM fallback at every profiled shape. + +## 2026-06-19: Split-K2 Two-Pass SwiGLU GEMV Experiment + +### Change Under Test + +- Code commit: `f1d6718be719c1237be392c0389874b6a8926a3c` + (`Experiment QMoE split-K SwiGLU GEMV`). +- Added default Split-K2 route with opt-out env knob: + `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1`. +- Scope: FP16 INT4/interleaved-SwiGLU FC1 GEMV path for decode-shaped QMoE. +- Implementation: + - First pass launches `moe_gemv_splitk_partials_kernel` with `SplitK=2` and + writes FP32 partials into QMoE workspace. + - Second pass launches `moe_gemv_splitk_reduce_swiglu_kernel` to reduce the + partials, add optional bias, and apply SwiGLU. + - FC2 remains on the existing `moe_gemv_kernel`. + - Scratch is allocated only for the supported Split-K2 route. Setting + `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` restores the previous single-kernel + FC1 SwiGLU GEMV path. + +### Repro Notes + +- Build: `cmake --build build/cu130/Release --target onnxruntime_providers_cuda --parallel $(nproc)`. +- Important provider sync: Python tests importing from + `build/cu130/Release/onnxruntime` load + `build/cu130/Release/onnxruntime/capi/libonnxruntime_providers_cuda.so`, not + only the top-level `build/cu130/Release/libonnxruntime_providers_cuda.so` or + the venv copy. Sync all relevant copies before measuring: + + ```bash + cp build/cu130/Release/libonnxruntime_providers_cuda.so \ + build/cu130/Release/onnxruntime/capi/libonnxruntime_providers_cuda.so + cp build/cu130/Release/libonnxruntime_providers_cuda.so \ + .venv_cu130/lib/python3.14/site-packages/onnxruntime/capi/libonnxruntime_providers_cuda.so + ``` + +- Focused QMoE helper: + + ```bash + cd ~ + CUDA_VISIBLE_DEVICES=1 \ + LD_LIBRARY_PATH=~/onnxruntime/build/cu130/Release:~/cuda13.0/lib64:~/cudnn9.19_cuda13/lib:~/cudnn9.19_cuda13/lib64:${LD_LIBRARY_PATH:-} \ + PYTHONPATH=~/onnxruntime/build/cu130/Release:~/onnxruntime/onnxruntime/test/python/transformers \ + ~/onnxruntime/.venv_cu130/bin/python \ + ~/onnxruntime/onnxruntime/test/python/transformers/profile_qmoe_gemv.py \ + --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --warmup 3 --repeat 20 + ``` + +### Focused QMoE Smoke + +Both modes reported `has_invalid_output=false`. + +| Mode | Env | Latency ms | +|------|-----|------------| +| Baseline | `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` | 0.072344 | +| Split-K2 | none | 0.073816 | + +The short helper was slightly slower with split-K2, so Nsight was required to +confirm route selection and isolate kernel time. + +### Nsight Systems Kernel Results + +Artifacts: + +- Baseline: `/tmp/qmoe_gptoss_baseline_final.{nsys-rep,sqlite}` +- Split-K2: `/tmp/qmoe_gptoss_splitk_final.{nsys-rep,sqlite}` + +Command shape: + +```bash +~/cuda13.0/bin/nsys profile -t cuda,nvtx --force-overwrite true \ + -o /tmp/qmoe_gptoss_splitk_final --export=sqlite \ + ~/onnxruntime/.venv_cu130/bin/python \ + ~/onnxruntime/onnxruntime/test/python/transformers/profile_qmoe_gemv.py \ + --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --warmup 3 --repeat 30 --nvtx +``` + +Parsed with `parse_nsys.py --nvtx-range benchmark --pattern '%'`. + +| Mode | Kernel | Calls | Avg us | +|------|--------|-------|--------| +| Baseline | `moe_gemv_interleaved_swiglu_kernel` | 30 | 21.42 | +| Baseline | `moe_gemv_kernel` | 30 | 12.13 | +| Split-K2 | `moe_gemv_splitk_partials_kernel` | 30 | 17.59 | +| Split-K2 | `moe_gemv_splitk_reduce_swiglu_kernel` | 30 | 2.39 | +| Split-K2 | `moe_gemv_kernel` | 30 | 12.22 | + +Split-K2 reduced FC1 kernel work from about `21.42 us` to `17.59 + 2.39 = +19.98 us`, a net FC1 reduction of about `1.44 us` per QMoE invocation. End-to-end +under Nsight was effectively tied: + +| Mode | Helper latency ms | +|------|-------------------| +| Baseline | 0.079855 | +| Split-K2 | 0.079728 | + +### Model-Level Decode Benchmark With CUDA Graph + +The user requested model-level measurement assuming CUDA graph. Both runs used +the GPT-OSS-20B INT4 QMoE model package, CUDA graph enabled, XQA enabled, and +deterministic MoE tactic selection: + +```bash +MODEL=models/gpt-oss-20b/variants/cuda_int4_int4_qmoe_rtn_matmul_only \ +GPU=0 PROMPT_LEN=512 GEN_LEN=128 REPS=10 WARMUP=3 CUDA_GRAPH=1 XQA=1 SYNC_LIB=1 \ +ORT_FORCE_DETERMINISTIC_MOE=1 \ +bash scripts/bench_gpt_oss_ort_decode.sh +``` + +Baseline additionally set `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1`. + +| Run | Mode | Decode latency ms/token | Decode throughput tok/s | +|-----|------|-------------------------|-------------------------| +| R1, `REPS=5`, `WARMUP=2` | Baseline | 2.869450 | 348.498901 | +| R1, `REPS=5`, `WARMUP=2` | Split-K2 | 2.823800 | 354.132707 | +| R2, `REPS=10`, `WARMUP=3` | Baseline | 2.865840 | 348.937861 | +| R2, `REPS=10`, `WARMUP=3` | Split-K2 | 2.839335 | 352.195107 | + +The longer CUDA-graph pair showed about `+0.9%` decode throughput. The shorter +pair showed about `+1.6%`. Since the focused helper reported valid output and +the model-level gain repeated in the same direction, even this modest gain is +worth enabling for GPT-OSS-20B decode while keeping an opt-out for A/B checks. + +After flipping Split-K2 to the default and adding +`ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` as the opt-out, three more paired +CUDA-graph model runs were collected with `REPS=10`, `WARMUP=3`, prompt length +512, and generation length 128: + +| Run | Mode | Decode latency ms/token | Decode throughput tok/s | +|-----|------|-------------------------|-------------------------| +| R3 | Default Split-K2 | 3.017252 | 331.427448 | +| R3 | Split-K2 disabled | 3.055736 | 327.253380 | +| R4 | Default Split-K2 | 3.006739 | 332.586260 | +| R4 | Split-K2 disabled | 3.047570 | 328.130314 | +| R5 | Default Split-K2 | 3.009466 | 332.284898 | +| R5 | Split-K2 disabled | 3.047015 | 328.190090 | +| Average | Default Split-K2 | 3.011152 | 332.099536 | +| Average | Split-K2 disabled | 3.050107 | 327.857928 | + +The default Split-K2 route was faster in all three pairs, averaging `+1.29%` +decode throughput and `-1.28%` decode latency versus the opt-out fallback. + +### Accuracy Smoke + +A 1000-sample `match_mmlu` smoke was run with the local parallel eval harness on +all eight H200 GPUs, using the same GPT-OSS-20B INT4 QMoE model package and the +current ORT build package. The default Split-K2 run scored `0.8380` pooled +accuracy; the opt-out fallback with `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` +scored `0.8350`. The small positive difference is within smoke-test noise, and +there is no accuracy regression signal from enabling Split-K2 by default. + +### Decision + +- Enable Split-K2 by default for its supported fp16 INT4 interleaved-SwiGLU GEMV + scope. +- Keep `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` as the fallback and A/B knob. +- The 1000-sample MMLU smoke matched the opt-out fallback within noise, so the + default flip has an accuracy sanity check in addition to focused-helper valid + output. +- Future work: + - Add per-shape autotune so route selection is data-driven instead of a fixed + default. + - Try a launch-fused reduction strategy or cooperative approach to keep the + FC1 parallelism benefit without the extra reduce launch. diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu index dd6d7cecf9288..f8dc958de3e86 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu @@ -174,6 +174,205 @@ __device__ __forceinline__ void swiglu_epilogue(void* out, void* tile_acc, void* } } +template +__device__ __forceinline__ void partial_epilogue(float* partial_out, void* tile_acc) { + static constexpr int Interleave = Details::kInterleave; + static constexpr int ThreadsPerInterleavedTile = Details::kThreadsPerInterleavedTile; + static constexpr int WarpSize = Details::kWarpSize; + static constexpr int WarpNum = Threads / WarpSize; + static_assert(CtaM == 1); + static_assert(Threads % WarpSize == 0); + + __shared__ float shmem[CtaM * CtaN * Interleave * WarpNum]; + int tid = threadIdx.x; + int warp_id = tid / WarpSize; + int lane_id = tid % WarpSize; +#pragma unroll + for (int n = 0; n < CtaN; ++n) { + float v = static_cast(reinterpret_cast(tile_acc)[n]); + v = warp_reduce_sum(v); + if (lane_id < Interleave * ThreadsPerInterleavedTile && lane_id % ThreadsPerInterleavedTile == 0) { + shmem[warp_id * CtaN * Interleave + n * Interleave + lane_id / ThreadsPerInterleavedTile] = v; + } + } + __syncthreads(); + +#pragma unroll + for (int col = tid; col < CtaN * Interleave; col += Threads) { + float val = 0.f; +#pragma unroll + for (int warp = 0; warp < WarpNum; ++warp) { + val += shmem[warp * CtaN * Interleave + col]; + } + partial_out[col] = val; + } +} + +template +__global__ void moe_gemv_splitk_partials_kernel( + TypeA* act, uint8_t* weight, TypeA* scales, float* partials, + int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, + int64_t weight_expert_stride, int64_t scale_expert_stride, int n, int k, int64_t expanded_num_rows) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) + using AccessTypeA = typename Details::AccessTypeA; + using AccessTypeW = typename Details::AccessTypeW; + + static constexpr bool Mandatory = true; + static constexpr int CtaM = 1; + static constexpr int StepK = Details::kStepK; + static constexpr int CtaK = StepK * Threads; + static_assert(CtaN % 2 == 0); + if constexpr (GroupSize != 0) { + static_assert((CtaK / Details::kInterleave) % GroupSize == 0); + } + + int const row = blockIdx.x; + int expert = permuted_row_to_expert != nullptr ? permuted_row_to_expert[row] : 0; +#pragma unroll 1 + for (int e = 0; e < num_experts && permuted_row_to_expert == nullptr; ++e) { + if (row >= static_cast(expert_first_token_offset[e + 1])) { + expert = e + 1; + continue; + } + break; + } + if (expert < 0 || expert >= num_experts) { + return; + } + + weight += expert * weight_expert_stride; + scales += static_cast(expert) * scale_expert_stride; + + int const origin_k = k; + int const interleaved_k = k * Details::kInterleave; + int const tile_id_m = row; + int const tile_id_n = blockIdx.y; + int const split_id = blockIdx.z; + int const tid = threadIdx.x; + int const offset_m = tile_id_m * CtaM; + int const interleaved_offset_n = tile_id_n * CtaN; + int const real_offset_n = interleaved_offset_n * Details::kInterleave + + ((tid * StepK / Details::LayoutDetails::kTileSize) % Details::kInterleave); + int const real_offset_k = + (tid * StepK / (Details::kInterleave * Details::LayoutDetails::kTileSize)) * + Details::LayoutDetails::kTileSize + + ((tid * StepK) % Details::LayoutDetails::kTileSize); + + GMemIterator act_iterator( + act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); + GMemIterator weight_iterator( + weight, (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, + CtaK / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW); + GMemIterator scales_iterator( + scales, (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); + + partials += (static_cast(split_id) * expanded_num_rows + offset_m) * n + + tile_id_n * CtaN * Details::kInterleave; + + AccT tile_acc[CtaM * CtaN]; + fill(tile_acc, static_cast(0.f)); + + TypeA vec_scale[CtaN]; + if constexpr (GroupSize == 0) { +#pragma unroll + for (int i = 0; i < CtaN; ++i) { + scales_iterator.load(vec_scale + i, 0, i); + } + } + + int const num_iters = (interleaved_k + CtaK - 1) / CtaK; + for (int iter = split_id; iter < num_iters; iter += SplitK) { + int const idx_k = iter * CtaK + tid * StepK; + if (idx_k >= interleaved_k) { + continue; + } + TypeA tile_a[StepK]; + TypeA tile_w[StepK]; + TypeA tile_w_pack2[CtaN * StepK]; + uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; + if constexpr (GroupSize != 0) { +#pragma unroll + for (int i = 0; i < CtaN; ++i) { + scales_iterator.load(vec_scale + i, iter, i); + } + } +#pragma unroll + for (int i = 0; i < CtaN; ++i) { + weight_iterator.load(tile_w_quantized, iter, i); + dequantize(tile_w, tile_w_quantized, vec_scale + i, nullptr, 1.0f); + pack_to_vec2(tile_w_pack2, tile_w, i); + } + act_iterator.load(tile_a, iter, 0); + mma(tile_acc, tile_w_pack2, tile_a); + } + partial_epilogue(partials, tile_acc); +#endif +} + +template +__global__ void moe_gemv_splitk_reduce_swiglu_kernel( + float const* partials, TypeA* bias, TypeA* out, + int const* permuted_row_to_expert, int num_experts, int64_t const* expert_first_token_offset, + int inter_size, int split_k, int64_t expanded_num_rows, + cutlass_kernels::ActivationParams activation_params) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) + int const row = blockIdx.x; + int const col = blockIdx.y * Threads + threadIdx.x; + if (col >= inter_size) { + return; + } + + int expert = permuted_row_to_expert != nullptr ? permuted_row_to_expert[row] : 0; +#pragma unroll 1 + for (int e = 0; e < num_experts && permuted_row_to_expert == nullptr; ++e) { + if (row >= static_cast(expert_first_token_offset[e + 1])) { + expert = e + 1; + continue; + } + break; + } + if (expert < 0 || expert >= num_experts) { + return; + } + + float const* alpha = activation_params.swiglu_alpha; + float const* beta = activation_params.swiglu_beta; + float const* limit = activation_params.swiglu_limit; + float const act_alpha = alpha ? alpha[expert] : activation_params.alpha; + float const act_beta = beta ? beta[expert] : activation_params.beta; + float const act_limit = limit ? limit[expert] : activation_params.limit; + + int const n = inter_size * 2; + int const gate_idx = col * 2; + int const linear_idx = gate_idx + 1; + int64_t const row_base = static_cast(row) * n; + int64_t const split_stride = expanded_num_rows * n; + float gate = 0.f; + float linear = 0.f; + for (int split = 0; split < split_k; ++split) { + int64_t const base = static_cast(split) * split_stride + row_base; + gate += partials[base + gate_idx]; + linear += partials[base + linear_idx]; + } + + if constexpr (EnableBias) { + bias += static_cast(expert) * n; + gate += static_cast(bias[gate_idx]); + linear += static_cast(bias[linear_idx]); + } + if (isfinite(act_limit)) { + gate = fminf(gate, act_limit); + linear = fminf(fmaxf(linear, -act_limit), act_limit); + } + linear += act_beta; + float const sigmoid = 1.0f / (1.0f + expf(-act_alpha * gate)); + out[static_cast(row) * inter_size + col] = static_cast(gate * sigmoid * linear); +#endif +} + template __global__ void moe_gemv_interleaved_swiglu_kernel( @@ -326,6 +525,76 @@ static void launch_moe_gemv_interleaved_swiglu( } } +template +static void launch_moe_gemv_splitk_twopass_swiglu( + TypeA* act, uint8_t* weight, TypeA* scales, TypeA* bias, TypeA* out, + int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, + int64_t expanded_num_rows, int64_t inter_size, int64_t k, + cutlass_kernels::ActivationParams activation_params, float* partials, cudaStream_t stream) { + static constexpr int StepK = Details::kStepK; + static constexpr int CtaK = StepK * Threads; + int64_t const n = inter_size * 2; + int64_t const interleaved_k = k * Details::kInterleave; + int const num_iters = static_cast((interleaved_k + CtaK - 1) / CtaK); + if (partials == nullptr || num_iters < 2) { + launch_moe_gemv_interleaved_swiglu( + act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, + expanded_num_rows, inter_size, k, activation_params, stream); + return; + } + + int64_t const weight_expert_stride = n * k / Details::kElemsPerByteW; + int64_t const scale_expert_stride = GroupSize == 0 ? n : ((k + GroupSize - 1) / GroupSize) * n; + dim3 grid1(static_cast(expanded_num_rows), + static_cast(n / (CtaN * Details::kInterleave)), SplitK); + dim3 block1(Threads); + moe_gemv_splitk_partials_kernel + <<>>( + act, weight, scales, partials, expert_first_token_offset, permuted_row_to_expert, num_experts, + weight_expert_stride, scale_expert_stride, static_cast(n), static_cast(k), expanded_num_rows); + + static constexpr int kReduceThreads = 256; + dim3 grid2(static_cast(expanded_num_rows), + static_cast((inter_size + kReduceThreads - 1) / kReduceThreads)); + dim3 block2(kReduceThreads); + if (bias != nullptr) { + moe_gemv_splitk_reduce_swiglu_kernel<<>>( + partials, bias, out, permuted_row_to_expert, num_experts, expert_first_token_offset, + static_cast(inter_size), SplitK, expanded_num_rows, activation_params); + } else { + moe_gemv_splitk_reduce_swiglu_kernel<<>>( + partials, bias, out, permuted_row_to_expert, num_experts, expert_first_token_offset, + static_cast(inter_size), SplitK, expanded_num_rows, activation_params); + } +} + +template +static void dispatch_moe_gemv_splitk_twopass_swiglu_group_size( + TypeA* act, uint8_t* weight, TypeA* scales, TypeA* bias, TypeA* out, + int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, + int64_t expanded_num_rows, int64_t inter_size, int64_t k, int group_size, + cutlass_kernels::ActivationParams activation_params, float* partials, cudaStream_t stream) { + if (group_size <= 0) { + launch_moe_gemv_splitk_twopass_swiglu( + act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, + expanded_num_rows, inter_size, k, activation_params, partials, stream); + } else if (group_size == 32) { + launch_moe_gemv_splitk_twopass_swiglu( + act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, + expanded_num_rows, inter_size, k, activation_params, partials, stream); + } else if (group_size == 64) { + launch_moe_gemv_splitk_twopass_swiglu( + act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, + expanded_num_rows, inter_size, k, activation_params, partials, stream); + } else if (group_size == 128) { + launch_moe_gemv_splitk_twopass_swiglu( + act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, + expanded_num_rows, inter_size, k, activation_params, partials, stream); + } else { + ORT_THROW("unsupported MoE GEMV split-K group_size: ", group_size); + } +} + template static void dispatch_moe_gemv_group_size(TypeA* act, uint8_t* weight, TypeA* scales, TypeA* bias, TypeA* out, int64_t const* expert_first_token_offset, @@ -528,10 +797,23 @@ void launch_moe_gemv_int_symmetric_interleaved_swiglu( T const* act, WeightType const* weight, T const* scales, T const* bias, T* out, int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, int64_t expanded_num_rows, int64_t inter_size, int64_t k, int group_size, int sm, - cutlass_kernels::ActivationParams activation_params, cudaStream_t stream) { + cutlass_kernels::ActivationParams activation_params, float* splitk_partials, cudaStream_t stream) { ORT_UNUSED_PARAMETER(sm); using Details = typename DetailsForTAndWeight::Details; using TypeA = typename DetailsForTAndWeight::TypeA; + if (splitk_partials != nullptr) { + if constexpr (std::is_same_v) { + fiv::dispatch_moe_gemv_splitk_twopass_swiglu_group_size( + const_cast(reinterpret_cast(act)), + const_cast(reinterpret_cast(weight)), + const_cast(reinterpret_cast(scales)), + const_cast(reinterpret_cast(bias)), + reinterpret_cast(out), + expert_first_token_offset, permuted_row_to_expert, num_experts, expanded_num_rows, inter_size, k, group_size, + activation_params, splitk_partials, stream); + return; + } + } // Accumulate in fp32 by default (see launch_moe_gemv_int_symmetric for the policy). bool const use_fp32_accum = !std::is_same_v || !MoeGemvUseFp16Accum(); auto launch = [&](auto acc_tag) { @@ -570,7 +852,8 @@ void launch_moe_gemv_int4_per_channel_interleaved_swiglu( cutlass_kernels::ActivationParams activation_params, cudaStream_t stream) { launch_moe_gemv_int_symmetric_interleaved_swiglu( act, reinterpret_cast(weight), scales, bias, out, expert_first_token_offset, - permuted_row_to_expert, num_experts, expanded_num_rows, inter_size, k, 0, sm, activation_params, stream); + permuted_row_to_expert, num_experts, expanded_num_rows, inter_size, k, 0, sm, activation_params, + nullptr, stream); } template void launch_moe_gemv_int_symmetric( @@ -581,10 +864,10 @@ template void launch_moe_gemv_int_symmetric( int64_t, int64_t, int64_t, int, int, cudaStream_t); template void launch_moe_gemv_int_symmetric_interleaved_swiglu( half const*, cutlass::uint4b_t const*, half const*, half const*, half*, int64_t const*, int const*, int, - int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, cudaStream_t); + int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, float*, cudaStream_t); template void launch_moe_gemv_int_symmetric_interleaved_swiglu( half const*, uint8_t const*, half const*, half const*, half*, int64_t const*, int const*, int, - int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, cudaStream_t); + int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, float*, cudaStream_t); template void launch_moe_gemv_int4_per_channel(half const*, uint8_t const*, half const*, half const*, half*, int64_t const*, int const*, int, int64_t, int64_t, int64_t, @@ -603,11 +886,11 @@ template void launch_moe_gemv_int_symmetric<__nv_bfloat16, uint8_t>( template void launch_moe_gemv_int_symmetric_interleaved_swiglu<__nv_bfloat16, cutlass::uint4b_t>( __nv_bfloat16 const*, cutlass::uint4b_t const*, __nv_bfloat16 const*, __nv_bfloat16 const*, __nv_bfloat16*, int64_t const*, int const*, int, int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, - cudaStream_t); + float*, cudaStream_t); template void launch_moe_gemv_int_symmetric_interleaved_swiglu<__nv_bfloat16, uint8_t>( __nv_bfloat16 const*, uint8_t const*, __nv_bfloat16 const*, __nv_bfloat16 const*, __nv_bfloat16*, int64_t const*, int const*, int, int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, - cudaStream_t); + float*, cudaStream_t); template void launch_moe_gemv_int4_per_channel<__nv_bfloat16>( __nv_bfloat16 const*, uint8_t const*, __nv_bfloat16 const*, __nv_bfloat16 const*, __nv_bfloat16*, diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.h b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.h index b4dfe1c59f02a..c855106f9c08d 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.h +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.h @@ -53,6 +53,7 @@ void launch_moe_gemv_int_symmetric_interleaved_swiglu( T const* act, WeightType const* weight, T const* scales, T const* bias, T* out, int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, int64_t expanded_num_rows, int64_t inter_size, int64_t k, int group_size, int sm, cutlass_kernels::ActivationParams activation_params, + float* splitk_partials, cudaStream_t stream); // Launches the int4 per-channel MoE GEMV. diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu index 3185a9a86b231..68e825b5f88ae 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu @@ -83,6 +83,19 @@ inline bool MoeGemvDisabledByEnv() { return disabled; } +inline bool MoeGemvSplitK2SwiGLUDisabledByEnv() { + // Parsed once via ORT's environment helper (consistent parsing/thread-safety across platforms). + static bool const disabled = + onnxruntime::ParseEnvironmentVariableWithDefault("ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU", 0) == 1; + return disabled; +} + +template +constexpr bool MoeGemvSplitK2SwiGLUSupported() { + return std::is_same_v && std::is_same_v && + std::is_same_v; +} + inline bool MoeGemvRejectedByProfiledInterSize(int64_t expanded_num_rows, int64_t inter_size) { return expanded_num_rows > onnxruntime::llm::kernels::moe_gemv::kMaxProfiledExpandedRowsForSmallProblemDim && inter_size < onnxruntime::llm::kernels::moe_gemv::kMinProfiledProblemDimForExpandedRowsAbove4; @@ -153,7 +166,7 @@ bool tryLaunchMoeGemvIntSymmetricInterleavedSwiGLU( ScaleBiasType const* biases, T* output, int64_t const* expert_first_token_offset, int num_experts_per_node, int const* permuted_row_to_expert, int64_t expanded_num_rows, int64_t inter_size, int64_t k, int sm, int group_size, - bool disabled, cutlass_kernels::ActivationParams activation_params, cudaStream_t stream) { + bool disabled, cutlass_kernels::ActivationParams activation_params, float* splitk_partials, cudaStream_t stream) { if constexpr ((std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v) && std::is_same_v) { bool const env_disabled = MoeGemvDisabledByEnv(); @@ -176,7 +189,8 @@ bool tryLaunchMoeGemvIntSymmetricInterleavedSwiGLU( } onnxruntime::llm::kernels::moe_gemv::launch_moe_gemv_int_symmetric_interleaved_swiglu( input, weights, scales, biases, output, expert_first_token_offset, permuted_row_to_expert, - num_experts_per_node, expanded_num_rows, inter_size, k, group_size, sm, activation_params, stream); + num_experts_per_node, expanded_num_rows, inter_size, k, group_size, sm, activation_params, + splitk_partials, stream); return true; } else { (void)input; @@ -195,6 +209,7 @@ bool tryLaunchMoeGemvIntSymmetricInterleavedSwiGLU( (void)group_size; (void)disabled; (void)activation_params; + (void)splitk_partials; (void)stream; return false; } @@ -2169,6 +2184,12 @@ CutlassMoeFCRunner: // (see comment above). Only the gated / has-glu path uses it; zero otherwise. size_t const fc1_result_dedicated_size = (glu_inter_elems > 0) ? fc1_result_size : 0; + static constexpr int kMoeGemvSplitK = 2; + size_t const moe_gemv_splitk_partials_size = is_gated_activation && !MoeGemvSplitK2SwiGLUDisabledByEnv() && + MoeGemvSplitK2SwiGLUSupported() + ? kMoeGemvSplitK * glu_inter_elems * sizeof(float) + : 0; + size_t smoothed_act_size = use_awq ? std::max(permuted_elems, interbuf_elems) * sizeof(T) * 2 : 0; // Extra workspace required by AWQ for smoothing activations @@ -2193,6 +2214,7 @@ CutlassMoeFCRunner: ADD(overlapped_gemm1_gemm2_inputs); ADD(overlapped_gemm1_gemm2_outputs); ADD(fc1_result_dedicated); + ADD(moe_gemv_splitk_partials); ADD_NAME(alpha_scale_ptr_array_fc1, alpha_scale_ptr_array_size); ADD_NAME(alpha_scale_ptr_array_fc2, alpha_scale_ptr_array_size); ADD(fp4_act_scale); @@ -2270,6 +2292,7 @@ void CutlassMoeFCRunner(); + ORT_ENFORCE(!use_splitk_swiglu_gemv || moe_gemv_splitk_partials != nullptr, + "Split-K2 SwiGLU GEMV requires split-K GEMV workspace"); bool const fc1_did_fused_gemv = tryLaunchMoeGemvIntSymmetricInterleavedSwiGLU( input, fc1_expert_weights, quant_params.groupwise.group_size > 0 @@ -2466,7 +2493,9 @@ void CutlassMoeFCRunner 1 || use_ampere_activation_fusion || !bias_is_broadcast || MoeGemvRejectedByProfiledInterSize(expanded_num_rows, inter_size), - activation_params, stream); + activation_params, + use_splitk_swiglu_gemv ? moe_gemv_splitk_partials : nullptr, + stream); // Run the GEMM with activation function overridden with `Identity`, we do the activation separately. // Fast path: symmetric INT4/INT8 (per-column or block-wise) MoE GEMV for small expanded-row counts @@ -2805,7 +2834,7 @@ void CutlassMoeFCRunner(fc1_expert_weights), static_cast(fc1_expert_biases), num_valid_tokens_ptr, static_cast(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant, fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size, - num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, nullptr, bias_is_broadcast, stream, + num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, nullptr, nullptr, bias_is_broadcast, stream, MOEParallelismConfig{}, config, activation_params); } @@ -626,6 +626,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_scale_; float const** alpha_scale_ptr_array_fc1_ = nullptr; float const** alpha_scale_ptr_array_fc2_ = nullptr; + float* moe_gemv_splitk_partials_{}; void* smoothed_act_{}; TmaWarpSpecializedGroupedGemmInput tma_ws_grouped_gemm1_input_; diff --git a/onnxruntime/test/python/transformers/profile_qmoe_gemv.py b/onnxruntime/test/python/transformers/profile_qmoe_gemv.py index 0f71409fa2206..407d5e6dd5d78 100644 --- a/onnxruntime/test/python/transformers/profile_qmoe_gemv.py +++ b/onnxruntime/test/python/transformers/profile_qmoe_gemv.py @@ -72,6 +72,16 @@ def main(): action="store_true", help="Run the grouped GEMM fallback by setting ORT_DISABLE_MOE_GEMV=1 before session creation", ) + parser.add_argument( + "--splitk2-swiglu", + action="store_true", + help="Deprecated compatibility flag; split-K2 two-pass FC1 SwiGLU GEMV is enabled by default when supported", + ) + parser.add_argument( + "--disable-splitk2-swiglu", + action="store_true", + help="Disable split-K2 two-pass FC1 SwiGLU GEMV by setting ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1", + ) parser.add_argument( "--nvtx", action="store_true", @@ -101,6 +111,12 @@ def main(): else: os.environ.pop("ORT_DISABLE_MOE_GEMV", None) + if args.disable_splitk2_swiglu: + os.environ["ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU"] = "1" + else: + os.environ.pop("ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU", None) + os.environ.pop("ORT_MOE_GEMV_SPLITK2_SWIGLU", None) + result = run_qmoe_gemv_benchmark(case) if result["has_invalid_output"]: raise RuntimeError("QMoE GEMV profiling produced NaN or Inf output") diff --git a/onnxruntime/test/python/transformers/profile_qmoe_gemv.sh b/onnxruntime/test/python/transformers/profile_qmoe_gemv.sh index c452ee17fec00..f82ebe3c065c8 100755 --- a/onnxruntime/test/python/transformers/profile_qmoe_gemv.sh +++ b/onnxruntime/test/python/transformers/profile_qmoe_gemv.sh @@ -11,6 +11,7 @@ # ./profile_qmoe_gemv.sh --list-cases # ./profile_qmoe_gemv.sh --case m8_top2_fp16_128x256 --warmup 5 --repeat 200 # ./profile_qmoe_gemv.sh --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --warmup 5 --repeat 100 +# ./profile_qmoe_gemv.sh --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --disable-splitk2-swiglu --warmup 5 --repeat 100 # ./profile_qmoe_gemv.sh --batch-size 1 --sequence-length 1 --hidden-size 1024 --intermediate-size 4096 --num-experts 8 --top-k 2 --quant-bits 8 --block-size 128 # CUDA_VISIBLE_DEVICES=1 ./profile_qmoe_gemv.sh -o /tmp/qmoe_gemv # @@ -26,6 +27,7 @@ OUTPUT_NAME="qmoe_gemv_profile" PY="${PYTHON:-python}" EXTRA_ARGS=() LIST_CASES=0 +DISABLE_SPLITK2_SWIGLU=0 while [[ "$#" -gt 0 ]]; do case $1 in @@ -36,6 +38,12 @@ while [[ "$#" -gt 0 ]]; do --list-cases) LIST_CASES=1 ;; + --splitk2-swiglu) + DISABLE_SPLITK2_SWIGLU=0 + ;; + --disable-splitk2-swiglu) + DISABLE_SPLITK2_SWIGLU=1 + ;; --batch-size|--sequence-length|--hidden-size|--intermediate-size|--num-experts|--top-k|--dtype|--quant-bits|--block-size) EXTRA_ARGS+=("$1" "$2") shift @@ -58,7 +66,7 @@ while [[ "$#" -gt 0 ]]; do ;; *) echo "Unknown option: $1" - echo "Usage: $0 [--list-cases] [--case NAME] [--batch-size N] [--sequence-length N] [--hidden-size N] [--intermediate-size N] [--num-experts N] [--top-k N] [--dtype FLOAT16|BFLOAT16] [--quant-bits 4|8] [--block-size 0|32|64|128] [--warmup N] [--repeat N] [--python PYTHON] [-o NAME]" + echo "Usage: $0 [--list-cases] [--case NAME] [--disable-splitk2-swiglu] [--batch-size N] [--sequence-length N] [--hidden-size N] [--intermediate-size N] [--num-experts N] [--top-k N] [--dtype FLOAT16|BFLOAT16] [--quant-bits 4|8] [--block-size 0|32|64|128] [--warmup N] [--repeat N] [--python PYTHON] [-o NAME]" exit 1 ;; esac @@ -98,13 +106,19 @@ fi if [[ "${#EXTRA_ARGS[@]}" -gt 0 ]]; then echo "Custom args: ${EXTRA_ARGS[*]}" fi +if [[ "${DISABLE_SPLITK2_SWIGLU}" -eq 1 ]]; then + echo "Split-K2 SwiGLU: disabled for GEMV mode" +fi profile_one() { local mode="$1" local disable_arg="" + local splitk2_arg="" local base="${OUTPUT_NAME}_${mode}" if [[ "${mode}" == "gemm" ]]; then disable_arg="--disable-gemv" + elif [[ "${DISABLE_SPLITK2_SWIGLU}" -eq 1 ]]; then + splitk2_arg="--disable-splitk2-swiglu" fi echo "" @@ -112,7 +126,8 @@ profile_one() { rm -f "${base}.nsys-rep" "${base}.sqlite" nsys profile -t cuda,nvtx --force-overwrite true -o "${base}" --export=sqlite \ "${PY}" "${SCRIPT_DIR}/profile_qmoe_gemv.py" \ - --case "${CASE}" "${EXTRA_ARGS[@]}" --warmup "${WARMUP}" --repeat "${REPEAT}" --nvtx ${disable_arg} + --case "${CASE}" "${EXTRA_ARGS[@]}" --warmup "${WARMUP}" --repeat "${REPEAT}" --nvtx \ + ${disable_arg} ${splitk2_arg} echo "" echo "---- Kernel results (${mode}) ----" diff --git a/onnxruntime/test/python/transformers/test_qmoe_cuda.py b/onnxruntime/test/python/transformers/test_qmoe_cuda.py index 36d2ff66010ff..ddca2687fd392 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cuda.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cuda.py @@ -2486,6 +2486,7 @@ def run_qmoe_gemv_benchmark(case): "case": case["name"], "block_size": case.get("block_size", 0), "disable_gemv": os.getenv("ORT_DISABLE_MOE_GEMV") == "1", + "disable_splitk2_swiglu": os.getenv("ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU") == "1", "expanded_num_rows": case["batch_size"] * case["sequence_length"] * case["top_k"], "has_invalid_output": bool(torch.isnan(output).any() or torch.isinf(output).any()), "latency_ms": qmoe.last_ort_latency_ms, @@ -2512,6 +2513,16 @@ def test_decode_latency(self): self.assertFalse(result["has_invalid_output"]) print(_QMOE_GEMV_BENCHMARK_RESULT_PREFIX + json.dumps(result, sort_keys=True)) + @unittest.skipIf( + os.getenv("ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU") == "1", + "Unset ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU to run the default split-K2 SwiGLU GEMV benchmark.", + ) + def test_splitk2_swiglu_decode_latency(self): + result = run_qmoe_gemv_benchmark_case("gpt_oss_20b_m1_top4_fp16_2880x2880_e32") + self.assertFalse(result["disable_splitk2_swiglu"]) + self.assertFalse(result["has_invalid_output"]) + print(_QMOE_GEMV_BENCHMARK_RESULT_PREFIX + json.dumps(result, sort_keys=True)) + @unittest.skipIf(True, "Skipping QMoE benchmark tests") class TestQMoESwiGLUBenchmark(unittest.TestCase): From bc4712b4d88324455a8115c23ba8c53faf9ac6e7 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jun 2026 16:09:17 -0700 Subject: [PATCH 2/3] experiment of fp16 accumulation for split-k --- docs/contrib_ops/cuda/moe_qmoe.md | 10 ++- .../contrib_ops/cuda/qmoe_gemv_experiments.md | 57 ++++++++++++- .../contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu | 84 +++++++++++-------- .../contrib_ops/cuda/llm/moe_gemm/moe_gemv.h | 2 +- 4 files changed, 109 insertions(+), 44 deletions(-) diff --git a/docs/contrib_ops/cuda/moe_qmoe.md b/docs/contrib_ops/cuda/moe_qmoe.md index b81c930af5d45..93949d1bf06c1 100644 --- a/docs/contrib_ops/cuda/moe_qmoe.md +++ b/docs/contrib_ops/cuda/moe_qmoe.md @@ -992,10 +992,12 @@ per-column INT4, block-wise INT4/INT8, and interleaved-SwiGLU GEMV kernels. #### Split-K2 SwiGLU GEMV default path The fp16 INT4 interleaved-SwiGLU GEMV path uses a two-pass Split-K2 FC1 kernel by -default for supported decode shapes. The first pass computes two K-split FP32 -partials into QMoE workspace, and the second pass reduces those partials, adds -optional bias, and applies the interleaved SwiGLU epilogue. FC2 stays on the -regular `moe_gemv_kernel` path. +default for supported decode shapes. The first pass computes two K-split +partials into QMoE workspace using the same accumulator type as the normal GEMV +path: fp16 activations use fp16 partials when the fp16-accumulation route is +selected, and the fp32 fallback uses fp32 partials. The second pass reduces those +partials in fp32, adds optional bias, and applies the interleaved SwiGLU +epilogue. FC2 stays on the regular `moe_gemv_kernel` path. Set `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` before process start to force the previous single-kernel FC1 SwiGLU GEMV path for debugging, A/B benchmarking, or diff --git a/docs/contrib_ops/cuda/qmoe_gemv_experiments.md b/docs/contrib_ops/cuda/qmoe_gemv_experiments.md index e2e8c16984f43..d2301775982b9 100644 --- a/docs/contrib_ops/cuda/qmoe_gemv_experiments.md +++ b/docs/contrib_ops/cuda/qmoe_gemv_experiments.md @@ -990,9 +990,11 @@ Every case reported `has_invalid_output=false`. - Scope: FP16 INT4/interleaved-SwiGLU FC1 GEMV path for decode-shaped QMoE. - Implementation: - First pass launches `moe_gemv_splitk_partials_kernel` with `SplitK=2` and - writes FP32 partials into QMoE workspace. + writes accumulator-typed partials into QMoE workspace. This follows the + normal GEMV accumulation policy: fp16 partials for fp16 accumulation, fp32 + partials for the fp32 fallback. - Second pass launches `moe_gemv_splitk_reduce_swiglu_kernel` to reduce the - partials, add optional bias, and apply SwiGLU. + partials in fp32, add optional bias, and apply SwiGLU. - FC2 remains on the existing `moe_gemv_kernel`. - Scratch is allocated only for the supported Split-K2 route. Setting `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` restores the previous single-kernel @@ -1120,6 +1122,53 @@ CUDA-graph model runs were collected with `REPS=10`, `WARMUP=3`, prompt length The default Split-K2 route was faster in all three pairs, averaging `+1.29%` decode throughput and `-1.28%` decode latency versus the opt-out fallback. +### FP16 Accumulation Follow-Up + +After the normal QMoE GEMV path was changed to support fp16 accumulation, the +Split-K2 route was rechecked with `ORT_MOE_GEMV_FP16_ACCUM=1` on the same +GPT-OSS-20B decode shape. Sequential focused-helper runs with `--repeat 100` +showed Split-K2 behind the single-kernel path: + +| Run | Mode | Latency ms/inference | +|-----|------|----------------------| +| R1 | Split-K2 | 0.061761 | +| R1 | Split-K2 disabled | 0.060108 | +| R2 | Split-K2 | 0.062862 | +| R2 | Split-K2 disabled | 0.060989 | +| R3 | Split-K2 | 0.064595 | +| R3 | Split-K2 disabled | 0.060464 | +| Average | Split-K2 | 0.063073 | +| Average | Split-K2 disabled | 0.060520 | + +A short CUDA-graph model-level pair with `REPS=5`, `WARMUP=2`, prompt length +512, and generation length 128 showed the same direction: + +| Mode | Decode latency ms/token | Decode throughput tok/s | +|------|-------------------------|-------------------------| +| Split-K2 | 2.848148 | 351.105318 | +| Split-K2 disabled | 2.816800 | 355.012723 | + +Although the single-kernel path was faster for this GPT-OSS focused helper, +Split-K2 with fp16 accumulation was still faster than the fp32-accumulation +Split-K2 route. The fp16 Split-K2 variant is kept so a future autotuner can +choose it for shapes where the extra K parallelism wins. + +The same focused profiler check was run for Qwen3.6-35B-A3B and Gemma4-26B-A4B +decode-shaped configs with `--repeat 100`: + +| Case | Mode | Latency ms/inference | +|------|------|----------------------| +| Qwen3.6-35B-A3B | fp16 Split-K2 | 0.049207 | +| Qwen3.6-35B-A3B | fp16 Split-K2 disabled | 0.047403 | +| Qwen3.6-35B-A3B | fp32 Split-K2 | 0.052055 | +| Gemma4-26B-A4B | fp16 Split-K2 | 0.053503 | +| Gemma4-26B-A4B | fp16 Split-K2 disabled | 0.050732 | +| Gemma4-26B-A4B | fp32 Split-K2 | 0.059571 | + +Both additional shapes produced valid output. In these focused helper runs, +fp16 Split-K2 again sat between the fp16 single-kernel path and the fp32 Split-K2 +path. + ### Accuracy Smoke A 1000-sample `match_mmlu` smoke was run with the local parallel eval harness on @@ -1133,6 +1182,10 @@ there is no accuracy regression signal from enabling Split-K2 by default. - Enable Split-K2 by default for its supported fp16 INT4 interleaved-SwiGLU GEMV scope. +- Keep the fp16-accumulation Split-K2 variant available. It is slower than the + single-kernel fp16-accumulation path on the GPT-OSS shape, but faster than the + fp32-accumulation Split-K2 route and may be selected by future per-shape + autotuning. - Keep `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` as the fallback and A/B knob. - The 1000-sample MMLU smoke matched the opt-out fallback within noise, so the default flip has an accuracy sanity check in addition to focused-helper valid diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu index f8dc958de3e86..c89ec26ef1552 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu @@ -174,8 +174,8 @@ __device__ __forceinline__ void swiglu_epilogue(void* out, void* tile_acc, void* } } -template -__device__ __forceinline__ void partial_epilogue(float* partial_out, void* tile_acc) { +template +__device__ __forceinline__ void partial_epilogue(PartialT* partial_out, void* tile_acc) { static constexpr int Interleave = Details::kInterleave; static constexpr int ThreadsPerInterleavedTile = Details::kThreadsPerInterleavedTile; static constexpr int WarpSize = Details::kWarpSize; @@ -204,14 +204,14 @@ __device__ __forceinline__ void partial_epilogue(float* partial_out, void* tile_ for (int warp = 0; warp < WarpNum; ++warp) { val += shmem[warp * CtaN * Interleave + col]; } - partial_out[col] = val; + partial_out[col] = static_cast(val); } } template + typename TypeA = typename Details::TypeDetailsA::Type, typename AccT = float, typename PartialT = AccT> __global__ void moe_gemv_splitk_partials_kernel( - TypeA* act, uint8_t* weight, TypeA* scales, float* partials, + TypeA* act, uint8_t* weight, TypeA* scales, PartialT* partials, int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, int64_t weight_expert_stride, int64_t scale_expert_stride, int n, int k, int64_t expanded_num_rows) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) @@ -307,14 +307,14 @@ __global__ void moe_gemv_splitk_partials_kernel( act_iterator.load(tile_a, iter, 0); mma(tile_acc, tile_w_pack2, tile_a); } - partial_epilogue(partials, tile_acc); + partial_epilogue(partials, tile_acc); #endif } template + typename TypeA = typename Details::TypeDetailsA::Type, typename PartialT = TypeA> __global__ void moe_gemv_splitk_reduce_swiglu_kernel( - float const* partials, TypeA* bias, TypeA* out, + PartialT const* partials, TypeA* bias, TypeA* out, int const* permuted_row_to_expert, int num_experts, int64_t const* expert_first_token_offset, int inter_size, int split_k, int64_t expanded_num_rows, cutlass_kernels::ActivationParams activation_params) { @@ -354,8 +354,8 @@ __global__ void moe_gemv_splitk_reduce_swiglu_kernel( float linear = 0.f; for (int split = 0; split < split_k; ++split) { int64_t const base = static_cast(split) * split_stride + row_base; - gate += partials[base + gate_idx]; - linear += partials[base + linear_idx]; + gate += static_cast(partials[base + gate_idx]); + linear += static_cast(partials[base + linear_idx]); } if constexpr (EnableBias) { @@ -525,19 +525,20 @@ static void launch_moe_gemv_interleaved_swiglu( } } -template +template static void launch_moe_gemv_splitk_twopass_swiglu( TypeA* act, uint8_t* weight, TypeA* scales, TypeA* bias, TypeA* out, int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, int64_t expanded_num_rows, int64_t inter_size, int64_t k, - cutlass_kernels::ActivationParams activation_params, float* partials, cudaStream_t stream) { + cutlass_kernels::ActivationParams activation_params, PartialT* partials, cudaStream_t stream) { static constexpr int StepK = Details::kStepK; static constexpr int CtaK = StepK * Threads; int64_t const n = inter_size * 2; int64_t const interleaved_k = k * Details::kInterleave; int const num_iters = static_cast((interleaved_k + CtaK - 1) / CtaK); if (partials == nullptr || num_iters < 2) { - launch_moe_gemv_interleaved_swiglu( + launch_moe_gemv_interleaved_swiglu( act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, expanded_num_rows, inter_size, k, activation_params, stream); return; @@ -548,7 +549,7 @@ static void launch_moe_gemv_splitk_twopass_swiglu( dim3 grid1(static_cast(expanded_num_rows), static_cast(n / (CtaN * Details::kInterleave)), SplitK); dim3 block1(Threads); - moe_gemv_splitk_partials_kernel + moe_gemv_splitk_partials_kernel <<>>( act, weight, scales, partials, expert_first_token_offset, permuted_row_to_expert, num_experts, weight_expert_stride, scale_expert_stride, static_cast(n), static_cast(k), expanded_num_rows); @@ -558,36 +559,36 @@ static void launch_moe_gemv_splitk_twopass_swiglu( static_cast((inter_size + kReduceThreads - 1) / kReduceThreads)); dim3 block2(kReduceThreads); if (bias != nullptr) { - moe_gemv_splitk_reduce_swiglu_kernel<<>>( + moe_gemv_splitk_reduce_swiglu_kernel<<>>( partials, bias, out, permuted_row_to_expert, num_experts, expert_first_token_offset, static_cast(inter_size), SplitK, expanded_num_rows, activation_params); } else { - moe_gemv_splitk_reduce_swiglu_kernel<<>>( + moe_gemv_splitk_reduce_swiglu_kernel<<>>( partials, bias, out, permuted_row_to_expert, num_experts, expert_first_token_offset, static_cast(inter_size), SplitK, expanded_num_rows, activation_params); } } -template +template static void dispatch_moe_gemv_splitk_twopass_swiglu_group_size( TypeA* act, uint8_t* weight, TypeA* scales, TypeA* bias, TypeA* out, int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, int64_t expanded_num_rows, int64_t inter_size, int64_t k, int group_size, - cutlass_kernels::ActivationParams activation_params, float* partials, cudaStream_t stream) { + cutlass_kernels::ActivationParams activation_params, PartialT* partials, cudaStream_t stream) { if (group_size <= 0) { - launch_moe_gemv_splitk_twopass_swiglu( + launch_moe_gemv_splitk_twopass_swiglu( act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, expanded_num_rows, inter_size, k, activation_params, partials, stream); } else if (group_size == 32) { - launch_moe_gemv_splitk_twopass_swiglu( + launch_moe_gemv_splitk_twopass_swiglu( act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, expanded_num_rows, inter_size, k, activation_params, partials, stream); } else if (group_size == 64) { - launch_moe_gemv_splitk_twopass_swiglu( + launch_moe_gemv_splitk_twopass_swiglu( act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, expanded_num_rows, inter_size, k, activation_params, partials, stream); } else if (group_size == 128) { - launch_moe_gemv_splitk_twopass_swiglu( + launch_moe_gemv_splitk_twopass_swiglu( act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, expanded_num_rows, inter_size, k, activation_params, partials, stream); } else { @@ -797,25 +798,34 @@ void launch_moe_gemv_int_symmetric_interleaved_swiglu( T const* act, WeightType const* weight, T const* scales, T const* bias, T* out, int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, int64_t expanded_num_rows, int64_t inter_size, int64_t k, int group_size, int sm, - cutlass_kernels::ActivationParams activation_params, float* splitk_partials, cudaStream_t stream) { + cutlass_kernels::ActivationParams activation_params, void* splitk_partials, cudaStream_t stream) { ORT_UNUSED_PARAMETER(sm); using Details = typename DetailsForTAndWeight::Details; using TypeA = typename DetailsForTAndWeight::TypeA; + // Accumulate in fp32 by default (see launch_moe_gemv_int_symmetric for the policy). + bool const use_fp32_accum = !std::is_same_v || !MoeGemvUseFp16Accum(); if (splitk_partials != nullptr) { if constexpr (std::is_same_v) { - fiv::dispatch_moe_gemv_splitk_twopass_swiglu_group_size( - const_cast(reinterpret_cast(act)), - const_cast(reinterpret_cast(weight)), - const_cast(reinterpret_cast(scales)), - const_cast(reinterpret_cast(bias)), - reinterpret_cast(out), - expert_first_token_offset, permuted_row_to_expert, num_experts, expanded_num_rows, inter_size, k, group_size, - activation_params, splitk_partials, stream); + auto launch_splitk = [&](auto acc_tag) { + using AccT = typename decltype(acc_tag)::type; + using PartialT = AccT; + fiv::dispatch_moe_gemv_splitk_twopass_swiglu_group_size( + const_cast(reinterpret_cast(act)), + const_cast(reinterpret_cast(weight)), + const_cast(reinterpret_cast(scales)), + const_cast(reinterpret_cast(bias)), + reinterpret_cast(out), + expert_first_token_offset, permuted_row_to_expert, num_experts, expanded_num_rows, inter_size, k, group_size, + activation_params, reinterpret_cast(splitk_partials), stream); + }; + if (use_fp32_accum) { + launch_splitk(TypeTag{}); + } else { + launch_splitk(TypeTag{}); + } return; } } - // Accumulate in fp32 by default (see launch_moe_gemv_int_symmetric for the policy). - bool const use_fp32_accum = !std::is_same_v || !MoeGemvUseFp16Accum(); auto launch = [&](auto acc_tag) { using AccT = typename decltype(acc_tag)::type; fiv::dispatch_moe_gemv_interleaved_swiglu_group_size( @@ -864,10 +874,10 @@ template void launch_moe_gemv_int_symmetric( int64_t, int64_t, int64_t, int, int, cudaStream_t); template void launch_moe_gemv_int_symmetric_interleaved_swiglu( half const*, cutlass::uint4b_t const*, half const*, half const*, half*, int64_t const*, int const*, int, - int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, float*, cudaStream_t); + int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, void*, cudaStream_t); template void launch_moe_gemv_int_symmetric_interleaved_swiglu( half const*, uint8_t const*, half const*, half const*, half*, int64_t const*, int const*, int, - int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, float*, cudaStream_t); + int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, void*, cudaStream_t); template void launch_moe_gemv_int4_per_channel(half const*, uint8_t const*, half const*, half const*, half*, int64_t const*, int const*, int, int64_t, int64_t, int64_t, @@ -886,11 +896,11 @@ template void launch_moe_gemv_int_symmetric<__nv_bfloat16, uint8_t>( template void launch_moe_gemv_int_symmetric_interleaved_swiglu<__nv_bfloat16, cutlass::uint4b_t>( __nv_bfloat16 const*, cutlass::uint4b_t const*, __nv_bfloat16 const*, __nv_bfloat16 const*, __nv_bfloat16*, int64_t const*, int const*, int, int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, - float*, cudaStream_t); + void*, cudaStream_t); template void launch_moe_gemv_int_symmetric_interleaved_swiglu<__nv_bfloat16, uint8_t>( __nv_bfloat16 const*, uint8_t const*, __nv_bfloat16 const*, __nv_bfloat16 const*, __nv_bfloat16*, int64_t const*, int const*, int, int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, - float*, cudaStream_t); + void*, cudaStream_t); template void launch_moe_gemv_int4_per_channel<__nv_bfloat16>( __nv_bfloat16 const*, uint8_t const*, __nv_bfloat16 const*, __nv_bfloat16 const*, __nv_bfloat16*, diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.h b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.h index c855106f9c08d..229f9839e62f1 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.h +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.h @@ -53,7 +53,7 @@ void launch_moe_gemv_int_symmetric_interleaved_swiglu( T const* act, WeightType const* weight, T const* scales, T const* bias, T* out, int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, int64_t expanded_num_rows, int64_t inter_size, int64_t k, int group_size, int sm, cutlass_kernels::ActivationParams activation_params, - float* splitk_partials, + void* splitk_partials, cudaStream_t stream); // Launches the int4 per-channel MoE GEMV. From 64db1ce77a5c7552cdb0d6b30b00b5ff170e5ce6 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jun 2026 16:57:58 -0700 Subject: [PATCH 3/3] env vars for fp32_accum and splik routing --- docs/contrib_ops/cuda/moe_qmoe.md | 18 ++++----- .../contrib_ops/cuda/qmoe_gemv_experiments.md | 32 ++++++++-------- .../contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu | 15 ++++++-- .../cuda/llm/moe_gemm/moe_kernels.cu | 19 +++++++--- .../python/transformers/profile_qmoe_gemv.py | 15 ++++---- .../python/transformers/profile_qmoe_gemv.sh | 37 ++++++++++--------- .../python/transformers/test_qmoe_cuda.py | 17 +++------ 7 files changed, 82 insertions(+), 71 deletions(-) diff --git a/docs/contrib_ops/cuda/moe_qmoe.md b/docs/contrib_ops/cuda/moe_qmoe.md index 28c0c6eb3d055..8cd698ac4ef2a 100644 --- a/docs/contrib_ops/cuda/moe_qmoe.md +++ b/docs/contrib_ops/cuda/moe_qmoe.md @@ -1006,10 +1006,10 @@ accuracy at 0.8260 for both modes. #### Split-K2 SwiGLU GEMV route The fp16 INT4 interleaved-SwiGLU GEMV path can use a two-pass Split-K2 FC1 kernel -for supported decode shapes. With the default fp16 accumulation policy, the -single-kernel FC1 SwiGLU path is used unless `ORT_MOE_GEMV_SPLITK2_SWIGLU=1` -forces Split-K2. With `ORT_MOE_GEMV_FP32_ACCUM=1`, Split-K2 is used by default -unless `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` disables it. +for supported decode shapes. `ORT_MOE_GEMV_FP32_ACCUM=1` enables fp32 +accumulation, and `ORT_MOE_GEMV_SPLITK2_SWIGLU=1` enables Split-K2. Both default +to `0`, so the default route is fp16 accumulation with the single-kernel FC1 +SwiGLU path. The first pass computes two K-split partials into QMoE workspace using the same accumulator type as the normal GEMV @@ -1018,19 +1018,19 @@ selected, and the fp32 fallback uses fp32 partials. The second pass reduces thos partials in fp32, adds optional bias, and applies the interleaved SwiGLU epilogue. FC2 stays on the regular `moe_gemv_kernel` path. -Use `ORT_MOE_GEMV_SPLITK2_SWIGLU=1` and -`ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` before process start for A/B -benchmarking or bisecting numerical differences. On GPT-OSS-20B, Split-K2 +Use the two binary knobs before process start for A/B benchmarking or bisecting +numerical differences. The focused profiler exposes the same controls as +`--fp32-accum` and `--splitk2-swiglu`. On GPT-OSS-20B, Split-K2 reduced FC1 kernel work from about 21.42 us to 19.98 us in the fp32-accumulation route and improved repeated CUDA-graph decode throughput by about 0.9% to 1.6% -with valid focused-helper output. A 1000-sample MMLU smoke matched the opt-out +with valid focused-helper output. A 1000-sample MMLU smoke matched the non-Split-K fallback within noise. A future autotuner can replace this hand-selected routing with per-shape route selection. ```bash onnxruntime/test/python/transformers/profile_qmoe_gemv.py \ --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 \ - --splitk2-swiglu --warmup 5 --repeat 100 --nvtx + --fp32-accum --splitk2-swiglu --warmup 5 --repeat 100 --nvtx ``` #### Experiments rejected after profiling diff --git a/docs/contrib_ops/cuda/qmoe_gemv_experiments.md b/docs/contrib_ops/cuda/qmoe_gemv_experiments.md index fa3aad40a4a94..161fb88621a39 100644 --- a/docs/contrib_ops/cuda/qmoe_gemv_experiments.md +++ b/docs/contrib_ops/cuda/qmoe_gemv_experiments.md @@ -985,8 +985,8 @@ Every case reported `has_invalid_output=false`. - Code commit: `f1d6718be719c1237be392c0389874b6a8926a3c` (`Experiment QMoE split-K SwiGLU GEMV`). -- Added default Split-K2 route with opt-out env knob: - `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1`. +- Added Split-K2 route with opt-in env knob: + `ORT_MOE_GEMV_SPLITK2_SWIGLU=1`. - Scope: FP16 INT4/interleaved-SwiGLU FC1 GEMV path for decode-shaped QMoE. - Implementation: - First pass launches `moe_gemv_splitk_partials_kernel` with `SplitK=2` and @@ -996,8 +996,8 @@ Every case reported `has_invalid_output=false`. - Second pass launches `moe_gemv_splitk_reduce_swiglu_kernel` to reduce the partials in fp32, add optional bias, and apply SwiGLU. - FC2 remains on the existing `moe_gemv_kernel`. - - Scratch is allocated only for the supported Split-K2 route. Setting - `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` restores the previous single-kernel + - Scratch is allocated only for the supported Split-K2 route. Leaving + `ORT_MOE_GEMV_SPLITK2_SWIGLU` unset or setting it to `0` keeps the previous single-kernel FC1 SwiGLU GEMV path. ### Repro Notes @@ -1034,7 +1034,7 @@ Both modes reported `has_invalid_output=false`. | Mode | Env | Latency ms | |------|-----|------------| -| Baseline | `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` | 0.072344 | +| Baseline | unset | 0.072344 | | Split-K2 | none | 0.073816 | The short helper was slightly slower with split-K2, so Nsight was required to @@ -1054,7 +1054,7 @@ Command shape: -o /tmp/qmoe_gptoss_splitk_final --export=sqlite \ ~/onnxruntime/.venv_cu130/bin/python \ ~/onnxruntime/onnxruntime/test/python/transformers/profile_qmoe_gemv.py \ - --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --warmup 3 --repeat 30 --nvtx + --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --warmup 3 --repeat 30 --nvtx --splitk2-swiglu ``` Parsed with `parse_nsys.py --nvtx-range benchmark --pattern '%'`. @@ -1089,7 +1089,7 @@ ORT_FORCE_DETERMINISTIC_MOE=1 \ bash scripts/bench_gpt_oss_ort_decode.sh ``` -Baseline additionally set `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1`. +Baseline left `ORT_MOE_GEMV_SPLITK2_SWIGLU` unset. | Run | Mode | Decode latency ms/token | Decode throughput tok/s | |-----|------|-------------------------|-------------------------| @@ -1103,8 +1103,7 @@ pair showed about `+1.6%`. Since the focused helper reported valid output and the model-level gain repeated in the same direction, even this modest gain is worth enabling for GPT-OSS-20B decode while keeping an opt-out for A/B checks. -After flipping Split-K2 to the default and adding -`ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` as the opt-out, three more paired +After testing Split-K2 as the selected route, three more paired CUDA-graph model runs were collected with `REPS=10`, `WARMUP=3`, prompt length 512, and generation length 128: @@ -1174,22 +1173,21 @@ fp16 Split-K2 again sat between the fp16 single-kernel path and the A 1000-sample `match_mmlu` smoke was run with the local parallel eval harness on all eight H200 GPUs, using the same GPT-OSS-20B INT4 QMoE model package and the current ORT build package. The default Split-K2 run scored `0.8380` pooled -accuracy; the opt-out fallback with `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` +accuracy; the non-Split-K fallback with `ORT_MOE_GEMV_SPLITK2_SWIGLU` unset scored `0.8350`. The small positive difference is within smoke-test noise, and -there is no accuracy regression signal from enabling Split-K2 by default. +there is no accuracy regression signal from enabling Split-K2. ### Decision -- Enable Split-K2 by default for its supported fp16 INT4 interleaved-SwiGLU GEMV - scope when `ORT_MOE_GEMV_FP32_ACCUM=1` selects fp32 accumulation. +- Keep Split-K2 available for its supported fp16 INT4 interleaved-SwiGLU GEMV + scope when `ORT_MOE_GEMV_SPLITK2_SWIGLU=1` enables it. - Keep the fp16-accumulation Split-K2 variant available. It is slower than the single-kernel fp16-accumulation path on the GPT-OSS shape, but faster than the fp32-accumulation Split-K2 route and may be selected by future per-shape autotuning. -- With default fp16 accumulation, use the single-kernel FC1 SwiGLU path unless - `ORT_MOE_GEMV_SPLITK2_SWIGLU=1` forces Split-K2. Keep - `ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1` as the fp32-accumulation fallback and - A/B knob. +- Use two binary route controls: `ORT_MOE_GEMV_FP32_ACCUM=1` enables fp32 + accumulation, and `ORT_MOE_GEMV_SPLITK2_SWIGLU=1` enables Split-K2. Both + default to `0`. - The 1000-sample MMLU smoke matched the opt-out fallback within noise, so the default flip has an accuracy sanity check in addition to focused-helper valid output. diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu index 0b62c0f830bca..996f59170e016 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu @@ -677,10 +677,17 @@ inline bool MoeGemvUseFp32Accum() { return enabled; } -inline bool MoeGemvForceSplitK2SwiGLU() { +inline bool MoeGemvUseSplitK2SwiGLU() { // Parsed once via ORT's environment helper (consistent parsing/thread-safety across platforms). - static bool const enabled = - onnxruntime::ParseEnvironmentVariableWithDefault("ORT_MOE_GEMV_SPLITK2_SWIGLU", 0) == 1; + static bool const enabled = [] { + auto const value = onnxruntime::ParseEnvironmentVariable("ORT_MOE_GEMV_SPLITK2_SWIGLU"); + if (!value.has_value()) { + return false; + } + ORT_ENFORCE(*value == 0 || *value == 1, + "ORT_MOE_GEMV_SPLITK2_SWIGLU must be 0 or 1, but got ", *value); + return *value == 1; + }(); return enabled; } @@ -810,7 +817,7 @@ void launch_moe_gemv_int_symmetric_interleaved_swiglu( using TypeA = typename DetailsForTAndWeight::TypeA; // Accumulation policy matches launch_moe_gemv_int_symmetric. bool const use_fp32_accum = !std::is_same_v || MoeGemvUseFp32Accum(); - if (splitk_partials != nullptr && (use_fp32_accum || MoeGemvForceSplitK2SwiGLU())) { + if (splitk_partials != nullptr && MoeGemvUseSplitK2SwiGLU()) { if constexpr (std::is_same_v) { auto launch_splitk = [&](auto acc_tag) { using AccT = typename decltype(acc_tag)::type; diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu index 68e825b5f88ae..6d708b3fc586a 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu @@ -83,11 +83,18 @@ inline bool MoeGemvDisabledByEnv() { return disabled; } -inline bool MoeGemvSplitK2SwiGLUDisabledByEnv() { +inline bool MoeGemvSplitK2SwiGLUEnabledByEnv() { // Parsed once via ORT's environment helper (consistent parsing/thread-safety across platforms). - static bool const disabled = - onnxruntime::ParseEnvironmentVariableWithDefault("ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU", 0) == 1; - return disabled; + static bool const enabled = [] { + auto const value = onnxruntime::ParseEnvironmentVariable("ORT_MOE_GEMV_SPLITK2_SWIGLU"); + if (!value.has_value()) { + return false; + } + ORT_ENFORCE(*value == 0 || *value == 1, + "ORT_MOE_GEMV_SPLITK2_SWIGLU must be 0 or 1, but got ", *value); + return *value == 1; + }(); + return enabled; } template @@ -2185,7 +2192,7 @@ CutlassMoeFCRunner: size_t const fc1_result_dedicated_size = (glu_inter_elems > 0) ? fc1_result_size : 0; static constexpr int kMoeGemvSplitK = 2; - size_t const moe_gemv_splitk_partials_size = is_gated_activation && !MoeGemvSplitK2SwiGLUDisabledByEnv() && + size_t const moe_gemv_splitk_partials_size = is_gated_activation && MoeGemvSplitK2SwiGLUEnabledByEnv() && MoeGemvSplitK2SwiGLUSupported() ? kMoeGemvSplitK * glu_inter_elems * sizeof(float) : 0; @@ -2476,7 +2483,7 @@ void CutlassMoeFCRunner(); ORT_ENFORCE(!use_splitk_swiglu_gemv || moe_gemv_splitk_partials != nullptr, "Split-K2 SwiGLU GEMV requires split-K GEMV workspace"); diff --git a/onnxruntime/test/python/transformers/profile_qmoe_gemv.py b/onnxruntime/test/python/transformers/profile_qmoe_gemv.py index 3f8dce98ce612..c674b44831032 100644 --- a/onnxruntime/test/python/transformers/profile_qmoe_gemv.py +++ b/onnxruntime/test/python/transformers/profile_qmoe_gemv.py @@ -75,12 +75,12 @@ def main(): parser.add_argument( "--splitk2-swiglu", action="store_true", - help="Force split-K2 two-pass FC1 SwiGLU GEMV by setting ORT_MOE_GEMV_SPLITK2_SWIGLU=1", + help="Enable split-K2 two-pass FC1 SwiGLU GEMV by setting ORT_MOE_GEMV_SPLITK2_SWIGLU=1", ) parser.add_argument( - "--disable-splitk2-swiglu", + "--fp32-accum", action="store_true", - help="Disable split-K2 two-pass FC1 SwiGLU GEMV by setting ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU=1", + help="Enable fp32 GEMV accumulation by setting ORT_MOE_GEMV_FP32_ACCUM=1", ) parser.add_argument( "--nvtx", @@ -111,11 +111,12 @@ def main(): else: os.environ.pop("ORT_DISABLE_MOE_GEMV", None) - if args.disable_splitk2_swiglu: - os.environ["ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU"] = "1" + if args.fp32_accum: + os.environ["ORT_MOE_GEMV_FP32_ACCUM"] = "1" else: - os.environ.pop("ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU", None) - if args.splitk2_swiglu and not args.disable_splitk2_swiglu: + os.environ.pop("ORT_MOE_GEMV_FP32_ACCUM", None) + + if args.splitk2_swiglu: os.environ["ORT_MOE_GEMV_SPLITK2_SWIGLU"] = "1" else: os.environ.pop("ORT_MOE_GEMV_SPLITK2_SWIGLU", None) diff --git a/onnxruntime/test/python/transformers/profile_qmoe_gemv.sh b/onnxruntime/test/python/transformers/profile_qmoe_gemv.sh index 979cf683ef13d..ce1473d83bcff 100755 --- a/onnxruntime/test/python/transformers/profile_qmoe_gemv.sh +++ b/onnxruntime/test/python/transformers/profile_qmoe_gemv.sh @@ -11,7 +11,8 @@ # ./profile_qmoe_gemv.sh --list-cases # ./profile_qmoe_gemv.sh --case m8_top2_fp16_128x256 --warmup 5 --repeat 200 # ./profile_qmoe_gemv.sh --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --warmup 5 --repeat 100 -# ./profile_qmoe_gemv.sh --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --disable-splitk2-swiglu --warmup 5 --repeat 100 +# ./profile_qmoe_gemv.sh --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --splitk2-swiglu --warmup 5 --repeat 100 +# ./profile_qmoe_gemv.sh --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --fp32-accum --splitk2-swiglu --warmup 5 --repeat 100 # ./profile_qmoe_gemv.sh --batch-size 1 --sequence-length 1 --hidden-size 1024 --intermediate-size 4096 --num-experts 8 --top-k 2 --quant-bits 8 --block-size 128 # CUDA_VISIBLE_DEVICES=1 ./profile_qmoe_gemv.sh -o /tmp/qmoe_gemv # @@ -27,8 +28,8 @@ OUTPUT_NAME="qmoe_gemv_profile" PY="${PYTHON:-python}" EXTRA_ARGS=() LIST_CASES=0 -DISABLE_SPLITK2_SWIGLU=0 -FORCE_SPLITK2_SWIGLU=0 +FP32_ACCUM=0 +SPLITK2_SWIGLU=0 while [[ "$#" -gt 0 ]]; do case $1 in @@ -40,12 +41,10 @@ while [[ "$#" -gt 0 ]]; do LIST_CASES=1 ;; --splitk2-swiglu) - FORCE_SPLITK2_SWIGLU=1 - DISABLE_SPLITK2_SWIGLU=0 + SPLITK2_SWIGLU=1 ;; - --disable-splitk2-swiglu) - DISABLE_SPLITK2_SWIGLU=1 - FORCE_SPLITK2_SWIGLU=0 + --fp32-accum) + FP32_ACCUM=1 ;; --batch-size|--sequence-length|--hidden-size|--intermediate-size|--num-experts|--top-k|--dtype|--quant-bits|--block-size) EXTRA_ARGS+=("$1" "$2") @@ -69,7 +68,7 @@ while [[ "$#" -gt 0 ]]; do ;; *) echo "Unknown option: $1" - echo "Usage: $0 [--list-cases] [--case NAME] [--disable-splitk2-swiglu] [--batch-size N] [--sequence-length N] [--hidden-size N] [--intermediate-size N] [--num-experts N] [--top-k N] [--dtype FLOAT16|BFLOAT16] [--quant-bits 4|8] [--block-size 0|32|64|128] [--warmup N] [--repeat N] [--python PYTHON] [-o NAME]" + echo "Usage: $0 [--list-cases] [--case NAME] [--fp32-accum] [--splitk2-swiglu] [--batch-size N] [--sequence-length N] [--hidden-size N] [--intermediate-size N] [--num-experts N] [--top-k N] [--dtype FLOAT16|BFLOAT16] [--quant-bits 4|8] [--block-size 0|32|64|128] [--warmup N] [--repeat N] [--python PYTHON] [-o NAME]" exit 1 ;; esac @@ -109,22 +108,26 @@ fi if [[ "${#EXTRA_ARGS[@]}" -gt 0 ]]; then echo "Custom args: ${EXTRA_ARGS[*]}" fi -if [[ "${DISABLE_SPLITK2_SWIGLU}" -eq 1 ]]; then - echo "Split-K2 SwiGLU: disabled for GEMV mode" -elif [[ "${FORCE_SPLITK2_SWIGLU}" -eq 1 ]]; then - echo "Split-K2 SwiGLU: forced for GEMV mode" +if [[ "${FP32_ACCUM}" -eq 1 ]]; then + echo "GEMV accumulation: fp32" +fi +if [[ "${SPLITK2_SWIGLU}" -eq 1 ]]; then + echo "Split-K2 SwiGLU: enabled for GEMV mode" fi profile_one() { local mode="$1" local disable_arg="" + local fp32_accum_arg="" local splitk2_arg="" local base="${OUTPUT_NAME}_${mode}" if [[ "${mode}" == "gemm" ]]; then disable_arg="--disable-gemv" - elif [[ "${DISABLE_SPLITK2_SWIGLU}" -eq 1 ]]; then - splitk2_arg="--disable-splitk2-swiglu" - elif [[ "${FORCE_SPLITK2_SWIGLU}" -eq 1 ]]; then + fi + if [[ "${FP32_ACCUM}" -eq 1 ]]; then + fp32_accum_arg="--fp32-accum" + fi + if [[ "${SPLITK2_SWIGLU}" -eq 1 ]]; then splitk2_arg="--splitk2-swiglu" fi @@ -134,7 +137,7 @@ profile_one() { nsys profile -t cuda,nvtx --force-overwrite true -o "${base}" --export=sqlite \ "${PY}" "${SCRIPT_DIR}/profile_qmoe_gemv.py" \ --case "${CASE}" "${EXTRA_ARGS[@]}" --warmup "${WARMUP}" --repeat "${REPEAT}" --nvtx \ - ${disable_arg} ${splitk2_arg} + ${disable_arg} ${fp32_accum_arg} ${splitk2_arg} echo "" echo "---- Kernel results (${mode}) ----" diff --git a/onnxruntime/test/python/transformers/test_qmoe_cuda.py b/onnxruntime/test/python/transformers/test_qmoe_cuda.py index 17e4ded3273b3..aa7c8f3db8470 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cuda.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cuda.py @@ -2486,8 +2486,8 @@ def run_qmoe_gemv_benchmark(case): "case": case["name"], "block_size": case.get("block_size", 0), "disable_gemv": os.getenv("ORT_DISABLE_MOE_GEMV") == "1", - "disable_splitk2_swiglu": os.getenv("ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU") == "1", - "force_splitk2_swiglu": os.getenv("ORT_MOE_GEMV_SPLITK2_SWIGLU") == "1", + "fp32_accum": os.getenv("ORT_MOE_GEMV_FP32_ACCUM", "0"), + "splitk2_swiglu": os.getenv("ORT_MOE_GEMV_SPLITK2_SWIGLU", "0"), "expanded_num_rows": case["batch_size"] * case["sequence_length"] * case["top_k"], "has_invalid_output": bool(torch.isnan(output).any() or torch.isinf(output).any()), "latency_ms": qmoe.last_ort_latency_ms, @@ -2514,24 +2514,19 @@ def test_decode_latency(self): self.assertFalse(result["has_invalid_output"]) print(_QMOE_GEMV_BENCHMARK_RESULT_PREFIX + json.dumps(result, sort_keys=True)) - @unittest.skipIf( - os.getenv("ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU") == "1", - "Unset ORT_DISABLE_MOE_GEMV_SPLITK2_SWIGLU to run the default split-K2 SwiGLU GEMV benchmark.", - ) def test_splitk2_swiglu_decode_latency(self): - previous_force_splitk2 = os.environ.get("ORT_MOE_GEMV_SPLITK2_SWIGLU") + previous_splitk2 = os.environ.get("ORT_MOE_GEMV_SPLITK2_SWIGLU") os.environ["ORT_MOE_GEMV_SPLITK2_SWIGLU"] = "1" try: result = run_qmoe_gemv_benchmark_case("gpt_oss_20b_m1_top4_fp16_2880x2880_e32") - self.assertFalse(result["disable_splitk2_swiglu"]) - self.assertTrue(result["force_splitk2_swiglu"]) + self.assertEqual(result["splitk2_swiglu"], "1") self.assertFalse(result["has_invalid_output"]) print(_QMOE_GEMV_BENCHMARK_RESULT_PREFIX + json.dumps(result, sort_keys=True)) finally: - if previous_force_splitk2 is None: + if previous_splitk2 is None: os.environ.pop("ORT_MOE_GEMV_SPLITK2_SWIGLU", None) else: - os.environ["ORT_MOE_GEMV_SPLITK2_SWIGLU"] = previous_force_splitk2 + os.environ["ORT_MOE_GEMV_SPLITK2_SWIGLU"] = previous_splitk2 @unittest.skipIf(True, "Skipping QMoE benchmark tests")