Skip to content

Commit df3b091

Browse files
committed
staging_tokens --> staging_blocks
Signed-off-by: Juncheng Gu <jcgu@google.com>
1 parent 0fc7dad commit df3b091

File tree

4 files changed

+26
-35
lines changed

4 files changed

+26
-35
lines changed

tests/distributed/offload/tpu_offload_connector_scheduler_test.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ def scheduler_factory():
6161
def _scheduler(
6262
block_size: int = _DEFAULT_BLOCK_SIZE,
6363
offload_decode_save: int = 0,
64-
offload_staging_buffer_tokens: int = -1,
64+
offload_num_staging_blocks: int = -1,
6565
offload_num_cpu_chunks: int = -1,
6666
):
6767
# update config
6868
vllm_config = MockVllmConfig(block_size=block_size)
6969
os.environ["TPU_OFFLOAD_DECODE_SAVE"] = str(offload_decode_save)
70-
if offload_staging_buffer_tokens >= 0:
71-
os.environ["TPU_OFFLOAD_STAGING_BUFFER_TOKENS"] = str(
72-
offload_staging_buffer_tokens)
70+
if offload_num_staging_blocks >= 0:
71+
os.environ["TPU_OFFLOAD_NUM_STAGING_BLOCKS"] = str(
72+
offload_num_staging_blocks)
7373
if offload_num_cpu_chunks > 0:
7474
os.environ["TPU_OFFLOAD_NUM_CPU_CHUNKS"] = str(
7575
offload_num_cpu_chunks)
@@ -111,9 +111,8 @@ def test_get_num_new_matched_tokens_hit(self, scheduler_factory,
111111
5. skip 1 block + full-hit + only 1 staging block
112112
6. skip 1 block + full-hit + no staging block
113113
"""
114-
num_staging_tokens = num_staging_blocks * _DEFAULT_BLOCK_SIZE
115114
scheduler = scheduler_factory(
116-
offload_staging_buffer_tokens=num_staging_tokens)
115+
offload_num_staging_blocks=num_staging_blocks)
117116
prompt_len = scheduler.block_size * num_prompt_blocks
118117
num_computed_tokens = scheduler.block_size * num_computed_blocks
119118
num_blocks_to_load = num_matched_blocks - num_computed_blocks
@@ -231,7 +230,7 @@ def test_build_connector_meta_new_prefill(self, scheduler_factory,
231230
"""
232231
num_staging_blocks = num_staging_tokens // _DEFAULT_BLOCK_SIZE
233232
scheduler = scheduler_factory(
234-
offload_staging_buffer_tokens=num_staging_tokens,
233+
offload_num_staging_blocks=num_staging_blocks,
235234
offload_num_cpu_chunks=100)
236235

237236
# calculate the groundtruth
@@ -347,10 +346,9 @@ def test_build_connector_meta_decode_with_save(self, scheduler_factory,
347346
2. th N-th decode (hit block bounary) + not decode_save (no save)
348347
"""
349348

350-
scheduler = scheduler_factory(
351-
offload_decode_save=decode_save,
352-
offload_staging_buffer_tokens=_DEFAULT_BLOCK_SIZE * 10,
353-
offload_num_cpu_chunks=10)
349+
scheduler = scheduler_factory(offload_decode_save=decode_save,
350+
offload_num_staging_blocks=10,
351+
offload_num_cpu_chunks=10)
354352

355353
prompt_tokens = list(range(prompt_len))
356354
generated_tokens = list(range(prompt_len, seq_len))

tpu_inference/distributed/offload/tpu_offload_connector.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -511,9 +511,8 @@ def __init__(self, vllm_config: "VllmConfig"):
511511

512512
# config staging buffer
513513
# NOTE(jcgu): Need to find a way to grab page_size_bytes in scheduler
514-
# otherwise, we can only use # of tokens as input, instead of buffer size in GB
515-
num_staging_buffer_tokens = envs.TPU_OFFLOAD_STAGING_BUFFER_TOKENS
516-
self.num_staging_blocks = num_staging_buffer_tokens // self.block_size
514+
# otherwise, we can only use # of blocks as input, instead of buffer size in GB
515+
self.num_staging_blocks = envs.TPU_OFFLOAD_NUM_STAGING_BLOCKS
517516
self.staging_buffer_manager = StagingBufferManager(
518517
num_blocks=self.num_staging_blocks)
519518

@@ -698,19 +697,15 @@ def _prepare_req_meta(
698697
block_hashes = self._get_request_block_hashes(_request)
699698
self.offload_manager.touch(block_hashes)
700699

701-
# only consider the tokens covered by block_hashes
700+
# only consider the tokens covered by block_hashes;
701+
# currently full blocks only
702702
num_total_blocks = len(block_hashes)
703703
num_total_tokens = min(num_total_blocks * self.block_size,
704704
len(tracker.token_ids))
705705
num_full_blocks = num_total_tokens // self.block_size
706-
num_full_blocks_tokens = num_full_blocks * self.block_size
707-
# adjust last partial block
708-
last_partial_block_num_tokens = num_total_tokens - num_full_blocks_tokens
709-
need_last_block = self._adjust_last_partial_block(
710-
last_partial_block_num_tokens)
711-
adjusted_num_total_tokens = num_total_tokens if need_last_block else num_full_blocks_tokens
712-
adjusted_num_total_blocks = num_full_blocks + (1 if need_last_block
713-
else 0)
706+
num_full_block_tokens = num_full_blocks * self.block_size
707+
adjusted_num_total_tokens = num_full_block_tokens
708+
adjusted_num_total_blocks = num_full_blocks
714709
assert adjusted_num_total_blocks <= len(tracker.block_ids)
715710

716711
has_new_tokens = adjusted_num_total_tokens > tracker.save_watermark

tpu_inference/envs.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
TPU_OFFLOAD_SWAP_OP_TYPE: str = "jax"
2929
TPU_OFFLOAD_DECODE_SAVE: bool = False
3030
TPU_OFFLOAD_NUM_CPU_CHUNKS: int = 1024
31-
TPU_OFFLOAD_STAGING_BUFFER_TOKENS: int = 8192
31+
TPU_OFFLOAD_NUM_STAGING_BLOCKS: int = 128
3232

3333

3434
def env_with_choices(
@@ -127,21 +127,21 @@ def _get_validated_env() -> str | None:
127127
# Ray compiled DAG channel type for TPU
128128
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
129129
env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
130-
# kv offload to dram: save kv in the decode phase
131-
"TPU_OFFLOAD_DECODE_SAVE":
132-
lambda: bool(int(os.getenv("TPU_OFFLOAD_DECODE_SAVE", "0"))),
130+
# kv offload to dram: skip pre-compiling swap-related jax functions
131+
"TPU_OFFLOAD_SKIP_JAX_PRECOMPILE":
132+
lambda: bool(int(os.getenv("TPU_OFFLOAD_SKIP_JAX_PRECOMPILE", "0"))),
133133
# kv offload to dram: swap function type: jax, or pallas
134134
"TPU_OFFLOAD_SWAP_OP_TYPE":
135135
lambda: os.getenv("TPU_OFFLOAD_SWAP_OP_TYPE", "jax"),
136+
# kv offload to dram: save kv in the decode phase
137+
"TPU_OFFLOAD_DECODE_SAVE":
138+
lambda: bool(int(os.getenv("TPU_OFFLOAD_DECODE_SAVE", "0"))),
136139
# kv offload to dram: dram space size in # of chunks / blocks
137140
"TPU_OFFLOAD_NUM_CPU_CHUNKS":
138141
lambda: int(os.getenv("TPU_OFFLOAD_NUM_CPU_CHUNKS", "1024")),
139-
# kv offload to dram: dram space size in # of chunks / blocks
140-
"TPU_OFFLOAD_SKIP_JAX_PRECOMPILE":
141-
lambda: bool(int(os.getenv("TPU_OFFLOAD_SKIP_JAX_PRECOMPILE", "0"))),
142142
# kv offload to dram: size of staging buffer (hbm) for swap
143-
"TPU_OFFLOAD_STAGING_BUFFER_TOKENS":
144-
lambda: int(os.getenv("TPU_OFFLOAD_STAGING_BUFFER_TOKENS", "16384")),
143+
"TPU_OFFLOAD_NUM_STAGING_BLOCKS":
144+
lambda: int(os.getenv("TPU_OFFLOAD_NUM_STAGING_BLOCKS", "128")),
145145
}
146146

147147

tpu_inference/worker/tpu_worker.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,7 @@ def determine_available_memory(self) -> int:
294294
kv_transfer_config = self.vllm_config.kv_transfer_config
295295
if kv_transfer_config.kv_connector == "TPUOffloadConnector" and kv_transfer_config.kv_connector_module_path == "tpu_inference.distributed.offload.tpu_offload_connector":
296296
# If kv offloading is enabled, we need to account for the memory used by the KV transfer buffer.
297-
staging_buffer_tokens = envs.TPU_OFFLOAD_STAGING_BUFFER_TOKENS
298-
# calculate staging buffer size
299-
staging_buffer_pages = staging_buffer_tokens // self.vllm_config.cache_config.block_size
297+
staging_buffer_pages = envs.TPU_OFFLOAD_NUM_STAGING_BLOCKS
300298

301299
kv_cache_specs = self.model_runner.get_kv_cache_spec()
302300
num_layers = len(kv_cache_specs)

0 commit comments

Comments
 (0)