Skip to content

Commit 5644ffb

Browse files
authored
First check-in to add ci/cd test on tpuv7x (#1270)
Signed-off-by: Qiliang Cui <derrhein@gmail.com>
1 parent bf54fc5 commit 5644ffb

File tree

7 files changed

+47
-5
lines changed

7 files changed

+47
-5
lines changed

.buildkite/pipeline_jax.yml

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,22 @@ steps:
124124
--ignore=/workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_moe.py \
125125
--cov-config=/workspace/tpu_inference/.coveragerc --cov tpu_inference --cov-report term-missing --cov-fail-under=69
126126
127+
- label: "JAX unit tests - tpuv7x"
128+
key: test_7_tpu7x
129+
soft_fail: true
130+
agents:
131+
queue: tpu_v7x_2_queue
132+
commands:
133+
- |
134+
IS_FOR_V7X=true .buildkite/scripts/run_in_docker.sh \
135+
python3 -m pytest -s -v -x /workspace/tpu_inference/tests/ \
136+
--ignore=/workspace/tpu_inference/tests/kernels \
137+
--ignore=/workspace/tpu_inference/tests/lora \
138+
--ignore=/workspace/tpu_inference/tests/e2e \
139+
--ignore=/workspace/tpu_inference/tpu_inference/mock \
140+
--ignore=/workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_moe.py \
141+
--cov-config=/workspace/tpu_inference/.coveragerc --cov tpu_inference --cov-report term-missing --cov-fail-under=67
142+
127143
- label: "JAX unit tests - kernels"
128144
key: test_8
129145
soft_fail: true
@@ -269,6 +285,7 @@ steps:
269285
- test_5
270286
- test_6
271287
- test_7
288+
- test_7_tpu7x
272289
- test_8
273290
- test_9
274291
- test_10
@@ -282,4 +299,4 @@ steps:
282299
commands:
283300
- |
284301
.buildkite/scripts/check_results.sh \
285-
"TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_8 test_9 test_10 test_11 test_12 test_13 test_15 test_16
302+
"TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_7_tpu7x test_8 test_9 test_10 test_11 test_12 test_13 test_15 test_16

.buildkite/scripts/setup_docker_env.sh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,10 @@ setup_environment() {
5555

5656
echo "Cleanup complete."
5757

58-
echo "Installing Python dependencies"
59-
python3 -m pip install --progress-bar off buildkite-test-collector==0.1.9
60-
echo "Python dependencies installed"
61-
6258
VLLM_COMMIT_HASH=$(buildkite-agent meta-data get "VLLM_COMMIT_HASH" --default "")
6359

6460
docker build \
6561
--build-arg VLLM_COMMIT_HASH="${VLLM_COMMIT_HASH}" \
62+
--build-arg IS_FOR_V7X="${IS_FOR_V7X:-false}" \
6663
--no-cache -f docker/Dockerfile -t "${IMAGE_NAME}:${BUILDKITE_COMMIT}" .
6764
}

tests/layers/jax/sample/test_rejection_sampler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,9 @@ def test_rejection_sampling_approximates_target_distribution(self):
11811181
We expect that as sample size increases, the distance to the target
11821182
distribution decreases much more than the distance to random distributions.
11831183
"""
1184+
if 'TPU7x' in jax.devices()[0].device_kind:
1185+
pytest.skip("Skipping test on TPU TPU7x.")
1186+
11841187
vocab_size = 10
11851188
k = 2
11861189
num_reference_probs = 100

tests/layers/vllm/test_compressed_tensors_w8a8_int8.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def test_loading_model(model, mesh):
129129
])
130130
@pytest.mark.parametrize("enable_sp", [False, True])
131131
def test_row_parallel_linear(model, bias, mesh, enable_sp):
132+
if 'TPU7x' in jax.devices()[0].device_kind:
133+
pytest.skip("Skipping test on TPU TPU7x.")
134+
132135
dtype = torch.bfloat16
133136

134137
engine_args = EngineArgs(

tests/layers/vllm/test_mxfp4.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def test_quant_override(model, mesh):
116116
@pytest.mark.parametrize("topk", [2])
117117
def test_mxfp4_fused_moe(mesh, num_tokens, intermediate_size, hidden_size,
118118
num_experts, topk):
119+
if 'TPU7x' in jax.devices()[0].device_kind:
120+
pytest.skip("Skipping test on TPU TPU7x.")
121+
119122
torch.manual_seed(42)
120123
dtype = torch.bfloat16
121124

@@ -205,6 +208,10 @@ def test_mxfp4_fused_moe(mesh, num_tokens, intermediate_size, hidden_size,
205208
@pytest.mark.parametrize("topk", [2])
206209
def test_mxfp4_fused_moe_use_kernel(mesh, num_tokens, intermediate_size,
207210
hidden_size, num_experts, topk):
211+
212+
if 'TPU7x' in jax.devices()[0].device_kind:
213+
pytest.skip("Skipping test on TPU TPU7x.")
214+
208215
torch.manual_seed(42)
209216
dtype = torch.bfloat16
210217

tests/layers/vllm/test_unquantized.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,9 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
415415
@pytest.mark.parametrize("topk", [2])
416416
def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
417417
num_experts, topk):
418+
if 'TPU7x' in jax.devices()[0].device_kind:
419+
pytest.skip("Skipping test on TPU TPU7x.")
420+
418421
torch.manual_seed(42)
419422
dtype = torch.bfloat16
420423

@@ -494,6 +497,9 @@ def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
494497
@pytest.mark.parametrize("topk", [2])
495498
def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size,
496499
num_experts, topk):
500+
if 'TPU7x' in jax.devices()[0].device_kind:
501+
pytest.skip("Skipping test on TPU TPU7x.")
502+
497503
torch.manual_seed(42)
498504
dtype = torch.bfloat16
499505

@@ -560,6 +566,9 @@ def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size,
560566
@pytest.mark.parametrize("activation", ["silu", "swigluoai"])
561567
def test_fused_moe_activation(mesh, num_tokens, intermediate_size, hidden_size,
562568
num_experts, topk, activation):
569+
if 'TPU7x' in jax.devices()[0].device_kind:
570+
pytest.skip("Skipping test on TPU TPU7x.")
571+
563572
torch.manual_seed(42)
564573
dtype = torch.bfloat16
565574

@@ -619,6 +628,8 @@ def test_fused_moe_activation(mesh, num_tokens, intermediate_size, hidden_size,
619628
@pytest.mark.parametrize("has_bias", [False, True])
620629
def test_fused_moe_use_kernel(mesh, num_tokens, intermediate_size, hidden_size,
621630
num_experts, topk, has_bias):
631+
if 'TPU7x' in jax.devices()[0].device_kind:
632+
pytest.skip("Skipping test on TPU TPU7x.")
622633

623634
if jax.local_device_count() < 8:
624635
pytest.skip("Test requires at least 8 devices")

tests/models/jax/test_llama_eagle3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ def test_eagle3_decoder_layer_init(self, mock_vllm_config: MockVllmConfig,
126126
def test_forward_pass(self, mock_vllm_config: MockVllmConfig, rng: PRNGKey,
127127
mesh: Mesh, mock_model_inputs):
128128
"""Tests the forward pass of the EagleLlama3ForCausalLM model."""
129+
130+
if 'TPU7x' in jax.devices()[0].device_kind:
131+
pytest.skip("Skipping test on TPU TPU7x.")
132+
129133
draft_model_config = mock_vllm_config.speculative_config.draft_model_config
130134
hf_config = draft_model_config.hf_config
131135
model = EagleLlama3ForCausalLM(mock_vllm_config, rng, mesh)

0 commit comments

Comments
 (0)