Skip to content

Commit 21ae0de

Browse files
committed
offload envs
Signed-off-by: Juncheng Gu <jcgu@google.com>
1 parent cd5cce2 commit 21ae0de

File tree

5 files changed

+29
-31
lines changed

5 files changed

+29
-31
lines changed

tests/distributed/offload/tpu_offload_connector_scheduler_test.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,12 @@ def scheduler_factory():
6262
def _scheduler(
6363
block_size: int = _DEFAULT_BLOCK_SIZE,
6464
offload_decode_save: int = 0,
65-
offload_partial_block_save_behavior: str = "drop",
66-
offload_partial_block_dynamic_pad_lower_limit: int = 0,
6765
offload_staging_buffer_tokens: int = -1,
6866
offload_num_cpu_chunks: int = DEFAULT_TPU_OFFLOAD_CPU_CHUNKS,
6967
):
7068
# update config
7169
vllm_config = MockVllmConfig(block_size=block_size)
7270
os.environ["TPU_OFFLOAD_DECODE_SAVE"] = str(offload_decode_save)
73-
os.environ[
74-
"TPU_OFFLOAD_PARTIAL_BLOCK_SAVE_BEHAVIOR"] = offload_partial_block_save_behavior
75-
os.environ["TPU_OFFLOAD_PARTIAL_BLOCK_DYNAMIC_PAD_LOWER_LIMIT"] = str(
76-
offload_partial_block_dynamic_pad_lower_limit)
7771
if offload_staging_buffer_tokens >= 0:
7872
os.environ["TPU_OFFLOAD_STAGING_BUFFER_TOKENS"] = str(
7973
offload_staging_buffer_tokens)
@@ -238,7 +232,6 @@ def test_build_connector_meta_new_prefill(self, scheduler_factory,
238232
"""
239233
num_staging_blocks = num_staging_tokens // _DEFAULT_BLOCK_SIZE
240234
scheduler = scheduler_factory(
241-
offload_partial_block_save_behavior="drop",
242235
offload_staging_buffer_tokens=num_staging_tokens,
243236
offload_num_cpu_chunks=100)
244237

tpu_inference/distributed/offload/cpu_backend.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010

1111
logger = init_logger(__name__)
1212

13-
GB = 1024**3
14-
DEFAULT_CPU_CACHE_SIZE_BYTES = 1 * GB
15-
1613

1714
class LocalCPUBackend:
1815
"""

tpu_inference/distributed/offload/offload_manager.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212

1313
logger = init_logger(__name__)
1414

15-
GB = 1024**3
16-
DEFAULT_CPU_CACHE_SIZE_BYTES = 1 * GB
17-
1815
ChunkHash = BlockHash
1916

2017

tpu_inference/distributed/offload/tpu_offload_connector.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,13 @@
112112
from vllm.v1.request import Request
113113
from vllm.forward_context import ForwardContext
114114

115+
from tpu_inference import envs
115116
from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend
116117
from tpu_inference.distributed.offload.offload_manager import (
117118
LRUCacheManager, StagingBufferManager)
118119
from tpu_inference.distributed.offload.utils import (
119120
CPU_OFFLOADING_SWAP_OP_TYPE, CpuChunkId, KVCacheSwapFn, ReqId,
120-
TokenProcessor, get_default_kv_connector_staging_buffer_tokens,
121-
get_kv_cache_swap_fn, jitted_insert_kv_cache_slices)
121+
TokenProcessor, get_kv_cache_swap_fn, jitted_insert_kv_cache_slices)
122122
from tpu_inference.logger import init_logger
123123
from tpu_inference.runner.kv_cache_manager import KVCacheManager
124124
from tpu_inference.runner.tpu_runner import TPUModelRunner
@@ -480,9 +480,7 @@ def __init__(self, vllm_config: "VllmConfig"):
480480
self.block_size = vllm_config.cache_config.block_size
481481

482482
# offloading manager
483-
self.num_cpu_chunks = int(
484-
os.getenv("TPU_OFFLOAD_NUM_CPU_CHUNKS",
485-
str(DEFAULT_TPU_OFFLOAD_CPU_CHUNKS)))
483+
self.num_cpu_chunks = envs.TPU_OFFLOAD_NUM_CPU_CHUNKS
486484
self.offload_manager = LRUCacheManager(
487485
num_cpu_chunks=self.num_cpu_chunks)
488486

@@ -506,14 +504,15 @@ def __init__(self, vllm_config: "VllmConfig"):
506504
self.token_processor = TokenProcessor(model_name=model_name,
507505
chunk_size=self.block_size)
508506

509-
self.decode_save = os.getenv("TPU_OFFLOAD_DECODE_SAVE", "0") == "1"
507+
self.decode_save = envs.TPU_OFFLOAD_DECODE_SAVE
510508
# NOTE(jcgu): currently, let's make chunk_size == block_size
511509
# chunk_size == n * block_size lead to
512510
# 1. multi-size chunks
513511
# 2. complicated resize (split, concatenate) operations due to
514512
# real-chunk-size in save and load
515513
self.cpu_chunk_size = self.block_size
516514

515+
# TODO(jcgu): rm
517516
# define partial_block saving behavior
518517
self.partial_block_save_behavior: PARTIAL_BLOCK_SAVE_BEHAVIOR = \
519518
os.getenv("TPU_OFFLOAD_PARTIAL_BLOCK_SAVE_BEHAVIOR", "drop")
@@ -535,11 +534,7 @@ def __init__(self, vllm_config: "VllmConfig"):
535534
# config staging buffer
536535
# NOTE(jcgu): Need to find a way to grab page_size_bytes in scheduler
537536
# otherwise, we can only use # of tokens as input, instead of buffer size in GB
538-
_default_staging_buffer_tokens = get_default_kv_connector_staging_buffer_tokens(
539-
)
540-
num_staging_buffer_tokens = int(
541-
os.getenv("TPU_OFFLOAD_STAGING_BUFFER_TOKENS",
542-
str(_default_staging_buffer_tokens)))
537+
num_staging_buffer_tokens = envs.TPU_OFFLOAD_STAGING_BUFFER_TOKENS
543538
self.num_staging_blocks = num_staging_buffer_tokens // self.block_size
544539
self.staging_buffer_manager = StagingBufferManager(
545540
num_blocks=self.num_staging_blocks)
@@ -1214,15 +1209,13 @@ def __init__(self, vllm_config: VllmConfig,
12141209

12151210
self.runner: Optional[TPUModelRunner] = None
12161211
self.mesh: Optional[Mesh] = None
1217-
self.swap_op_type = os.getenv("TPU_OFFLOAD_SWAP_OP_TYPE",
1218-
default=DEFAULT_HOST_HBM_SWAP_OP_TYPE)
1212+
self.swap_op_type = envs.TPU_OFFLOAD_SWAP_OP_TYPE
12191213
assert self.swap_op_type in get_args(CPU_OFFLOADING_SWAP_OP_TYPE)
12201214
# TODO(jcgu): check libtpu compatibility for pallas dma kernel
12211215
logger.info(
12221216
f"(cpu offloading) swap operation type is {self.swap_op_type}")
12231217

1224-
self.use_bucketed_swap_ops = os.getenv(
1225-
"TPU_OFFLOAD_SKIP_JAX_PRECOMPILE", "0") == "0"
1218+
self.use_bucketed_swap_ops = not envs.TPU_OFFLOAD_SKIP_JAX_PRECOMPILE
12261219
logger.info(
12271220
f"(cpu offloading) use_bucketed_swap_ops={self.use_bucketed_swap_ops}"
12281221
)
@@ -1231,9 +1224,7 @@ def __init__(self, vllm_config: VllmConfig,
12311224
self.swap_out_fn: KVCacheSwapFn = None
12321225

12331226
# cpu cache
1234-
self.num_cpu_chunks = int(
1235-
os.getenv("TPU_OFFLOAD_NUM_CPU_CHUNKS",
1236-
str(DEFAULT_TPU_OFFLOAD_CPU_CHUNKS)))
1227+
self.num_cpu_chunks = envs.TPU_OFFLOAD_NUM_CPU_CHUNKS
12371228
self.cpu_backend = LocalCPUBackend(num_cpu_chunks=self.num_cpu_chunks)
12381229
# The worker needs its own token processor to generate keys.
12391230
model_name = self.vllm_config.model_config.model

tpu_inference/envs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
NUM_SLICES: int = 1
2525
RAY_USAGE_STATS_ENABLED: str = "0"
2626
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
27+
TPU_OFFLOAD_SKIP_JAX_PRECOMPILE: bool = False
28+
TPU_OFFLOAD_SWAP_OP_TYPE: str = "jax"
29+
TPU_OFFLOAD_DECODE_SAVE: bool = False
30+
TPU_OFFLOAD_NUM_CPU_CHUNKS: int = 1024
31+
TPU_OFFLOAD_STAGING_BUFFER_TOKENS: int = 8192
2732

2833

2934
def env_with_choices(
@@ -122,6 +127,21 @@ def _get_validated_env() -> str | None:
122127
# Ray compiled DAG channel type for TPU
123128
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
124129
env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
130+
# kv offload to dram: save kv in the decode phase
131+
"TPU_OFFLOAD_DECODE_SAVE":
132+
lambda: bool(int(os.getenv("TPU_OFFLOAD_DECODE_SAVE", "0"))),
133+
# kv offload to dram: swap function type: jax, or pallas
134+
"TPU_OFFLOAD_SWAP_OP_TYPE":
135+
lambda: os.getenv("TPU_OFFLOAD_SWAP_OP_TYPE", "jax"),
136+
# kv offload to dram: dram space size in # of chunks / blocks
137+
"TPU_OFFLOAD_NUM_CPU_CHUNKS":
138+
lambda: int(os.getenv("TPU_OFFLOAD_NUM_CPU_CHUNKS", "1024")),
139+
# kv offload to dram: dram space size in # of chunks / blocks
140+
"TPU_OFFLOAD_SKIP_JAX_PRECOMPILE":
141+
lambda: bool(int(os.getenv("TPU_OFFLOAD_SKIP_JAX_PRECOMPILE", "0"))),
142+
# kv offload to dram: size of staging buffer (hbm) for swap
143+
"TPU_OFFLOAD_STAGING_BUFFER_TOKENS":
144+
lambda: int(os.getenv("TPU_OFFLOAD_STAGING_BUFFER_TOKENS", "16384")),
125145
}
126146

127147

0 commit comments

Comments
 (0)