@@ -415,6 +415,9 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
415415@pytest .mark .parametrize ("topk" , [2 ])
416416def test_fused_moe (use_ep , mesh , num_tokens , intermediate_size , hidden_size ,
417417 num_experts , topk ):
418+ if 'TPU7x' in jax .devices ()[0 ].device_kind :
419+ pytest .skip ("Skipping test on TPU TPU7x." )
420+
418421 torch .manual_seed (42 )
419422 dtype = torch .bfloat16
420423
@@ -494,6 +497,9 @@ def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
494497@pytest .mark .parametrize ("topk" , [2 ])
495498def test_fused_moe_bias (mesh , num_tokens , intermediate_size , hidden_size ,
496499 num_experts , topk ):
500+ if 'TPU7x' in jax .devices ()[0 ].device_kind :
501+ pytest .skip ("Skipping test on TPU TPU7x." )
502+
497503 torch .manual_seed (42 )
498504 dtype = torch .bfloat16
499505
@@ -560,6 +566,9 @@ def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size,
560566@pytest .mark .parametrize ("activation" , ["silu" , "swigluoai" ])
561567def test_fused_moe_activation (mesh , num_tokens , intermediate_size , hidden_size ,
562568 num_experts , topk , activation ):
569+ if 'TPU7x' in jax .devices ()[0 ].device_kind :
570+ pytest .skip ("Skipping test on TPU TPU7x." )
571+
563572 torch .manual_seed (42 )
564573 dtype = torch .bfloat16
565574
@@ -619,6 +628,8 @@ def test_fused_moe_activation(mesh, num_tokens, intermediate_size, hidden_size,
619628@pytest .mark .parametrize ("has_bias" , [False , True ])
620629def test_fused_moe_use_kernel (mesh , num_tokens , intermediate_size , hidden_size ,
621630 num_experts , topk , has_bias ):
631+ if 'TPU7x' in jax .devices ()[0 ].device_kind :
632+ pytest .skip ("Skipping test on TPU TPU7x." )
622633
623634 if jax .local_device_count () < 8 :
624635 pytest .skip ("Test requires at least 8 devices" )
0 commit comments