Skip to content

Commit 48d2e3d

Browse files
committed
follow up changes in the upstream; and update test scripts
Signed-off-by: Juncheng Gu <jcgu@google.com>
1 parent 41f3c73 commit 48d2e3d

File tree

4 files changed

+5
-13
lines changed

4 files changed

+5
-13
lines changed

.buildkite/features/KV_Cache_Offload.yml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@ steps:
99
commands:
1010
- |
1111
.buildkite/scripts/run_in_docker.sh \
12-
python3 -m pytest -s -v /workspace/tpu_inference/tests/distributed/offload/tpu_offload_connector_scheduler_test.py \
13-
/workspace/tpu_inference/tests/distributed/offload/tpu_offload_connector_worker_test.py \
14-
/workspace/tpu_inference/tests/distributed/offload/tpu_offload_cpu_backend_test.py \
15-
/workspace/tpu_inference/tests/distributed/offload/tpu_offload_manager_test.py \
16-
/workspace/tpu_inference/tests/distributed/offload/tpu_offload_utils_test.py \
17-
/workspace/tpu_inference/tests/distributed/offload/tpu_offload_accuracy_test.py
12+
python3 -m pytest -s -v /workspace/tpu_inference/tests/distributed/offload/
1813
- label: "Record correctness test result for KV Cache Offload"
1914
key: "record_KV_Cache_Offload_CorrectnessTest"
2015
depends_on: "KV_Cache_Offload_CorrectnessTest"

examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,7 @@ spec:
1717
command:
1818
- /bin/bash
1919
- -c
20-
- "pytest -sv tests/distributed/offload/tpu_offload_cpu_backend_test.py"
21-
- "pytest -sv tests/distributed/offload/tpu_offload_connector_worker_test.py"
22-
- "pytest -sv tests/distributed/offload/tpu_offload_connector_scheduler_test.py"
23-
- "pytest -sv tests/distributed/offload/tpu_offload_utils_test.py"
24-
- "pytest -sv tests/distributed/offload/tpu_offload_manager_test.py"
25-
- "pytest -sv tests/distributed/offload/tpu_offload_accuracy_test.py"
20+
- "pytest -sv tests/distributed/offload/"
2621
env:
2722
- name: HUGGING_FACE_HUB_TOKEN
2823
valueFrom:

tests/distributed/offload/tpu_offload_connector_worker_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def setUp(self):
9494
def tearDown(self):
9595
super().tearDown()
9696
cc.reset_cache()
97+
jax.clear_caches()
9798

9899
def create_mesh(self, axis_shapes, axis_names):
99100
"""Creates a JAX device mesh with the default device order."""

tpu_inference/worker/tpu_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ def determine_available_memory(self) -> int:
289289

290290
kv_cache_specs = self.model_runner.get_kv_cache_spec()
291291
num_layers = len(kv_cache_specs)
292-
vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs)
292+
vllm_page_size_bytes = get_uniform_page_size(
293+
list(kv_cache_specs.values()))
293294
stage_buffer_size_bytes = staging_buffer_pages * num_layers * vllm_page_size_bytes
294295

295296
total_hbm_avail = total_hbm_avail - stage_buffer_size_bytes

0 commit comments

Comments
 (0)