Skip to content

Commit 2901e56

Browse files
committed
update unit-test yml
Signed-off-by: Juncheng Gu <jcgu@google.com>
1 parent 12c4885 commit 2901e56

File tree

4 files changed

+27
-10
lines changed

4 files changed

+27
-10
lines changed

.buildkite/features/KV_Cache_Offload.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@ steps:
44
- label: "Correctness tests for KV Cache Offload"
55
key: "KV_Cache_Offload_CorrectnessTest"
66
soft_fail: true
7+
env:
8+
USE_V6E8_QUEUE: "True"
9+
VLLM_LOG_LEVEL: "INFO"
710
agents:
8-
queue: tpu_v6e_queue
11+
queue: tpu_v6e_8_queue
912
commands:
1013
- |
1114
.buildkite/scripts/run_in_docker.sh \

.buildkite/pipeline_jax.yml

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ steps:
122122
--ignore=/workspace/tpu_inference/tests/e2e \
123123
--ignore=/workspace/tpu_inference/tpu_inference/mock \
124124
--ignore=/workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_moe.py \
125-
--ignore=/workspace/tpu_inference/tests/distributed/offload/test_offload_accuracy_test.py \
125+
--ignore=/workspace/tpu_inference/tests/distributed/offload \
126126
--cov-config=/workspace/tpu_inference/.coveragerc --cov tpu_inference --cov-report term-missing --cov-fail-under=69
127127
128128
- label: "JAX unit tests - kernels"
@@ -138,6 +138,7 @@ steps:
138138
--ignore=/workspace/tpu_inference/tests/kernels/ragged_paged_attention_kernel_v2_test.py \
139139
--ignore=/workspace/tpu_inference/tests/kernels/ragged_kv_cache_update_v2_test.py \
140140
--ignore=/workspace/tpu_inference/tests/kernels/collectives \
141+
--ignore=/workspace/tpu_inference/tests/kernels/host_dma_test.py \
141142
--ignore=/workspace/tpu_inference/tests/kernels/fused_moe_v1_test.py
142143
else
143144
echo "Skipping: no changes detected in kernels, tests/kernels, or requirements.txt"
@@ -256,6 +257,21 @@ steps:
256257
echo "Skipping: NIGHTLY environment variable not set"
257258
exit 0
258259
fi
260+
261+
- label: "kv cache offload tests on multi chips"
262+
key: test_17
263+
soft_fail: true
264+
env:
265+
USE_V6E8_QUEUE: "True"
266+
VLLM_LOG_LEVEL: "INFO"
267+
agents:
268+
queue: tpu_v6e_8_queue
269+
commands:
270+
- |
271+
.buildkite/scripts/run_in_docker.sh \
272+
python3 -m pytest -s -v -x /workspace/tpu_inference/tests/distributed/offload/ \
273+
/workspace/tpu_inference/tests/kernels/host_dma_test.py \
274+
--ignore=/workspace/tpu_inference/tests/distributed/offload/tpu_offload_accuracy_test.py
259275
# -----------------------------------------------------------------
260276
# NOTIFICATION STEP
261277
# -----------------------------------------------------------------
@@ -278,9 +294,10 @@ steps:
278294
- test_13
279295
- test_15
280296
- test_16
297+
- test_17
281298
agents:
282299
queue: cpu
283300
commands:
284301
- |
285302
.buildkite/scripts/check_results.sh \
286-
"TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_8 test_9 test_10 test_11 test_12 test_13 test_15 test_16
303+
"TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_8 test_9 test_10 test_11 test_12 test_13 test_15 test_16 test_17

tests/distributed/offload/tpu_offload_utils_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def setUp(self):
1717
"""Set up common parameters for the tests."""
1818
self.num_layers = 2
1919
self.num_tokens = 256
20-
self.num_kv_heads = 8
20+
num_devices = len(list(jax.devices()))
21+
self.num_kv_heads = num_devices
2122
self.head_dim = 128
2223
self.block_size = 16
2324
self.num_blocks = self.num_tokens // self.block_size
@@ -37,7 +38,7 @@ def setUp(self):
3738

3839
self.cache_dtype = jnp.bfloat16
3940

40-
self.mesh = self.create_mesh((1, 8), ("data", "model"))
41+
self.mesh = self.create_mesh((1, num_devices), ("data", "model"))
4142
partition_spec = PartitionSpec(None, None, "model")
4243
self.device_sharding = NamedSharding(self.mesh,
4344
partition_spec,

tests/kernels/host_dma_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import jax.numpy as jnp
77
import numpy as np
88
from absl.testing import absltest, parameterized
9-
from jax._src import compilation_cache as cc
109
from jax._src import test_util as jtu
1110
from jax.sharding import NamedSharding, PartitionSpec
1211

@@ -15,7 +14,6 @@
1514
DATA_LOCATION = Literal["device", "host"]
1615

1716

18-
# TODO(jcgu): add into CI tests
1917
@jtu.with_config(jax_numpy_dtype_promotion='strict')
2018
class HostHbmDmaTest(jtu.JaxTestCase):
2119

@@ -27,9 +25,7 @@ def setUp(self):
2725

2826
def tearDown(self):
2927
super().tearDown()
30-
# Reset the cache after each test.
31-
# This can also be achieved by running with JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE=True
32-
cc.reset_cache()
28+
jax.clear_caches()
3329

3430
def create_mesh(self, axis_shapes, axis_names):
3531
"""Creates a JAX device mesh with the default device order."""

0 commit comments

Comments
 (0)