Skip to content

Commit a5ec87d

Browse files
committed
rm saving behavior
Signed-off-by: Juncheng Gu <jcgu@google.com>
1 parent 21ae0de commit a5ec87d

File tree

3 files changed

+3
-59
lines changed

3 files changed

+3
-59
lines changed

tpu_inference/distributed/offload/tpu_offload_connector.py

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,6 @@
128128
# kv cache layout needed by cpu offloading mechanism
129129
REQUIRED_KV_CACHE_LAYOUT = "NHD"
130130

131-
# default swap op type
132-
DEFAULT_HOST_HBM_SWAP_OP_TYPE = "jax"
133-
134131
BLOCK_SIZE_BUCKETS = [1, 2, 4, 8, 16]
135132

136133
# we keep our operations at vllm's block granularity,
@@ -139,9 +136,7 @@
139136
# 1. [supported] drop: drop the entire partial block
140137
# 2. pad: pad to a full block
141138
# 3. dynamic: keep the partial block as is.
142-
PARTIAL_BLOCK_SAVE_BEHAVIOR = Literal["drop", "pad", "dynamic"]
143-
144-
DEFAULT_TPU_OFFLOAD_CPU_CHUNKS = 1024
139+
PARTIAL_BLOCK_SAVE_BEHAVIOR = Literal["drop"]
145140

146141

147142
@dataclass
@@ -512,24 +507,7 @@ def __init__(self, vllm_config: "VllmConfig"):
512507
# real-chunk-size in save and load
513508
self.cpu_chunk_size = self.block_size
514509

515-
# TODO(jcgu): rm
516-
# define partial_block saving behavior
517-
self.partial_block_save_behavior: PARTIAL_BLOCK_SAVE_BEHAVIOR = \
518-
os.getenv("TPU_OFFLOAD_PARTIAL_BLOCK_SAVE_BEHAVIOR", "drop")
519-
assert self.partial_block_save_behavior in get_args(
520-
PARTIAL_BLOCK_SAVE_BEHAVIOR
521-
), f"{self.partial_block_save_behavior} not in {get_args(PARTIAL_BLOCK_SAVE_BEHAVIOR)}"
522-
self.partial_block_dynamic_pad_lower_limit = \
523-
int(os.getenv("TPU_OFFLOAD_PARTIAL_BLOCK_DYNAMIC_PAD_LOWER_LIMIT", "0"))
524-
if self.partial_block_save_behavior == "dynamic":
525-
if self.partial_block_dynamic_pad_lower_limit <= 0:
526-
self.partial_block_save_behavior == "drop"
527-
elif self.partial_block_dynamic_pad_lower_limit >= self.block_size:
528-
self.partial_block_save_behavior == "pad"
529-
logger.info(
530-
f" partial_block_save_behavior is configed to {self.partial_block_save_behavior}, but we only support drop now."
531-
)
532-
self.partial_block_save_behavior = "drop"
510+
self.partial_block_save_behavior: PARTIAL_BLOCK_SAVE_BEHAVIOR = "drop"
533511

534512
# config staging buffer
535513
# NOTE(jcgu): Need to find a way to grab page_size_bytes in scheduler
@@ -547,7 +525,6 @@ def __init__(self, vllm_config: "VllmConfig"):
547525
f"model_name={model_name}, "
548526
f"decode_save={self.decode_save}, "
549527
f"partial_block_save_behavior={self.partial_block_save_behavior}, "
550-
f"partial_block_dynamic_pad_lower_limit={self.partial_block_dynamic_pad_lower_limit}, "
551528
f"num_staging_blocks={self.num_staging_blocks}.")
552529

553530
def _get_request_block_hashes(self, req: "Request") -> list[BlockHash]:
@@ -668,27 +645,6 @@ def get_num_new_matched_tokens(
668645
# external_computed_tokens, load_kv_async
669646
return num_to_load, False
670647

671-
def _adjust_last_partial_block(self,
672-
last_partial_block_num_tokens: int) -> bool:
673-
"""
674-
adjust prompt token / len based on pre-configed save behavior
675-
when the last block of request's token is partially used.
676-
In order to keep all the saved kv be aligned with block_size,
677-
we may
678-
1. drop the partial block
679-
2. pad the partial block to be a full block
680-
3. drop or pad based on actual num_tokens in the last partial block
681-
682-
Input: num of tokens in the last partial block (could be 0)
683-
Output: the last partial block should be kept (True) or dropped (False)
684-
"""
685-
if self.partial_block_save_behavior == "pad":
686-
return True if last_partial_block_num_tokens > 0 else False
687-
elif self.partial_block_save_behavior == "drop":
688-
return False
689-
elif self.partial_block_save_behavior == "dynamic":
690-
return True if last_partial_block_num_tokens >= self.partial_block_dynamic_pad_lower_limit else False
691-
692648
def update_state_after_alloc(self, request: "Request",
693649
blocks: "KVCacheBlocks",
694650
num_external_tokens: int):

tpu_inference/distributed/offload/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525

2626
CPU_OFFLOADING_SWAP_OP_TYPE = Literal["jax", "pallas"]
2727

28-
DEFAULT_TPU_OFFLOAD_STAGING_BUFFER_TOKENS = 8192
29-
3028

3129
@dataclass(order=True)
3230
class CacheKey:
@@ -110,10 +108,6 @@ def get_kv_connector_cache_layout():
110108
return None
111109

112110

113-
def get_default_kv_connector_staging_buffer_tokens() -> int:
114-
return DEFAULT_TPU_OFFLOAD_STAGING_BUFFER_TOKENS
115-
116-
117111
SwapFn = Callable[
118112
[
119113
List[jax.Array], # src_kv_caches

tpu_inference/worker/tpu_worker.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626

2727
from tpu_inference import envs, utils
2828
from tpu_inference.distributed import jax_parallel_state
29-
from tpu_inference.distributed.offload.utils import \
30-
get_default_kv_connector_staging_buffer_tokens
3129
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
3230
get_node_id)
3331
from tpu_inference.layers.common.sharding import ShardingConfigManager
@@ -296,11 +294,7 @@ def determine_available_memory(self) -> int:
296294
kv_transfer_config = self.vllm_config.kv_transfer_config
297295
if kv_transfer_config.kv_connector == "TPUOffloadConnector" and kv_transfer_config.kv_connector_module_path == "tpu_inference.distributed.offload.tpu_offload_connector":
298296
# If kv offloading is enabled, we need to account for the memory used by the KV transfer buffer.
299-
_default_staging_buffer_tokens = get_default_kv_connector_staging_buffer_tokens(
300-
)
301-
staging_buffer_tokens = int(
302-
os.getenv("TPU_OFFLOAD_STAGING_BUFFER_TOKENS",
303-
str(_default_staging_buffer_tokens)))
297+
staging_buffer_tokens = envs.TPU_OFFLOAD_STAGING_BUFFER_TOKENS
304298
# calculate staging buffer size
305299
staging_buffer_pages = staging_buffer_tokens // self.vllm_config.cache_config.block_size
306300

0 commit comments

Comments
 (0)