Skip to content

Commit f64f488

Browse files
authored
Refactor moe codebase (#1199)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 7056450 commit f64f488

File tree

5 files changed

+229
-336
lines changed

5 files changed

+229
-336
lines changed

tests/layers/vllm/test_mxfp4.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def test_mxfp4_fused_moe_use_kernel(mesh, num_tokens, intermediate_size,
231231
vllm_config = engine_args.create_engine_config()
232232
vllm_config.model_config.dtype = dtype
233233
vllm_config.parallel_config = ParallelConfig(
234-
tensor_parallel_size=mesh.devices.size, enable_expert_paralle=True)
234+
tensor_parallel_size=mesh.devices.size)
235235

236236
quant_config = get_tpu_quantization_config(vllm_config, mesh)
237237
with set_current_vllm_config(vllm_config):
@@ -247,6 +247,8 @@ def test_mxfp4_fused_moe_use_kernel(mesh, num_tokens, intermediate_size,
247247
quant_config=quant_config,
248248
has_bias=True,
249249
)
250+
vllm_fused_moe.moe_parallel_config.use_ep = True
251+
250252
vllm_fused_moe.w13_weight.data = w1_weight
251253
vllm_fused_moe.w2_weight.data = w2_weight
252254
vllm_fused_moe.w13_weight_scale.data = w1_weight_scale

tests/layers/vllm/test_unquantized.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,6 @@ def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
444444
)
445445
vllm_config = engine_args.create_engine_config()
446446
vllm_config.model_config.dtype = dtype
447-
vllm_config.parallel_config = ParallelConfig(enable_expert_paralle=use_ep)
448447

449448
quant_config = get_tpu_quantization_config(vllm_config, mesh)
450449
with set_current_vllm_config(vllm_config):
@@ -459,6 +458,7 @@ def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
459458
dp_size=1,
460459
quant_config=quant_config,
461460
)
461+
vllm_fused_moe.moe_parallel_config.use_ep = use_ep
462462
vllm_fused_moe.w13_weight.data = w1
463463
vllm_fused_moe.w2_weight.data = w2
464464

@@ -678,7 +678,7 @@ def test_fused_moe_use_kernel(mesh, num_tokens, intermediate_size, hidden_size,
678678
vllm_config = engine_args.create_engine_config()
679679
vllm_config.model_config.dtype = dtype
680680
vllm_config.parallel_config = ParallelConfig(
681-
tensor_parallel_size=mesh.devices.size, enable_expert_paralle=True)
681+
tensor_parallel_size=mesh.devices.size)
682682

683683
quant_config = get_tpu_quantization_config(vllm_config, mesh)
684684
with set_current_vllm_config(vllm_config):

0 commit comments

Comments
 (0)