Skip to content

Commit e663031

Browse files
committed
refactor: abstract multi-tier kv cache transfer from WorkerImpl.
1 parent da8d4d8 commit e663031

18 files changed

+996
-762
lines changed

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -402,11 +402,8 @@ void WorkerService::TransferBlocks(
402402
std::vector<BlockTransferInfo> block_transfer_info;
403403
uint64_t batch_id = proto_to_block_transfer_info(*req, block_transfer_info);
404404

405-
if (batch_id == UNINITIALIZED_BATCH_ID) {
406-
resp->set_success_cnt(worker_->transfer_kv_blocks(block_transfer_info));
407-
} else {
408-
worker_->transfer_kv_blocks(batch_id, std::move(block_transfer_info));
409-
}
405+
resp->set_success_cnt(
406+
worker_->transfer_kv_blocks(batch_id, std::move(block_transfer_info)));
410407
return;
411408
}
412409

@@ -482,7 +479,8 @@ void WorkerService::PrefetchFromStorage(
482479
std::min(i + options_.prefetch_bacth_size(),
483480
transfer_slice.size()));
484481

485-
auto success_cnt = worker_->prefetch_from_storage(current_slice);
482+
auto success_cnt = worker_->transfer_kv_blocks(UNINITIALIZED_BATCH_ID,
483+
current_slice);
486484

487485
if (success_cnt != current_slice.size() ||
488486
i + options_.prefetch_bacth_size() >= transfer_slice.size()) {

xllm/core/framework/kv_cache/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
$<$<BOOL:${USE_NPU}>:llm_data_dist_transfer.h>
1616
$<$<BOOL:${USE_NPU}>:spec_kv_cache_transfer.h>
1717
kv_cache_store.h
18+
multi_tier_kv_cache_transfer.h
1819
SRCS
1920
embedding_allocator.cpp
2021
$<$<BOOL:${USE_NPU}>:hccl_kv_cache_transfer.cpp>
@@ -23,6 +24,7 @@ cc_library(
2324
$<$<BOOL:${USE_NPU}>:llm_data_dist_transfer.cpp>
2425
$<$<BOOL:${USE_NPU}>:spec_kv_cache_transfer.cpp>
2526
kv_cache_store.cpp
27+
multi_tier_kv_cache_transfer.cpp
2628
DEPS
2729
:common
2830
$<$<BOOL:${USE_NPU}>:graph>

xllm/core/framework/kv_cache/kv_cache.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,41 @@ torch::Tensor KVCache::get_k_cache() const { return key_cache_; }
3535
torch::Tensor KVCache::get_v_cache() const { return value_cache_; }
3636
torch::Tensor KVCache::get_index_cache() const { return index_cache_; }
3737

38+
std::vector<std::vector<int64_t>> KVCache::get_shapes() {
39+
std::vector<std::vector<int64_t>> tensor_shapes(3);
40+
if (key_cache_.defined()) {
41+
std::vector<int64_t> shape;
42+
auto sizes = key_cache_.sizes();
43+
shape.resize(sizes.size());
44+
for (int i = 0; i < sizes.size(); ++i) {
45+
shape[i] = sizes[i];
46+
}
47+
tensor_shapes[0] = std::move(shape);
48+
}
49+
50+
if (value_cache_.defined() && key_cache_.numel() != 0) {
51+
std::vector<int64_t> shape;
52+
auto sizes = value_cache_.sizes();
53+
shape.resize(sizes.size());
54+
for (int i = 0; i < sizes.size(); ++i) {
55+
shape[i] = sizes[i];
56+
}
57+
tensor_shapes[1] = std::move(shape);
58+
}
59+
60+
if (index_cache_.defined() && index_cache_.numel() != 0) {
61+
std::vector<int64_t> shape;
62+
auto sizes = index_cache_.sizes();
63+
shape.resize(sizes.size());
64+
for (int i = 0; i < sizes.size(); ++i) {
65+
shape[i] = sizes[i];
66+
}
67+
tensor_shapes[2] = std::move(shape);
68+
}
69+
70+
return tensor_shapes;
71+
}
72+
3873
void KVCache::swap_blocks(torch::Tensor& src_tensor,
3974
torch::Tensor& dst_tensor) {
4075
// batch select keys and values

xllm/core/framework/kv_cache/kv_cache.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class KVCache final {
4040
torch::Tensor get_v_cache() const;
4141
torch::Tensor get_index_cache() const;
4242

43+
std::vector<std::vector<int64_t>> get_shapes();
44+
4345
std::shared_ptr<XTensor> get_k_xtensor() const;
4446
std::shared_ptr<XTensor> get_v_xtensor() const;
4547

0 commit comments

Comments
 (0)