diff --git a/docs/contrib_ops/cuda/moe_qmoe.md b/docs/contrib_ops/cuda/moe_qmoe.md index fbddd59f8bebd..8cd698ac4ef2a 100644 --- a/docs/contrib_ops/cuda/moe_qmoe.md +++ b/docs/contrib_ops/cuda/moe_qmoe.md @@ -1003,6 +1003,36 @@ CUDA-graph decode run, default fp16 accumulation reached 386.26 tok/s versus 353.70 tok/s with the fp32 fallback. A 1000-sample MMLU smoke test matched pooled accuracy at 0.8260 for both modes. +#### Split-K2 SwiGLU GEMV route + +The fp16 INT4 interleaved-SwiGLU GEMV path can use a two-pass Split-K2 FC1 kernel +for supported decode shapes. `ORT_MOE_GEMV_FP32_ACCUM=1` enables fp32 +accumulation, and `ORT_MOE_GEMV_SPLITK2_SWIGLU=1` enables Split-K2. Both default +to `0`, so the default route is fp16 accumulation with the single-kernel FC1 +SwiGLU path. + +The first pass computes two K-split +partials into QMoE workspace using the same accumulator type as the normal GEMV +path: fp16 activations use fp16 partials when the fp16-accumulation route is +selected, and the fp32 fallback uses fp32 partials. The second pass reduces those +partials in fp32, adds optional bias, and applies the interleaved SwiGLU +epilogue. FC2 stays on the regular `moe_gemv_kernel` path. + +Use the two binary knobs before process start for A/B benchmarking or bisecting +numerical differences. The focused profiler exposes the same controls as +`--fp32-accum` and `--splitk2-swiglu`. On GPT-OSS-20B, Split-K2 +reduced FC1 kernel work from about 21.42 us to 19.98 us in the fp32-accumulation +route and improved repeated CUDA-graph decode throughput by about 0.9% to 1.6% +with valid focused-helper output. A 1000-sample MMLU smoke matched the non-Split-K +fallback within noise. A future autotuner can replace this hand-selected routing +with per-shape route selection. + +```bash +onnxruntime/test/python/transformers/profile_qmoe_gemv.py \ + --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 \ + --fp32-accum --splitk2-swiglu --warmup 5 --repeat 100 --nvtx +``` + #### Experiments rejected after profiling | Experiment | Why it was rejected | diff --git a/docs/contrib_ops/cuda/qmoe_gemv_experiments.md b/docs/contrib_ops/cuda/qmoe_gemv_experiments.md index d2165640b0051..161fb88621a39 100644 --- a/docs/contrib_ops/cuda/qmoe_gemv_experiments.md +++ b/docs/contrib_ops/cuda/qmoe_gemv_experiments.md @@ -979,6 +979,224 @@ Every case reported `has_invalid_output=false`. - Per-column INT8 W8A16 decode shapes route to GEMV for both FP16 and BF16 and beat the grouped-GEMM fallback at every profiled shape. +## 2026-06-19: Split-K2 Two-Pass SwiGLU GEMV Experiment + +### Change Under Test + +- Code commit: `f1d6718be719c1237be392c0389874b6a8926a3c` + (`Experiment QMoE split-K SwiGLU GEMV`). +- Added Split-K2 route with opt-in env knob: + `ORT_MOE_GEMV_SPLITK2_SWIGLU=1`. +- Scope: FP16 INT4/interleaved-SwiGLU FC1 GEMV path for decode-shaped QMoE. +- Implementation: + - First pass launches `moe_gemv_splitk_partials_kernel` with `SplitK=2` and + writes accumulator-typed partials into QMoE workspace. This follows the + normal GEMV accumulation policy: fp16 partials for fp16 accumulation, fp32 + partials for the fp32 fallback. + - Second pass launches `moe_gemv_splitk_reduce_swiglu_kernel` to reduce the + partials in fp32, add optional bias, and apply SwiGLU. + - FC2 remains on the existing `moe_gemv_kernel`. + - Scratch is allocated only for the supported Split-K2 route. Leaving + `ORT_MOE_GEMV_SPLITK2_SWIGLU` unset or setting it to `0` keeps the previous single-kernel + FC1 SwiGLU GEMV path. + +### Repro Notes + +- Build: `cmake --build build/cu130/Release --target onnxruntime_providers_cuda --parallel $(nproc)`. +- Important provider sync: Python tests importing from + `build/cu130/Release/onnxruntime` load + `build/cu130/Release/onnxruntime/capi/libonnxruntime_providers_cuda.so`, not + only the top-level `build/cu130/Release/libonnxruntime_providers_cuda.so` or + the venv copy. Sync all relevant copies before measuring: + + ```bash + cp build/cu130/Release/libonnxruntime_providers_cuda.so \ + build/cu130/Release/onnxruntime/capi/libonnxruntime_providers_cuda.so + cp build/cu130/Release/libonnxruntime_providers_cuda.so \ + .venv_cu130/lib/python3.14/site-packages/onnxruntime/capi/libonnxruntime_providers_cuda.so + ``` + +- Focused QMoE helper: + + ```bash + cd ~ + CUDA_VISIBLE_DEVICES=1 \ + LD_LIBRARY_PATH=~/onnxruntime/build/cu130/Release:~/cuda13.0/lib64:~/cudnn9.19_cuda13/lib:~/cudnn9.19_cuda13/lib64:${LD_LIBRARY_PATH:-} \ + PYTHONPATH=~/onnxruntime/build/cu130/Release:~/onnxruntime/onnxruntime/test/python/transformers \ + ~/onnxruntime/.venv_cu130/bin/python \ + ~/onnxruntime/onnxruntime/test/python/transformers/profile_qmoe_gemv.py \ + --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --warmup 3 --repeat 20 + ``` + +### Focused QMoE Smoke + +Both modes reported `has_invalid_output=false`. + +| Mode | Env | Latency ms | +|------|-----|------------| +| Baseline | unset | 0.072344 | +| Split-K2 | none | 0.073816 | + +The short helper was slightly slower with split-K2, so Nsight was required to +confirm route selection and isolate kernel time. + +### Nsight Systems Kernel Results + +Artifacts: + +- Baseline: `/tmp/qmoe_gptoss_baseline_final.{nsys-rep,sqlite}` +- Split-K2: `/tmp/qmoe_gptoss_splitk_final.{nsys-rep,sqlite}` + +Command shape: + +```bash +~/cuda13.0/bin/nsys profile -t cuda,nvtx --force-overwrite true \ + -o /tmp/qmoe_gptoss_splitk_final --export=sqlite \ + ~/onnxruntime/.venv_cu130/bin/python \ + ~/onnxruntime/onnxruntime/test/python/transformers/profile_qmoe_gemv.py \ + --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --warmup 3 --repeat 30 --nvtx --splitk2-swiglu +``` + +Parsed with `parse_nsys.py --nvtx-range benchmark --pattern '%'`. + +| Mode | Kernel | Calls | Avg us | +|------|--------|-------|--------| +| Baseline | `moe_gemv_interleaved_swiglu_kernel` | 30 | 21.42 | +| Baseline | `moe_gemv_kernel` | 30 | 12.13 | +| Split-K2 | `moe_gemv_splitk_partials_kernel` | 30 | 17.59 | +| Split-K2 | `moe_gemv_splitk_reduce_swiglu_kernel` | 30 | 2.39 | +| Split-K2 | `moe_gemv_kernel` | 30 | 12.22 | + +Split-K2 reduced FC1 kernel work from about `21.42 us` to `17.59 + 2.39 = +19.98 us`, a net FC1 reduction of about `1.44 us` per QMoE invocation. End-to-end +under Nsight was effectively tied: + +| Mode | Helper latency ms | +|------|-------------------| +| Baseline | 0.079855 | +| Split-K2 | 0.079728 | + +### Model-Level Decode Benchmark With CUDA Graph + +The user requested model-level measurement assuming CUDA graph. Both runs used +the GPT-OSS-20B INT4 QMoE model package, CUDA graph enabled, XQA enabled, and +deterministic MoE tactic selection: + +```bash +MODEL=models/gpt-oss-20b/variants/cuda_int4_int4_qmoe_rtn_matmul_only \ +GPU=0 PROMPT_LEN=512 GEN_LEN=128 REPS=10 WARMUP=3 CUDA_GRAPH=1 XQA=1 SYNC_LIB=1 \ +ORT_FORCE_DETERMINISTIC_MOE=1 \ +bash scripts/bench_gpt_oss_ort_decode.sh +``` + +Baseline left `ORT_MOE_GEMV_SPLITK2_SWIGLU` unset. + +| Run | Mode | Decode latency ms/token | Decode throughput tok/s | +|-----|------|-------------------------|-------------------------| +| R1, `REPS=5`, `WARMUP=2` | Baseline | 2.869450 | 348.498901 | +| R1, `REPS=5`, `WARMUP=2` | Split-K2 | 2.823800 | 354.132707 | +| R2, `REPS=10`, `WARMUP=3` | Baseline | 2.865840 | 348.937861 | +| R2, `REPS=10`, `WARMUP=3` | Split-K2 | 2.839335 | 352.195107 | + +The longer CUDA-graph pair showed about `+0.9%` decode throughput. The shorter +pair showed about `+1.6%`. Since the focused helper reported valid output and +the model-level gain repeated in the same direction, even this modest gain is +worth enabling for GPT-OSS-20B decode while keeping an opt-out for A/B checks. + +After testing Split-K2 as the selected route, three more paired +CUDA-graph model runs were collected with `REPS=10`, `WARMUP=3`, prompt length +512, and generation length 128: + +| Run | Mode | Decode latency ms/token | Decode throughput tok/s | +|-----|------|-------------------------|-------------------------| +| R3 | Default Split-K2 | 3.017252 | 331.427448 | +| R3 | Split-K2 disabled | 3.055736 | 327.253380 | +| R4 | Default Split-K2 | 3.006739 | 332.586260 | +| R4 | Split-K2 disabled | 3.047570 | 328.130314 | +| R5 | Default Split-K2 | 3.009466 | 332.284898 | +| R5 | Split-K2 disabled | 3.047015 | 328.190090 | +| Average | Default Split-K2 | 3.011152 | 332.099536 | +| Average | Split-K2 disabled | 3.050107 | 327.857928 | + +The default Split-K2 route was faster in all three pairs, averaging `+1.29%` +decode throughput and `-1.28%` decode latency versus the opt-out fallback. + +### FP16 Accumulation Follow-Up + +After the normal QMoE GEMV path changed to use fp16 accumulation by default, the +Split-K2 route was rechecked on the same GPT-OSS-20B decode shape. Sequential +focused-helper runs with `--repeat 100` showed Split-K2 behind the single-kernel +path: + +| Run | Mode | Latency ms/inference | +|-----|------|----------------------| +| R1 | Split-K2 | 0.061761 | +| R1 | Split-K2 disabled | 0.060108 | +| R2 | Split-K2 | 0.062862 | +| R2 | Split-K2 disabled | 0.060989 | +| R3 | Split-K2 | 0.064595 | +| R3 | Split-K2 disabled | 0.060464 | +| Average | Split-K2 | 0.063073 | +| Average | Split-K2 disabled | 0.060520 | + +A short CUDA-graph model-level pair with `REPS=5`, `WARMUP=2`, prompt length +512, and generation length 128 showed the same direction: + +| Mode | Decode latency ms/token | Decode throughput tok/s | +|------|-------------------------|-------------------------| +| Split-K2 | 2.848148 | 351.105318 | +| Split-K2 disabled | 2.816800 | 355.012723 | + +Although the single-kernel path was faster for this GPT-OSS focused helper, +Split-K2 with default fp16 accumulation was still faster than the +`ORT_MOE_GEMV_FP32_ACCUM=1` Split-K2 route. The fp16 Split-K2 variant is kept so +a future autotuner can choose it for shapes where the extra K parallelism wins. + +The same focused profiler check was run for Qwen3.6-35B-A3B and Gemma4-26B-A4B +decode-shaped configs with `--repeat 100`: + +| Case | Mode | Latency ms/inference | +|------|------|----------------------| +| Qwen3.6-35B-A3B | fp16 Split-K2 | 0.049207 | +| Qwen3.6-35B-A3B | fp16 Split-K2 disabled | 0.047403 | +| Qwen3.6-35B-A3B | fp32 Split-K2 | 0.052055 | +| Gemma4-26B-A4B | fp16 Split-K2 | 0.053503 | +| Gemma4-26B-A4B | fp16 Split-K2 disabled | 0.050732 | +| Gemma4-26B-A4B | fp32 Split-K2 | 0.059571 | + +Both additional shapes produced valid output. In these focused helper runs, +fp16 Split-K2 again sat between the fp16 single-kernel path and the +`ORT_MOE_GEMV_FP32_ACCUM=1` Split-K2 path. + +### Accuracy Smoke + +A 1000-sample `match_mmlu` smoke was run with the local parallel eval harness on +all eight H200 GPUs, using the same GPT-OSS-20B INT4 QMoE model package and the +current ORT build package. The default Split-K2 run scored `0.8380` pooled +accuracy; the non-Split-K fallback with `ORT_MOE_GEMV_SPLITK2_SWIGLU` unset +scored `0.8350`. The small positive difference is within smoke-test noise, and +there is no accuracy regression signal from enabling Split-K2. + +### Decision + +- Keep Split-K2 available for its supported fp16 INT4 interleaved-SwiGLU GEMV + scope when `ORT_MOE_GEMV_SPLITK2_SWIGLU=1` enables it. +- Keep the fp16-accumulation Split-K2 variant available. It is slower than the + single-kernel fp16-accumulation path on the GPT-OSS shape, but faster than the + fp32-accumulation Split-K2 route and may be selected by future per-shape + autotuning. +- Use two binary route controls: `ORT_MOE_GEMV_FP32_ACCUM=1` enables fp32 + accumulation, and `ORT_MOE_GEMV_SPLITK2_SWIGLU=1` enables Split-K2. Both + default to `0`. +- The 1000-sample MMLU smoke matched the opt-out fallback within noise, so the + default flip has an accuracy sanity check in addition to focused-helper valid + output. +- Future work: + - Add per-shape autotune so route selection is data-driven instead of a fixed + default. + - Try a launch-fused reduction strategy or cooperative approach to keep the + FC1 parallelism benefit without the extra reduce launch. + ## 2026-06-19 FP16 Accumulation Default: SM90, GPT-OSS Decode Shape ### Setup diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu index b4cc262b4f1ee..996f59170e016 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.cu @@ -174,6 +174,205 @@ __device__ __forceinline__ void swiglu_epilogue(void* out, void* tile_acc, void* } } +template +__device__ __forceinline__ void partial_epilogue(PartialT* partial_out, void* tile_acc) { + static constexpr int Interleave = Details::kInterleave; + static constexpr int ThreadsPerInterleavedTile = Details::kThreadsPerInterleavedTile; + static constexpr int WarpSize = Details::kWarpSize; + static constexpr int WarpNum = Threads / WarpSize; + static_assert(CtaM == 1); + static_assert(Threads % WarpSize == 0); + + __shared__ float shmem[CtaM * CtaN * Interleave * WarpNum]; + int tid = threadIdx.x; + int warp_id = tid / WarpSize; + int lane_id = tid % WarpSize; +#pragma unroll + for (int n = 0; n < CtaN; ++n) { + float v = static_cast(reinterpret_cast(tile_acc)[n]); + v = warp_reduce_sum(v); + if (lane_id < Interleave * ThreadsPerInterleavedTile && lane_id % ThreadsPerInterleavedTile == 0) { + shmem[warp_id * CtaN * Interleave + n * Interleave + lane_id / ThreadsPerInterleavedTile] = v; + } + } + __syncthreads(); + +#pragma unroll + for (int col = tid; col < CtaN * Interleave; col += Threads) { + float val = 0.f; +#pragma unroll + for (int warp = 0; warp < WarpNum; ++warp) { + val += shmem[warp * CtaN * Interleave + col]; + } + partial_out[col] = static_cast(val); + } +} + +template +__global__ void moe_gemv_splitk_partials_kernel( + TypeA* act, uint8_t* weight, TypeA* scales, PartialT* partials, + int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, + int64_t weight_expert_stride, int64_t scale_expert_stride, int n, int k, int64_t expanded_num_rows) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) + using AccessTypeA = typename Details::AccessTypeA; + using AccessTypeW = typename Details::AccessTypeW; + + static constexpr bool Mandatory = true; + static constexpr int CtaM = 1; + static constexpr int StepK = Details::kStepK; + static constexpr int CtaK = StepK * Threads; + static_assert(CtaN % 2 == 0); + if constexpr (GroupSize != 0) { + static_assert((CtaK / Details::kInterleave) % GroupSize == 0); + } + + int const row = blockIdx.x; + int expert = permuted_row_to_expert != nullptr ? permuted_row_to_expert[row] : 0; +#pragma unroll 1 + for (int e = 0; e < num_experts && permuted_row_to_expert == nullptr; ++e) { + if (row >= static_cast(expert_first_token_offset[e + 1])) { + expert = e + 1; + continue; + } + break; + } + if (expert < 0 || expert >= num_experts) { + return; + } + + weight += expert * weight_expert_stride; + scales += static_cast(expert) * scale_expert_stride; + + int const origin_k = k; + int const interleaved_k = k * Details::kInterleave; + int const tile_id_m = row; + int const tile_id_n = blockIdx.y; + int const split_id = blockIdx.z; + int const tid = threadIdx.x; + int const offset_m = tile_id_m * CtaM; + int const interleaved_offset_n = tile_id_n * CtaN; + int const real_offset_n = interleaved_offset_n * Details::kInterleave + + ((tid * StepK / Details::LayoutDetails::kTileSize) % Details::kInterleave); + int const real_offset_k = + (tid * StepK / (Details::kInterleave * Details::LayoutDetails::kTileSize)) * + Details::LayoutDetails::kTileSize + + ((tid * StepK) % Details::LayoutDetails::kTileSize); + + GMemIterator act_iterator( + act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); + GMemIterator weight_iterator( + weight, (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, + CtaK / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW); + GMemIterator scales_iterator( + scales, (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); + + partials += (static_cast(split_id) * expanded_num_rows + offset_m) * n + + tile_id_n * CtaN * Details::kInterleave; + + AccT tile_acc[CtaM * CtaN]; + fill(tile_acc, static_cast(0.f)); + + TypeA vec_scale[CtaN]; + if constexpr (GroupSize == 0) { +#pragma unroll + for (int i = 0; i < CtaN; ++i) { + scales_iterator.load(vec_scale + i, 0, i); + } + } + + int const num_iters = (interleaved_k + CtaK - 1) / CtaK; + for (int iter = split_id; iter < num_iters; iter += SplitK) { + int const idx_k = iter * CtaK + tid * StepK; + if (idx_k >= interleaved_k) { + continue; + } + TypeA tile_a[StepK]; + TypeA tile_w[StepK]; + TypeA tile_w_pack2[CtaN * StepK]; + uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; + if constexpr (GroupSize != 0) { +#pragma unroll + for (int i = 0; i < CtaN; ++i) { + scales_iterator.load(vec_scale + i, iter, i); + } + } +#pragma unroll + for (int i = 0; i < CtaN; ++i) { + weight_iterator.load(tile_w_quantized, iter, i); + dequantize(tile_w, tile_w_quantized, vec_scale + i, nullptr, 1.0f); + pack_to_vec2(tile_w_pack2, tile_w, i); + } + act_iterator.load(tile_a, iter, 0); + mma(tile_acc, tile_w_pack2, tile_a); + } + partial_epilogue(partials, tile_acc); +#endif +} + +template +__global__ void moe_gemv_splitk_reduce_swiglu_kernel( + PartialT const* partials, TypeA* bias, TypeA* out, + int const* permuted_row_to_expert, int num_experts, int64_t const* expert_first_token_offset, + int inter_size, int split_k, int64_t expanded_num_rows, + cutlass_kernels::ActivationParams activation_params) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) + int const row = blockIdx.x; + int const col = blockIdx.y * Threads + threadIdx.x; + if (col >= inter_size) { + return; + } + + int expert = permuted_row_to_expert != nullptr ? permuted_row_to_expert[row] : 0; +#pragma unroll 1 + for (int e = 0; e < num_experts && permuted_row_to_expert == nullptr; ++e) { + if (row >= static_cast(expert_first_token_offset[e + 1])) { + expert = e + 1; + continue; + } + break; + } + if (expert < 0 || expert >= num_experts) { + return; + } + + float const* alpha = activation_params.swiglu_alpha; + float const* beta = activation_params.swiglu_beta; + float const* limit = activation_params.swiglu_limit; + float const act_alpha = alpha ? alpha[expert] : activation_params.alpha; + float const act_beta = beta ? beta[expert] : activation_params.beta; + float const act_limit = limit ? limit[expert] : activation_params.limit; + + int const n = inter_size * 2; + int const gate_idx = col * 2; + int const linear_idx = gate_idx + 1; + int64_t const row_base = static_cast(row) * n; + int64_t const split_stride = expanded_num_rows * n; + float gate = 0.f; + float linear = 0.f; + for (int split = 0; split < split_k; ++split) { + int64_t const base = static_cast(split) * split_stride + row_base; + gate += static_cast(partials[base + gate_idx]); + linear += static_cast(partials[base + linear_idx]); + } + + if constexpr (EnableBias) { + bias += static_cast(expert) * n; + gate += static_cast(bias[gate_idx]); + linear += static_cast(bias[linear_idx]); + } + if (isfinite(act_limit)) { + gate = fminf(gate, act_limit); + linear = fminf(fmaxf(linear, -act_limit), act_limit); + } + linear += act_beta; + float const sigmoid = 1.0f / (1.0f + expf(-act_alpha * gate)); + out[static_cast(row) * inter_size + col] = static_cast(gate * sigmoid * linear); +#endif +} + template __global__ void moe_gemv_interleaved_swiglu_kernel( @@ -326,6 +525,77 @@ static void launch_moe_gemv_interleaved_swiglu( } } +template +static void launch_moe_gemv_splitk_twopass_swiglu( + TypeA* act, uint8_t* weight, TypeA* scales, TypeA* bias, TypeA* out, + int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, + int64_t expanded_num_rows, int64_t inter_size, int64_t k, + cutlass_kernels::ActivationParams activation_params, PartialT* partials, cudaStream_t stream) { + static constexpr int StepK = Details::kStepK; + static constexpr int CtaK = StepK * Threads; + int64_t const n = inter_size * 2; + int64_t const interleaved_k = k * Details::kInterleave; + int const num_iters = static_cast((interleaved_k + CtaK - 1) / CtaK); + if (partials == nullptr || num_iters < 2) { + launch_moe_gemv_interleaved_swiglu( + act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, + expanded_num_rows, inter_size, k, activation_params, stream); + return; + } + + int64_t const weight_expert_stride = n * k / Details::kElemsPerByteW; + int64_t const scale_expert_stride = GroupSize == 0 ? n : ((k + GroupSize - 1) / GroupSize) * n; + dim3 grid1(static_cast(expanded_num_rows), + static_cast(n / (CtaN * Details::kInterleave)), SplitK); + dim3 block1(Threads); + moe_gemv_splitk_partials_kernel + <<>>( + act, weight, scales, partials, expert_first_token_offset, permuted_row_to_expert, num_experts, + weight_expert_stride, scale_expert_stride, static_cast(n), static_cast(k), expanded_num_rows); + + static constexpr int kReduceThreads = 256; + dim3 grid2(static_cast(expanded_num_rows), + static_cast((inter_size + kReduceThreads - 1) / kReduceThreads)); + dim3 block2(kReduceThreads); + if (bias != nullptr) { + moe_gemv_splitk_reduce_swiglu_kernel<<>>( + partials, bias, out, permuted_row_to_expert, num_experts, expert_first_token_offset, + static_cast(inter_size), SplitK, expanded_num_rows, activation_params); + } else { + moe_gemv_splitk_reduce_swiglu_kernel<<>>( + partials, bias, out, permuted_row_to_expert, num_experts, expert_first_token_offset, + static_cast(inter_size), SplitK, expanded_num_rows, activation_params); + } +} + +template +static void dispatch_moe_gemv_splitk_twopass_swiglu_group_size( + TypeA* act, uint8_t* weight, TypeA* scales, TypeA* bias, TypeA* out, + int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, + int64_t expanded_num_rows, int64_t inter_size, int64_t k, int group_size, + cutlass_kernels::ActivationParams activation_params, PartialT* partials, cudaStream_t stream) { + if (group_size <= 0) { + launch_moe_gemv_splitk_twopass_swiglu( + act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, + expanded_num_rows, inter_size, k, activation_params, partials, stream); + } else if (group_size == 32) { + launch_moe_gemv_splitk_twopass_swiglu( + act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, + expanded_num_rows, inter_size, k, activation_params, partials, stream); + } else if (group_size == 64) { + launch_moe_gemv_splitk_twopass_swiglu( + act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, + expanded_num_rows, inter_size, k, activation_params, partials, stream); + } else if (group_size == 128) { + launch_moe_gemv_splitk_twopass_swiglu( + act, weight, scales, bias, out, expert_first_token_offset, permuted_row_to_expert, num_experts, + expanded_num_rows, inter_size, k, activation_params, partials, stream); + } else { + ORT_THROW("unsupported MoE GEMV split-K group_size: ", group_size); + } +} + template static void dispatch_moe_gemv_group_size(TypeA* act, uint8_t* weight, TypeA* scales, TypeA* bias, TypeA* out, int64_t const* expert_first_token_offset, @@ -407,6 +677,20 @@ inline bool MoeGemvUseFp32Accum() { return enabled; } +inline bool MoeGemvUseSplitK2SwiGLU() { + // Parsed once via ORT's environment helper (consistent parsing/thread-safety across platforms). + static bool const enabled = [] { + auto const value = onnxruntime::ParseEnvironmentVariable("ORT_MOE_GEMV_SPLITK2_SWIGLU"); + if (!value.has_value()) { + return false; + } + ORT_ENFORCE(*value == 0 || *value == 1, + "ORT_MOE_GEMV_SPLITK2_SWIGLU must be 0 or 1, but got ", *value); + return *value == 1; + }(); + return enabled; +} + bool is_moe_gemv_supported(int sm, int64_t expanded_num_rows, int64_t n, int64_t k, int weight_bits, int group_size) { if (sm < 80) { @@ -527,12 +811,34 @@ void launch_moe_gemv_int_symmetric_interleaved_swiglu( T const* act, WeightType const* weight, T const* scales, T const* bias, T* out, int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, int64_t expanded_num_rows, int64_t inter_size, int64_t k, int group_size, int sm, - cutlass_kernels::ActivationParams activation_params, cudaStream_t stream) { + cutlass_kernels::ActivationParams activation_params, void* splitk_partials, cudaStream_t stream) { ORT_UNUSED_PARAMETER(sm); using Details = typename DetailsForTAndWeight::Details; using TypeA = typename DetailsForTAndWeight::TypeA; // Accumulation policy matches launch_moe_gemv_int_symmetric. bool const use_fp32_accum = !std::is_same_v || MoeGemvUseFp32Accum(); + if (splitk_partials != nullptr && MoeGemvUseSplitK2SwiGLU()) { + if constexpr (std::is_same_v) { + auto launch_splitk = [&](auto acc_tag) { + using AccT = typename decltype(acc_tag)::type; + using PartialT = AccT; + fiv::dispatch_moe_gemv_splitk_twopass_swiglu_group_size( + const_cast(reinterpret_cast(act)), + const_cast(reinterpret_cast(weight)), + const_cast(reinterpret_cast(scales)), + const_cast(reinterpret_cast(bias)), + reinterpret_cast(out), + expert_first_token_offset, permuted_row_to_expert, num_experts, expanded_num_rows, inter_size, k, group_size, + activation_params, reinterpret_cast(splitk_partials), stream); + }; + if (use_fp32_accum) { + launch_splitk(TypeTag{}); + } else { + launch_splitk(TypeTag{}); + } + return; + } + } auto launch = [&](auto acc_tag) { using AccT = typename decltype(acc_tag)::type; fiv::dispatch_moe_gemv_interleaved_swiglu_group_size( @@ -569,7 +875,8 @@ void launch_moe_gemv_int4_per_channel_interleaved_swiglu( cutlass_kernels::ActivationParams activation_params, cudaStream_t stream) { launch_moe_gemv_int_symmetric_interleaved_swiglu( act, reinterpret_cast(weight), scales, bias, out, expert_first_token_offset, - permuted_row_to_expert, num_experts, expanded_num_rows, inter_size, k, 0, sm, activation_params, stream); + permuted_row_to_expert, num_experts, expanded_num_rows, inter_size, k, 0, sm, activation_params, + nullptr, stream); } template void launch_moe_gemv_int_symmetric( @@ -580,10 +887,10 @@ template void launch_moe_gemv_int_symmetric( int64_t, int64_t, int64_t, int, int, cudaStream_t); template void launch_moe_gemv_int_symmetric_interleaved_swiglu( half const*, cutlass::uint4b_t const*, half const*, half const*, half*, int64_t const*, int const*, int, - int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, cudaStream_t); + int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, void*, cudaStream_t); template void launch_moe_gemv_int_symmetric_interleaved_swiglu( half const*, uint8_t const*, half const*, half const*, half*, int64_t const*, int const*, int, - int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, cudaStream_t); + int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, void*, cudaStream_t); template void launch_moe_gemv_int4_per_channel(half const*, uint8_t const*, half const*, half const*, half*, int64_t const*, int const*, int, int64_t, int64_t, int64_t, @@ -602,11 +909,11 @@ template void launch_moe_gemv_int_symmetric<__nv_bfloat16, uint8_t>( template void launch_moe_gemv_int_symmetric_interleaved_swiglu<__nv_bfloat16, cutlass::uint4b_t>( __nv_bfloat16 const*, cutlass::uint4b_t const*, __nv_bfloat16 const*, __nv_bfloat16 const*, __nv_bfloat16*, int64_t const*, int const*, int, int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, - cudaStream_t); + void*, cudaStream_t); template void launch_moe_gemv_int_symmetric_interleaved_swiglu<__nv_bfloat16, uint8_t>( __nv_bfloat16 const*, uint8_t const*, __nv_bfloat16 const*, __nv_bfloat16 const*, __nv_bfloat16*, int64_t const*, int const*, int, int64_t, int64_t, int64_t, int, int, cutlass_kernels::ActivationParams, - cudaStream_t); + void*, cudaStream_t); template void launch_moe_gemv_int4_per_channel<__nv_bfloat16>( __nv_bfloat16 const*, uint8_t const*, __nv_bfloat16 const*, __nv_bfloat16 const*, __nv_bfloat16*, diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.h b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.h index b4dfe1c59f02a..229f9839e62f1 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.h +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemv.h @@ -53,6 +53,7 @@ void launch_moe_gemv_int_symmetric_interleaved_swiglu( T const* act, WeightType const* weight, T const* scales, T const* bias, T* out, int64_t const* expert_first_token_offset, int const* permuted_row_to_expert, int num_experts, int64_t expanded_num_rows, int64_t inter_size, int64_t k, int group_size, int sm, cutlass_kernels::ActivationParams activation_params, + void* splitk_partials, cudaStream_t stream); // Launches the int4 per-channel MoE GEMV. diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu index 3185a9a86b231..6d708b3fc586a 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu @@ -83,6 +83,26 @@ inline bool MoeGemvDisabledByEnv() { return disabled; } +inline bool MoeGemvSplitK2SwiGLUEnabledByEnv() { + // Parsed once via ORT's environment helper (consistent parsing/thread-safety across platforms). + static bool const enabled = [] { + auto const value = onnxruntime::ParseEnvironmentVariable("ORT_MOE_GEMV_SPLITK2_SWIGLU"); + if (!value.has_value()) { + return false; + } + ORT_ENFORCE(*value == 0 || *value == 1, + "ORT_MOE_GEMV_SPLITK2_SWIGLU must be 0 or 1, but got ", *value); + return *value == 1; + }(); + return enabled; +} + +template +constexpr bool MoeGemvSplitK2SwiGLUSupported() { + return std::is_same_v && std::is_same_v && + std::is_same_v; +} + inline bool MoeGemvRejectedByProfiledInterSize(int64_t expanded_num_rows, int64_t inter_size) { return expanded_num_rows > onnxruntime::llm::kernels::moe_gemv::kMaxProfiledExpandedRowsForSmallProblemDim && inter_size < onnxruntime::llm::kernels::moe_gemv::kMinProfiledProblemDimForExpandedRowsAbove4; @@ -153,7 +173,7 @@ bool tryLaunchMoeGemvIntSymmetricInterleavedSwiGLU( ScaleBiasType const* biases, T* output, int64_t const* expert_first_token_offset, int num_experts_per_node, int const* permuted_row_to_expert, int64_t expanded_num_rows, int64_t inter_size, int64_t k, int sm, int group_size, - bool disabled, cutlass_kernels::ActivationParams activation_params, cudaStream_t stream) { + bool disabled, cutlass_kernels::ActivationParams activation_params, float* splitk_partials, cudaStream_t stream) { if constexpr ((std::is_same_v || std::is_same_v) && (std::is_same_v || std::is_same_v) && std::is_same_v) { bool const env_disabled = MoeGemvDisabledByEnv(); @@ -176,7 +196,8 @@ bool tryLaunchMoeGemvIntSymmetricInterleavedSwiGLU( } onnxruntime::llm::kernels::moe_gemv::launch_moe_gemv_int_symmetric_interleaved_swiglu( input, weights, scales, biases, output, expert_first_token_offset, permuted_row_to_expert, - num_experts_per_node, expanded_num_rows, inter_size, k, group_size, sm, activation_params, stream); + num_experts_per_node, expanded_num_rows, inter_size, k, group_size, sm, activation_params, + splitk_partials, stream); return true; } else { (void)input; @@ -195,6 +216,7 @@ bool tryLaunchMoeGemvIntSymmetricInterleavedSwiGLU( (void)group_size; (void)disabled; (void)activation_params; + (void)splitk_partials; (void)stream; return false; } @@ -2169,6 +2191,12 @@ CutlassMoeFCRunner: // (see comment above). Only the gated / has-glu path uses it; zero otherwise. size_t const fc1_result_dedicated_size = (glu_inter_elems > 0) ? fc1_result_size : 0; + static constexpr int kMoeGemvSplitK = 2; + size_t const moe_gemv_splitk_partials_size = is_gated_activation && MoeGemvSplitK2SwiGLUEnabledByEnv() && + MoeGemvSplitK2SwiGLUSupported() + ? kMoeGemvSplitK * glu_inter_elems * sizeof(float) + : 0; + size_t smoothed_act_size = use_awq ? std::max(permuted_elems, interbuf_elems) * sizeof(T) * 2 : 0; // Extra workspace required by AWQ for smoothing activations @@ -2193,6 +2221,7 @@ CutlassMoeFCRunner: ADD(overlapped_gemm1_gemm2_inputs); ADD(overlapped_gemm1_gemm2_outputs); ADD(fc1_result_dedicated); + ADD(moe_gemv_splitk_partials); ADD_NAME(alpha_scale_ptr_array_fc1, alpha_scale_ptr_array_size); ADD_NAME(alpha_scale_ptr_array_fc2, alpha_scale_ptr_array_size); ADD(fp4_act_scale); @@ -2270,6 +2299,7 @@ void CutlassMoeFCRunner(); + ORT_ENFORCE(!use_splitk_swiglu_gemv || moe_gemv_splitk_partials != nullptr, + "Split-K2 SwiGLU GEMV requires split-K GEMV workspace"); bool const fc1_did_fused_gemv = tryLaunchMoeGemvIntSymmetricInterleavedSwiGLU( input, fc1_expert_weights, quant_params.groupwise.group_size > 0 @@ -2466,7 +2500,9 @@ void CutlassMoeFCRunner 1 || use_ampere_activation_fusion || !bias_is_broadcast || MoeGemvRejectedByProfiledInterSize(expanded_num_rows, inter_size), - activation_params, stream); + activation_params, + use_splitk_swiglu_gemv ? moe_gemv_splitk_partials : nullptr, + stream); // Run the GEMM with activation function overridden with `Identity`, we do the activation separately. // Fast path: symmetric INT4/INT8 (per-column or block-wise) MoE GEMV for small expanded-row counts @@ -2805,7 +2841,7 @@ void CutlassMoeFCRunner(fc1_expert_weights), static_cast(fc1_expert_biases), num_valid_tokens_ptr, static_cast(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant, fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size, - num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, nullptr, bias_is_broadcast, stream, + num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, nullptr, nullptr, bias_is_broadcast, stream, MOEParallelismConfig{}, config, activation_params); } @@ -626,6 +626,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_scale_; float const** alpha_scale_ptr_array_fc1_ = nullptr; float const** alpha_scale_ptr_array_fc2_ = nullptr; + float* moe_gemv_splitk_partials_{}; void* smoothed_act_{}; TmaWarpSpecializedGroupedGemmInput tma_ws_grouped_gemm1_input_; diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index ae000168f7004..d5330b5a361b5 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -375,4 +375,3 @@ MlasConvSGemmRoute(const MLAS_CONV_PARAMETERS* Parameters) { : MlasConvSGemmRouteDirect; } } - diff --git a/onnxruntime/test/python/transformers/profile_qmoe_gemv.py b/onnxruntime/test/python/transformers/profile_qmoe_gemv.py index 0f71409fa2206..c674b44831032 100644 --- a/onnxruntime/test/python/transformers/profile_qmoe_gemv.py +++ b/onnxruntime/test/python/transformers/profile_qmoe_gemv.py @@ -72,6 +72,16 @@ def main(): action="store_true", help="Run the grouped GEMM fallback by setting ORT_DISABLE_MOE_GEMV=1 before session creation", ) + parser.add_argument( + "--splitk2-swiglu", + action="store_true", + help="Enable split-K2 two-pass FC1 SwiGLU GEMV by setting ORT_MOE_GEMV_SPLITK2_SWIGLU=1", + ) + parser.add_argument( + "--fp32-accum", + action="store_true", + help="Enable fp32 GEMV accumulation by setting ORT_MOE_GEMV_FP32_ACCUM=1", + ) parser.add_argument( "--nvtx", action="store_true", @@ -101,6 +111,16 @@ def main(): else: os.environ.pop("ORT_DISABLE_MOE_GEMV", None) + if args.fp32_accum: + os.environ["ORT_MOE_GEMV_FP32_ACCUM"] = "1" + else: + os.environ.pop("ORT_MOE_GEMV_FP32_ACCUM", None) + + if args.splitk2_swiglu: + os.environ["ORT_MOE_GEMV_SPLITK2_SWIGLU"] = "1" + else: + os.environ.pop("ORT_MOE_GEMV_SPLITK2_SWIGLU", None) + result = run_qmoe_gemv_benchmark(case) if result["has_invalid_output"]: raise RuntimeError("QMoE GEMV profiling produced NaN or Inf output") diff --git a/onnxruntime/test/python/transformers/profile_qmoe_gemv.sh b/onnxruntime/test/python/transformers/profile_qmoe_gemv.sh index c452ee17fec00..ce1473d83bcff 100755 --- a/onnxruntime/test/python/transformers/profile_qmoe_gemv.sh +++ b/onnxruntime/test/python/transformers/profile_qmoe_gemv.sh @@ -11,6 +11,8 @@ # ./profile_qmoe_gemv.sh --list-cases # ./profile_qmoe_gemv.sh --case m8_top2_fp16_128x256 --warmup 5 --repeat 200 # ./profile_qmoe_gemv.sh --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --warmup 5 --repeat 100 +# ./profile_qmoe_gemv.sh --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --splitk2-swiglu --warmup 5 --repeat 100 +# ./profile_qmoe_gemv.sh --case gpt_oss_20b_m1_top4_fp16_2880x2880_e32 --fp32-accum --splitk2-swiglu --warmup 5 --repeat 100 # ./profile_qmoe_gemv.sh --batch-size 1 --sequence-length 1 --hidden-size 1024 --intermediate-size 4096 --num-experts 8 --top-k 2 --quant-bits 8 --block-size 128 # CUDA_VISIBLE_DEVICES=1 ./profile_qmoe_gemv.sh -o /tmp/qmoe_gemv # @@ -26,6 +28,8 @@ OUTPUT_NAME="qmoe_gemv_profile" PY="${PYTHON:-python}" EXTRA_ARGS=() LIST_CASES=0 +FP32_ACCUM=0 +SPLITK2_SWIGLU=0 while [[ "$#" -gt 0 ]]; do case $1 in @@ -36,6 +40,12 @@ while [[ "$#" -gt 0 ]]; do --list-cases) LIST_CASES=1 ;; + --splitk2-swiglu) + SPLITK2_SWIGLU=1 + ;; + --fp32-accum) + FP32_ACCUM=1 + ;; --batch-size|--sequence-length|--hidden-size|--intermediate-size|--num-experts|--top-k|--dtype|--quant-bits|--block-size) EXTRA_ARGS+=("$1" "$2") shift @@ -58,7 +68,7 @@ while [[ "$#" -gt 0 ]]; do ;; *) echo "Unknown option: $1" - echo "Usage: $0 [--list-cases] [--case NAME] [--batch-size N] [--sequence-length N] [--hidden-size N] [--intermediate-size N] [--num-experts N] [--top-k N] [--dtype FLOAT16|BFLOAT16] [--quant-bits 4|8] [--block-size 0|32|64|128] [--warmup N] [--repeat N] [--python PYTHON] [-o NAME]" + echo "Usage: $0 [--list-cases] [--case NAME] [--fp32-accum] [--splitk2-swiglu] [--batch-size N] [--sequence-length N] [--hidden-size N] [--intermediate-size N] [--num-experts N] [--top-k N] [--dtype FLOAT16|BFLOAT16] [--quant-bits 4|8] [--block-size 0|32|64|128] [--warmup N] [--repeat N] [--python PYTHON] [-o NAME]" exit 1 ;; esac @@ -98,21 +108,36 @@ fi if [[ "${#EXTRA_ARGS[@]}" -gt 0 ]]; then echo "Custom args: ${EXTRA_ARGS[*]}" fi +if [[ "${FP32_ACCUM}" -eq 1 ]]; then + echo "GEMV accumulation: fp32" +fi +if [[ "${SPLITK2_SWIGLU}" -eq 1 ]]; then + echo "Split-K2 SwiGLU: enabled for GEMV mode" +fi profile_one() { local mode="$1" local disable_arg="" + local fp32_accum_arg="" + local splitk2_arg="" local base="${OUTPUT_NAME}_${mode}" if [[ "${mode}" == "gemm" ]]; then disable_arg="--disable-gemv" fi + if [[ "${FP32_ACCUM}" -eq 1 ]]; then + fp32_accum_arg="--fp32-accum" + fi + if [[ "${SPLITK2_SWIGLU}" -eq 1 ]]; then + splitk2_arg="--splitk2-swiglu" + fi echo "" echo "---- Profiling ${mode} ----" rm -f "${base}.nsys-rep" "${base}.sqlite" nsys profile -t cuda,nvtx --force-overwrite true -o "${base}" --export=sqlite \ "${PY}" "${SCRIPT_DIR}/profile_qmoe_gemv.py" \ - --case "${CASE}" "${EXTRA_ARGS[@]}" --warmup "${WARMUP}" --repeat "${REPEAT}" --nvtx ${disable_arg} + --case "${CASE}" "${EXTRA_ARGS[@]}" --warmup "${WARMUP}" --repeat "${REPEAT}" --nvtx \ + ${disable_arg} ${fp32_accum_arg} ${splitk2_arg} echo "" echo "---- Kernel results (${mode}) ----" diff --git a/onnxruntime/test/python/transformers/test_qmoe_cuda.py b/onnxruntime/test/python/transformers/test_qmoe_cuda.py index 36d2ff66010ff..aa7c8f3db8470 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cuda.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cuda.py @@ -2486,6 +2486,8 @@ def run_qmoe_gemv_benchmark(case): "case": case["name"], "block_size": case.get("block_size", 0), "disable_gemv": os.getenv("ORT_DISABLE_MOE_GEMV") == "1", + "fp32_accum": os.getenv("ORT_MOE_GEMV_FP32_ACCUM", "0"), + "splitk2_swiglu": os.getenv("ORT_MOE_GEMV_SPLITK2_SWIGLU", "0"), "expanded_num_rows": case["batch_size"] * case["sequence_length"] * case["top_k"], "has_invalid_output": bool(torch.isnan(output).any() or torch.isinf(output).any()), "latency_ms": qmoe.last_ort_latency_ms, @@ -2512,6 +2514,20 @@ def test_decode_latency(self): self.assertFalse(result["has_invalid_output"]) print(_QMOE_GEMV_BENCHMARK_RESULT_PREFIX + json.dumps(result, sort_keys=True)) + def test_splitk2_swiglu_decode_latency(self): + previous_splitk2 = os.environ.get("ORT_MOE_GEMV_SPLITK2_SWIGLU") + os.environ["ORT_MOE_GEMV_SPLITK2_SWIGLU"] = "1" + try: + result = run_qmoe_gemv_benchmark_case("gpt_oss_20b_m1_top4_fp16_2880x2880_e32") + self.assertEqual(result["splitk2_swiglu"], "1") + self.assertFalse(result["has_invalid_output"]) + print(_QMOE_GEMV_BENCHMARK_RESULT_PREFIX + json.dumps(result, sort_keys=True)) + finally: + if previous_splitk2 is None: + os.environ.pop("ORT_MOE_GEMV_SPLITK2_SWIGLU", None) + else: + os.environ["ORT_MOE_GEMV_SPLITK2_SWIGLU"] = previous_splitk2 + @unittest.skipIf(True, "Skipping QMoE benchmark tests") class TestQMoESwiGLUBenchmark(unittest.TestCase):