128128# kv cache layout needed by cpu offloading mechanism
129129REQUIRED_KV_CACHE_LAYOUT = "NHD"
130130
131- # default swap op type
132- DEFAULT_HOST_HBM_SWAP_OP_TYPE = "jax"
133-
134131BLOCK_SIZE_BUCKETS = [1 , 2 , 4 , 8 , 16 ]
135132
136133# we keep our operations at vllm's block granularity,
139136# 1. [supported] drop: drop the entire partial block
140137# 2. pad: pad to a full block
141138# 3. dynamic: keep the partial block as is.
142- PARTIAL_BLOCK_SAVE_BEHAVIOR = Literal ["drop" , "pad" , "dynamic" ]
143-
144- DEFAULT_TPU_OFFLOAD_CPU_CHUNKS = 1024
139+ PARTIAL_BLOCK_SAVE_BEHAVIOR = Literal ["drop" ]
145140
146141
147142@dataclass
@@ -512,24 +507,7 @@ def __init__(self, vllm_config: "VllmConfig"):
512507 # real-chunk-size in save and load
513508 self .cpu_chunk_size = self .block_size
514509
515- # TODO(jcgu): rm
516- # define partial_block saving behavior
517- self .partial_block_save_behavior : PARTIAL_BLOCK_SAVE_BEHAVIOR = \
518- os .getenv ("TPU_OFFLOAD_PARTIAL_BLOCK_SAVE_BEHAVIOR" , "drop" )
519- assert self .partial_block_save_behavior in get_args (
520- PARTIAL_BLOCK_SAVE_BEHAVIOR
521- ), f"{ self .partial_block_save_behavior } not in { get_args (PARTIAL_BLOCK_SAVE_BEHAVIOR )} "
522- self .partial_block_dynamic_pad_lower_limit = \
523- int (os .getenv ("TPU_OFFLOAD_PARTIAL_BLOCK_DYNAMIC_PAD_LOWER_LIMIT" , "0" ))
524- if self .partial_block_save_behavior == "dynamic" :
525- if self .partial_block_dynamic_pad_lower_limit <= 0 :
526- self .partial_block_save_behavior == "drop"
527- elif self .partial_block_dynamic_pad_lower_limit >= self .block_size :
528- self .partial_block_save_behavior == "pad"
529- logger .info (
530- f" partial_block_save_behavior is configed to { self .partial_block_save_behavior } , but we only support drop now."
531- )
532- self .partial_block_save_behavior = "drop"
510+ self .partial_block_save_behavior : PARTIAL_BLOCK_SAVE_BEHAVIOR = "drop"
533511
534512 # config staging buffer
535513 # NOTE(jcgu): Need to find a way to grab page_size_bytes in scheduler
@@ -547,7 +525,6 @@ def __init__(self, vllm_config: "VllmConfig"):
547525 f"model_name={ model_name } , "
548526 f"decode_save={ self .decode_save } , "
549527 f"partial_block_save_behavior={ self .partial_block_save_behavior } , "
550- f"partial_block_dynamic_pad_lower_limit={ self .partial_block_dynamic_pad_lower_limit } , "
551528 f"num_staging_blocks={ self .num_staging_blocks } ." )
552529
553530 def _get_request_block_hashes (self , req : "Request" ) -> list [BlockHash ]:
@@ -668,27 +645,6 @@ def get_num_new_matched_tokens(
668645 # external_computed_tokens, load_kv_async
669646 return num_to_load , False
670647
671- def _adjust_last_partial_block (self ,
672- last_partial_block_num_tokens : int ) -> bool :
673- """
674- adjust prompt token / len based on pre-configed save behavior
675- when the last block of request's token is partially used.
676- In order to keep all the saved kv be aligned with block_size,
677- we may
678- 1. drop the partial block
679- 2. pad the partial block to be a full block
680- 3. drop or pad based on actual num_tokens in the last partial block
681-
682- Input: num of tokens in the last partial block (could be 0)
683- Output: the last partial block should be kept (True) or dropped (False)
684- """
685- if self .partial_block_save_behavior == "pad" :
686- return True if last_partial_block_num_tokens > 0 else False
687- elif self .partial_block_save_behavior == "drop" :
688- return False
689- elif self .partial_block_save_behavior == "dynamic" :
690- return True if last_partial_block_num_tokens >= self .partial_block_dynamic_pad_lower_limit else False
691-
692648 def update_state_after_alloc (self , request : "Request" ,
693649 blocks : "KVCacheBlocks" ,
694650 num_external_tokens : int ):
0 commit comments