118118 LRUCacheManager , StagingBufferManager )
119119from tpu_inference .distributed .offload .utils import (
120120 CPU_OFFLOADING_SWAP_OP_TYPE , CpuChunkId , KVCacheSwapFn , ReqId ,
121- TokenProcessor , get_kv_cache_swap_fn , jitted_insert_kv_cache_slices )
121+ get_kv_cache_swap_fn , jitted_insert_kv_cache_slices )
122122from tpu_inference .logger import init_logger
123123from tpu_inference .runner .kv_cache_manager import KVCacheManager
124124from tpu_inference .runner .tpu_runner import TPUModelRunner
@@ -496,8 +496,6 @@ def __init__(self, vllm_config: "VllmConfig"):
496496 self ._reqs_being_loaded = defaultdict [ReqId , set [CpuChunkId ]](set )
497497
498498 model_name = self .vllm_config .model_config .model
499- self .token_processor = TokenProcessor (model_name = model_name ,
500- chunk_size = self .block_size )
501499
502500 self .decode_save = envs .TPU_OFFLOAD_DECODE_SAVE
503501 # NOTE(jcgu): currently, let's make chunk_size == block_size
@@ -528,7 +526,7 @@ def __init__(self, vllm_config: "VllmConfig"):
528526
529527 def _get_request_block_hashes (self , req : "Request" ) -> list [BlockHash ]:
530528 # request's original block_hashes do not include the last partial block
531- # TODO(jcgu): switch back to token_processor
529+ # TODO(jcgu): add an option to use local token_processor
532530 return req .block_hashes
533531
534532 def get_num_new_matched_tokens (
@@ -1160,19 +1158,14 @@ def __init__(self, vllm_config: VllmConfig,
11601158
11611159 self .runner : Optional [TPUModelRunner ] = None
11621160 self .mesh : Optional [Mesh ] = None
1161+ self .swap_in_fn : KVCacheSwapFn = None
1162+ self .swap_out_fn : KVCacheSwapFn = None
11631163 self .swap_op_type = envs .TPU_OFFLOAD_SWAP_OP_TYPE
1164- assert self .swap_op_type in get_args (CPU_OFFLOADING_SWAP_OP_TYPE )
11651164 # TODO(jcgu): check libtpu compatibility for pallas dma kernel
1166- logger .info (
1167- f"(cpu offloading) swap operation type is { self .swap_op_type } " )
1168-
1165+ assert self .swap_op_type in get_args (CPU_OFFLOADING_SWAP_OP_TYPE )
11691166 self .use_bucketed_swap_ops = not envs .TPU_OFFLOAD_SKIP_JAX_PRECOMPILE
1170- logger .info (
1171- f"(cpu offloading) use_bucketed_swap_ops={ self .use_bucketed_swap_ops } "
1172- )
1173-
1174- self .swap_in_fn : KVCacheSwapFn = None
1175- self .swap_out_fn : KVCacheSwapFn = None
1167+ logger .info (f" swap operation type is { self .swap_op_type } , "
1168+ f"use_bucketed_swap_ops={ self .use_bucketed_swap_ops } ." )
11761169
11771170 # cpu cache
11781171 self .num_cpu_chunks = envs .TPU_OFFLOAD_NUM_CPU_CHUNKS
@@ -1181,13 +1174,11 @@ def __init__(self, vllm_config: VllmConfig,
11811174 model_name = self .vllm_config .model_config .model
11821175 logger .info (
11831176 f"Model name is { model_name } , KV block_size={ self .block_size } " )
1184- self .token_processor = TokenProcessor (model_name = model_name ,
1185- chunk_size = self .block_size )
11861177
11871178 self .cpu_chunk_size = self .block_size
11881179 # Thread pool for asynchronous TPU->CPU copies
1189- self .save_executor = ThreadPoolExecutor (max_workers = 4 ,
1190- thread_name_prefix = "tpu_saver " )
1180+ self .save_executor = ThreadPoolExecutor (
1181+ max_workers = 4 , thread_name_prefix = "tpu_save_handler " )
11911182 self .finished_save_reqs : set [ReqId ] = set ()
11921183 self .finished_load_reqs : set [ReqId ] = set ()
11931184 # Tracks if wait_for_save has been called for the current step's metadata.
@@ -1298,10 +1289,11 @@ def _precompile_kv_swap_operations(self):
12981289
12991290 # 3. Pre-compile CPU -> TPU transfer (used in load)
13001291 split_size_list = [self .block_size ] * num_blocks
1301- chunked_dummy_kv_cpu = [
1302- jax .lax .split (flat_layer_cache , split_size_list , axis = 0 )
1303- for flat_layer_cache in dummy_kv_cpu
1304- ]
1292+ chunked_dummy_kv_cpu = jax .tree .map (
1293+ lambda flat_layer_cache : jax .lax .split (
1294+ flat_layer_cache , split_size_list , axis = 0 ),
1295+ dummy_kv_cpu )
1296+
13051297 chunked_dummy_kv_tpu = self .swap_in_fn (chunked_dummy_kv_cpu )
13061298 jax .block_until_ready (chunked_dummy_kv_tpu )
13071299
@@ -1374,13 +1366,13 @@ def _bucketed_swap_out_fn(
13741366
13751367 # Fast path: handle bucket-sized transfers
13761368 if num_blocks in BLOCK_SIZE_BUCKETS :
1369+ split_size_list = [self .block_size ] * num_blocks
13771370 flat_kv_caches_cpu = self .swap_out_fn (flat_kv_caches_tpu )
13781371 jax .block_until_ready (flat_kv_caches_cpu )
1379- split_size_list = [self .block_size ] * num_blocks
1380- return [
1381- jax .lax .split (flat_layer_cache , split_size_list , axis = 0 )
1382- for flat_layer_cache in flat_kv_caches_cpu
1383- ]
1372+ return jax .tree .map (
1373+ lambda flat_layer_cache : jax .lax .split (
1374+ flat_layer_cache , split_size_list , axis = 0 ),
1375+ flat_kv_caches_cpu )
13841376
13851377 # Bucket decomposition path
13861378 decomposed_block_sizes = self ._decompose_into_buckets (num_blocks )
@@ -1580,12 +1572,10 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int],
15801572 # NOTE(jcgu): we keep cpu_chunk_size == block_size
15811573 split_size_list = [self .cpu_chunk_size
15821574 ] * num_blocks_to_save
1583- chunks_on_cpu = [
1584- jax .lax .split (flat_layer_cache ,
1585- split_size_list ,
1586- axis = 0 )
1587- for flat_layer_cache in flat_kv_caches_cpu
1588- ]
1575+ chunks_on_cpu = jax .tree .map (
1576+ lambda flat_layer_cache : jax .lax .split (
1577+ flat_layer_cache , split_size_list , axis = 0 ),
1578+ flat_kv_caches_cpu )
15891579
15901580 if chunks_on_cpu and chunks_on_cpu [0 ]:
15911581 jax .block_until_ready (chunks_on_cpu )
0 commit comments