112112 from vllm .v1 .request import Request
113113 from vllm .forward_context import ForwardContext
114114
115+ from tpu_inference import envs
115116from tpu_inference .distributed .offload .cpu_backend import LocalCPUBackend
116117from tpu_inference .distributed .offload .offload_manager import (
117118 LRUCacheManager , StagingBufferManager )
118119from tpu_inference .distributed .offload .utils import (
119120 CPU_OFFLOADING_SWAP_OP_TYPE , CpuChunkId , KVCacheSwapFn , ReqId ,
120- TokenProcessor , get_default_kv_connector_staging_buffer_tokens ,
121- get_kv_cache_swap_fn , jitted_insert_kv_cache_slices )
121+ TokenProcessor , get_kv_cache_swap_fn , jitted_insert_kv_cache_slices )
122122from tpu_inference .logger import init_logger
123123from tpu_inference .runner .kv_cache_manager import KVCacheManager
124124from tpu_inference .runner .tpu_runner import TPUModelRunner
@@ -480,9 +480,7 @@ def __init__(self, vllm_config: "VllmConfig"):
480480 self .block_size = vllm_config .cache_config .block_size
481481
482482 # offloading manager
483- self .num_cpu_chunks = int (
484- os .getenv ("TPU_OFFLOAD_NUM_CPU_CHUNKS" ,
485- str (DEFAULT_TPU_OFFLOAD_CPU_CHUNKS )))
483+ self .num_cpu_chunks = envs .TPU_OFFLOAD_NUM_CPU_CHUNKS
486484 self .offload_manager = LRUCacheManager (
487485 num_cpu_chunks = self .num_cpu_chunks )
488486
@@ -506,14 +504,15 @@ def __init__(self, vllm_config: "VllmConfig"):
506504 self .token_processor = TokenProcessor (model_name = model_name ,
507505 chunk_size = self .block_size )
508506
509- self .decode_save = os . getenv ( " TPU_OFFLOAD_DECODE_SAVE" , "0" ) == "1"
507+ self .decode_save = envs . TPU_OFFLOAD_DECODE_SAVE
510508 # NOTE(jcgu): currently, let's make chunk_size == block_size
511509 # chunk_size == n * block_size lead to
512510 # 1. multi-size chunks
513511 # 2. complicated resize (split, concatenate) operations due to
514512 # real-chunk-size in save and load
515513 self .cpu_chunk_size = self .block_size
516514
515+ # TODO(jcgu): rm
517516 # define partial_block saving behavior
518517 self .partial_block_save_behavior : PARTIAL_BLOCK_SAVE_BEHAVIOR = \
519518 os .getenv ("TPU_OFFLOAD_PARTIAL_BLOCK_SAVE_BEHAVIOR" , "drop" )
@@ -535,11 +534,7 @@ def __init__(self, vllm_config: "VllmConfig"):
535534 # config staging buffer
536535 # NOTE(jcgu): Need to find a way to grab page_size_bytes in scheduler
537536 # otherwise, we can only use # of tokens as input, instead of buffer size in GB
538- _default_staging_buffer_tokens = get_default_kv_connector_staging_buffer_tokens (
539- )
540- num_staging_buffer_tokens = int (
541- os .getenv ("TPU_OFFLOAD_STAGING_BUFFER_TOKENS" ,
542- str (_default_staging_buffer_tokens )))
537+ num_staging_buffer_tokens = envs .TPU_OFFLOAD_STAGING_BUFFER_TOKENS
543538 self .num_staging_blocks = num_staging_buffer_tokens // self .block_size
544539 self .staging_buffer_manager = StagingBufferManager (
545540 num_blocks = self .num_staging_blocks )
@@ -1214,15 +1209,13 @@ def __init__(self, vllm_config: VllmConfig,
12141209
12151210 self .runner : Optional [TPUModelRunner ] = None
12161211 self .mesh : Optional [Mesh ] = None
1217- self .swap_op_type = os .getenv ("TPU_OFFLOAD_SWAP_OP_TYPE" ,
1218- default = DEFAULT_HOST_HBM_SWAP_OP_TYPE )
1212+ self .swap_op_type = envs .TPU_OFFLOAD_SWAP_OP_TYPE
12191213 assert self .swap_op_type in get_args (CPU_OFFLOADING_SWAP_OP_TYPE )
12201214 # TODO(jcgu): check libtpu compatibility for pallas dma kernel
12211215 logger .info (
12221216 f"(cpu offloading) swap operation type is { self .swap_op_type } " )
12231217
1224- self .use_bucketed_swap_ops = os .getenv (
1225- "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE" , "0" ) == "0"
1218+ self .use_bucketed_swap_ops = not envs .TPU_OFFLOAD_SKIP_JAX_PRECOMPILE
12261219 logger .info (
12271220 f"(cpu offloading) use_bucketed_swap_ops={ self .use_bucketed_swap_ops } "
12281221 )
@@ -1231,9 +1224,7 @@ def __init__(self, vllm_config: VllmConfig,
12311224 self .swap_out_fn : KVCacheSwapFn = None
12321225
12331226 # cpu cache
1234- self .num_cpu_chunks = int (
1235- os .getenv ("TPU_OFFLOAD_NUM_CPU_CHUNKS" ,
1236- str (DEFAULT_TPU_OFFLOAD_CPU_CHUNKS )))
1227+ self .num_cpu_chunks = envs .TPU_OFFLOAD_NUM_CPU_CHUNKS
12371228 self .cpu_backend = LocalCPUBackend (num_cpu_chunks = self .num_cpu_chunks )
12381229 # The worker needs its own token processor to generate keys.
12391230 model_name = self .vllm_config .model_config .model
0 commit comments