Skip to content

Commit 43f8f1e

Browse files
committed
worker_test: multi requests; acc_test: precompile
Signed-off-by: Juncheng Gu <jcgu@google.com>
1 parent ff4d31f commit 43f8f1e

File tree

3 files changed

+83
-58
lines changed

3 files changed

+83
-58
lines changed

tests/distributed/offload/tpu_offload_accuracy_test.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import itertools
34
import os
45
import time
56

@@ -49,12 +50,13 @@ def _test_kv_cache_cpu_offloading_accuracy(
4950
sampling_config: SamplingParams,
5051
kv_transfer_config: KVTransferConfig,
5152
swap_op_type: str,
53+
skip_precompile: str,
5254
decode_save: str,
5355
):
5456
with monkeypatch.context():
5557
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
56-
os.environ['TPU_OFFLOAD_SKIP_JAX_PRECOMPILE'] = '1'
5758
os.environ['TPU_OFFLOAD_SWAP_OP_TYPE'] = swap_op_type
59+
os.environ['TPU_OFFLOAD_SKIP_JAX_PRECOMPILE'] = skip_precompile
5860
os.environ['TPU_OFFLOAD_DECODE_SAVE'] = decode_save
5961
llm = LLM(model="meta-llama/Llama-3.2-3B",
6062
max_model_len=1024,
@@ -98,12 +100,14 @@ def test_kv_cache_cpu_offloading_accuracy(
98100
):
99101
swap_op_types = ["pallas", "jax"]
100102
decode_saves = ["0", "1"]
101-
for swap_op_type in swap_op_types:
102-
for decode_save in decode_saves:
103-
_test_kv_cache_cpu_offloading_accuracy(
104-
monkeypatch,
105-
sampling_config,
106-
kv_transfer_config,
107-
swap_op_type,
108-
decode_save,
109-
)
103+
skip_precompile = ["0", "1"]
104+
for swap_op_type, decode_save, _skip_precompile in itertools.product(
105+
swap_op_types, decode_saves, skip_precompile):
106+
_test_kv_cache_cpu_offloading_accuracy(
107+
monkeypatch,
108+
sampling_config,
109+
kv_transfer_config,
110+
swap_op_type,
111+
_skip_precompile,
112+
decode_save,
113+
)

tests/distributed/offload/tpu_offload_connector_worker_test.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def setUp(self):
7070
self.vllm_config = MockVllmConfig(block_size=_DEFAULT_BLOCK_SIZE)
7171
self.num_layers = 80
7272
self.num_blocks = 128
73-
self.num_cpu_chunks = 24
73+
self.num_cpu_chunks = 128
7474
self.block_size = self.vllm_config.cache_config.block_size
7575
self.num_heads = 8
7676
self.head_size = 128
@@ -205,40 +205,57 @@ def test_precompile_run_success(self, swap_op_type: str):
205205

206206
@parameterized.named_parameters(
207207
dict(
208-
testcase_name="_regular_single_block_save",
208+
testcase_name="_single_block",
209209
num_blocks_to_save=1,
210210
num_requests=1,
211211
),
212212
dict(
213-
testcase_name="_regular_multi_requests_single_block_save",
214-
num_blocks_to_save=2,
215-
num_requests=4,
213+
testcase_name="_multi_requests_single_block",
214+
num_blocks_to_save=1,
215+
num_requests=6,
216216
),
217217
dict(
218-
testcase_name="_regular_multi_block_save",
218+
testcase_name="_multi_blocks",
219219
num_blocks_to_save=5,
220220
num_requests=1,
221221
),
222222
dict(
223-
testcase_name="_regular_multi_block_save_with_compile_jax",
223+
testcase_name="_multi_requests_multi_blocks",
224+
num_blocks_to_save=5,
225+
num_requests=6,
226+
),
227+
dict(
228+
testcase_name="_multi_blocks_with_compile_jax",
224229
num_blocks_to_save=5,
225230
num_requests=1,
226231
use_precompiled_swap_ops=True,
227232
),
228233
dict(
229-
testcase_name=
230-
"_regular_multi_request_single_block_save_with_compile_jax",
234+
testcase_name="_multi_requests_single_block_with_compile_jax",
231235
num_blocks_to_save=1,
232236
num_requests=6,
233237
use_precompiled_swap_ops=True,
234238
),
235239
dict(
236-
testcase_name="_regular_multi_block_save_with_compile_pallas",
240+
testcase_name="_multi_requests_multi_blocks_with_compile_jax",
241+
num_blocks_to_save=5,
242+
num_requests=6,
243+
use_precompiled_swap_ops=True,
244+
),
245+
dict(
246+
testcase_name="_multi_blocks_with_compile_pallas",
237247
num_blocks_to_save=5,
238248
num_requests=1,
239249
use_precompiled_swap_ops=True,
240250
swap_op_type="pallas",
241251
),
252+
dict(
253+
testcase_name="_multi_requests_multi_blocks_with_compile_pallas",
254+
num_blocks_to_save=5,
255+
num_requests=6,
256+
use_precompiled_swap_ops=True,
257+
swap_op_type="pallas",
258+
),
242259
dict(
243260
testcase_name="_final_save",
244261
num_blocks_to_save=1,
@@ -370,13 +387,13 @@ def test_tpu_connector_save(
370387

371388
@parameterized.named_parameters(
372389
dict(
373-
testcase_name="_single_block_",
390+
testcase_name="_single_block",
374391
num_blocks_to_operate=1,
375392
num_requests=1,
376393
),
377394
dict(
378-
testcase_name="_multi_requests_",
379-
num_blocks_to_operate=2,
395+
testcase_name="_multi_requests_single_block",
396+
num_blocks_to_operate=1,
380397
num_requests=4,
381398
),
382399
dict(
@@ -387,9 +404,23 @@ def test_tpu_connector_save(
387404
swap_op_type="jax",
388405
),
389406
dict(
390-
testcase_name="_multi_blocks_compile_pallas",
407+
testcase_name="_multi_requests_single_block_compile_jax",
408+
num_blocks_to_operate=1,
409+
num_requests=6,
410+
use_precompiled_swap_ops=True,
411+
swap_op_type="jax",
412+
),
413+
dict(
414+
testcase_name="_multi_requests_multi_blocks_compile_jax",
391415
num_blocks_to_operate=5,
392-
num_requests=1,
416+
num_requests=6,
417+
use_precompiled_swap_ops=True,
418+
swap_op_type="jax",
419+
),
420+
dict(
421+
testcase_name="_multi_requests_multi_blocks_compile_pallas",
422+
num_blocks_to_operate=5,
423+
num_requests=6,
393424
use_precompiled_swap_ops=True,
394425
swap_op_type="pallas",
395426
),

tpu_inference/distributed/offload/tpu_offload_connector.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@
118118
LRUCacheManager, StagingBufferManager)
119119
from tpu_inference.distributed.offload.utils import (
120120
CPU_OFFLOADING_SWAP_OP_TYPE, CpuChunkId, KVCacheSwapFn, ReqId,
121-
TokenProcessor, get_kv_cache_swap_fn, jitted_insert_kv_cache_slices)
121+
get_kv_cache_swap_fn, jitted_insert_kv_cache_slices)
122122
from tpu_inference.logger import init_logger
123123
from tpu_inference.runner.kv_cache_manager import KVCacheManager
124124
from tpu_inference.runner.tpu_runner import TPUModelRunner
@@ -496,8 +496,6 @@ def __init__(self, vllm_config: "VllmConfig"):
496496
self._reqs_being_loaded = defaultdict[ReqId, set[CpuChunkId]](set)
497497

498498
model_name = self.vllm_config.model_config.model
499-
self.token_processor = TokenProcessor(model_name=model_name,
500-
chunk_size=self.block_size)
501499

502500
self.decode_save = envs.TPU_OFFLOAD_DECODE_SAVE
503501
# NOTE(jcgu): currently, let's make chunk_size == block_size
@@ -528,7 +526,7 @@ def __init__(self, vllm_config: "VllmConfig"):
528526

529527
def _get_request_block_hashes(self, req: "Request") -> list[BlockHash]:
530528
# request's original block_hashes do not include the last partial block
531-
# TODO(jcgu): switch back to token_processor
529+
# TODO(jcgu): add an option to use local token_processor
532530
return req.block_hashes
533531

534532
def get_num_new_matched_tokens(
@@ -1160,19 +1158,14 @@ def __init__(self, vllm_config: VllmConfig,
11601158

11611159
self.runner: Optional[TPUModelRunner] = None
11621160
self.mesh: Optional[Mesh] = None
1161+
self.swap_in_fn: KVCacheSwapFn = None
1162+
self.swap_out_fn: KVCacheSwapFn = None
11631163
self.swap_op_type = envs.TPU_OFFLOAD_SWAP_OP_TYPE
1164-
assert self.swap_op_type in get_args(CPU_OFFLOADING_SWAP_OP_TYPE)
11651164
# TODO(jcgu): check libtpu compatibility for pallas dma kernel
1166-
logger.info(
1167-
f"(cpu offloading) swap operation type is {self.swap_op_type}")
1168-
1165+
assert self.swap_op_type in get_args(CPU_OFFLOADING_SWAP_OP_TYPE)
11691166
self.use_bucketed_swap_ops = not envs.TPU_OFFLOAD_SKIP_JAX_PRECOMPILE
1170-
logger.info(
1171-
f"(cpu offloading) use_bucketed_swap_ops={self.use_bucketed_swap_ops}"
1172-
)
1173-
1174-
self.swap_in_fn: KVCacheSwapFn = None
1175-
self.swap_out_fn: KVCacheSwapFn = None
1167+
logger.info(f" swap operation type is {self.swap_op_type}, "
1168+
f"use_bucketed_swap_ops={self.use_bucketed_swap_ops}.")
11761169

11771170
# cpu cache
11781171
self.num_cpu_chunks = envs.TPU_OFFLOAD_NUM_CPU_CHUNKS
@@ -1181,13 +1174,11 @@ def __init__(self, vllm_config: VllmConfig,
11811174
model_name = self.vllm_config.model_config.model
11821175
logger.info(
11831176
f"Model name is {model_name}, KV block_size={self.block_size}")
1184-
self.token_processor = TokenProcessor(model_name=model_name,
1185-
chunk_size=self.block_size)
11861177

11871178
self.cpu_chunk_size = self.block_size
11881179
# Thread pool for asynchronous TPU->CPU copies
1189-
self.save_executor = ThreadPoolExecutor(max_workers=4,
1190-
thread_name_prefix="tpu_saver")
1180+
self.save_executor = ThreadPoolExecutor(
1181+
max_workers=4, thread_name_prefix="tpu_save_handler")
11911182
self.finished_save_reqs: set[ReqId] = set()
11921183
self.finished_load_reqs: set[ReqId] = set()
11931184
# Tracks if wait_for_save has been called for the current step's metadata.
@@ -1298,10 +1289,11 @@ def _precompile_kv_swap_operations(self):
12981289

12991290
# 3. Pre-compile CPU -> TPU transfer (used in load)
13001291
split_size_list = [self.block_size] * num_blocks
1301-
chunked_dummy_kv_cpu = [
1302-
jax.lax.split(flat_layer_cache, split_size_list, axis=0)
1303-
for flat_layer_cache in dummy_kv_cpu
1304-
]
1292+
chunked_dummy_kv_cpu = jax.tree.map(
1293+
lambda flat_layer_cache: jax.lax.split(
1294+
flat_layer_cache, split_size_list, axis=0),
1295+
dummy_kv_cpu)
1296+
13051297
chunked_dummy_kv_tpu = self.swap_in_fn(chunked_dummy_kv_cpu)
13061298
jax.block_until_ready(chunked_dummy_kv_tpu)
13071299

@@ -1374,13 +1366,13 @@ def _bucketed_swap_out_fn(
13741366

13751367
# Fast path: handle bucket-sized transfers
13761368
if num_blocks in BLOCK_SIZE_BUCKETS:
1369+
split_size_list = [self.block_size] * num_blocks
13771370
flat_kv_caches_cpu = self.swap_out_fn(flat_kv_caches_tpu)
13781371
jax.block_until_ready(flat_kv_caches_cpu)
1379-
split_size_list = [self.block_size] * num_blocks
1380-
return [
1381-
jax.lax.split(flat_layer_cache, split_size_list, axis=0)
1382-
for flat_layer_cache in flat_kv_caches_cpu
1383-
]
1372+
return jax.tree.map(
1373+
lambda flat_layer_cache: jax.lax.split(
1374+
flat_layer_cache, split_size_list, axis=0),
1375+
flat_kv_caches_cpu)
13841376

13851377
# Bucket decomposition path
13861378
decomposed_block_sizes = self._decompose_into_buckets(num_blocks)
@@ -1580,12 +1572,10 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int],
15801572
# NOTE(jcgu): we keep cpu_chunk_size == block_size
15811573
split_size_list = [self.cpu_chunk_size
15821574
] * num_blocks_to_save
1583-
chunks_on_cpu = [
1584-
jax.lax.split(flat_layer_cache,
1585-
split_size_list,
1586-
axis=0)
1587-
for flat_layer_cache in flat_kv_caches_cpu
1588-
]
1575+
chunks_on_cpu = jax.tree.map(
1576+
lambda flat_layer_cache: jax.lax.split(
1577+
flat_layer_cache, split_size_list, axis=0),
1578+
flat_kv_caches_cpu)
15891579

15901580
if chunks_on_cpu and chunks_on_cpu[0]:
15911581
jax.block_until_ready(chunks_on_cpu)

0 commit comments

Comments
 (0)