-
Notifications
You must be signed in to change notification settings - Fork 53
[Feat][TPU Offload] KV cache offload to local cpu buffer #1163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure:
|
48d2e3d to
eeb1ed8
Compare
|
this is really really big pr & it's difficult to review. are you planning to split this PR into multiple smaller ones to make it easier for reviewers? |
54785e1 to
cc97559
Compare
5dfe210 to
e780427
Compare
|
/lgtm |
e780427 to
87db579
Compare
That's true. The core implementation is in the distributed/offload folder (~2400 lines), others are for tests. But this feature is modularized and makes negligible changes to the core of tpu-inference. |
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: dannawang <dannawang@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: dannawang <dannawang@google.com>
Signed-off-by: Juncheng Gu <jcgu@google.com>
Signed-off-by: dannawang <dannawang@google.com>
6113fb4 to
9ac152e
Compare
Description
TL;DR, add the feature of offloading KV cache to host cpu buffer (similar to the native CPU offloading in vLLM).
This PR allows offloading computed KV cache (at the granularity of block / page) (of prompt tokens or even including the generated tokens) to the host CPU buffer and bringing them back to TPU HBM when there are cache hits, to avoid re-compute.
implementation
Following the general kv connector interfaces and the logic of native CPU offloading in vLLM, it introduces a TPUOffloadConnector, which is the central component of managing and executing the offloading logic. Within the
TPUOffloadConnector, there are:TPUOffloadConnectorScheduler.The
TPU-CPUswap operations are the core of this PR. We provide two approaches to move KV cache data:usage
This feature can be used by setting the
kv_tranfer_configin vLLM engine:--kv-transfer-config '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.distributed.offload.tpu_offload_connector","kv_role":"kv_both"}'example: examples/offload/gke/benchmarks/deploy-cpu-offload.yaml
And, it can be configured through the following environment variables (we will move them into the
KVTransferConfig.kv_connector_extra_config):TPU_OFFLOAD_SKIP_JAX_PRECOMPILE: skipping pre-compiling the swap functions, default=0. We would suggest to turn on pre-compile. All swap operations are applied at block-granularity; when swap pre-compile is turned on, we will break a request of swap into multiple swap-operations following the predefined bucket size list (1 block, 2 blocks, 4 blocks, 8 blocks, 16 blocks) to avoid re-compile (thanks to @saikat-royc).TPU_OFFLOAD_SWAP_OP_TYPE:jax(default), orpallas.TPU_OFFLOAD_NUM_CPU_CHUNKS: host CPU buffer capacity in terms of number of chunks (equivalent to kv cache blocks / pages), default=1024.TPU_OFFLOAD_NUM_STAGING_BLOCKS: the size of the staging buffer in the TPU HBM, default=128.TPU_OFFLOAD_DECODE_SAVE: save the KV cache of generated (decode) tokens, default=False.Tests
pytest -s -v tests/distributed/offload/pytest -s -v tests/kernels/host_dma_test.pyChecklist
Before submitting this PR, please make sure: