Skip to content

Commit 57082c3

Browse files
committed
feat: extend KVCache store to support MLU format with index cache.
1 parent 449e7a5 commit 57082c3

File tree

5 files changed

+72
-57
lines changed

5 files changed

+72
-57
lines changed

xllm/core/framework/kv_cache/kv_cache.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ torch::Tensor KVCache::get_index_cache() const { return index_cache_; }
3737

3838
std::vector<std::vector<int64_t>> KVCache::get_shapes() {
3939
std::vector<std::vector<int64_t>> tensor_shapes(3);
40-
if (key_cache_.defined()) {
40+
if (key_cache_.defined() && key_cache_.numel() != 0) {
4141
std::vector<int64_t> shape;
4242
auto sizes = key_cache_.sizes();
4343
shape.resize(sizes.size());
@@ -47,7 +47,7 @@ std::vector<std::vector<int64_t>> KVCache::get_shapes() {
4747
tensor_shapes[0] = std::move(shape);
4848
}
4949

50-
if (value_cache_.defined() && key_cache_.numel() != 0) {
50+
if (value_cache_.defined() && value_cache_.numel() != 0) {
5151
std::vector<int64_t> shape;
5252
auto sizes = value_cache_.sizes();
5353
shape.resize(sizes.size());

xllm/core/framework/kv_cache/kv_cache_store.cpp

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,22 @@ bool KVCacheStore::init(const StoreConfig& config,
4343
}
4444
client_ptr_ = client_opt.value();
4545

46-
auto k_tensor_one_block = host_kv_caches_->at(0).get_k_cache();
47-
auto v_tensor_one_block = host_kv_caches_->at(0).get_v_cache();
48-
49-
k_cache_size_per_block_ =
50-
k_tensor_one_block.numel() * k_tensor_one_block.element_size();
51-
v_cache_size_per_block_ =
52-
v_tensor_one_block.numel() * v_tensor_one_block.element_size();
46+
auto k_cache = host_kv_caches_->at(0).get_k_cache();
47+
k_cache_size_per_block_ = k_cache.numel() * k_cache.element_size();
48+
LOG(INFO) << "key cache size per block: " << k_cache_size_per_block_;
49+
50+
auto v_cache = host_kv_caches_->at(0).get_v_cache();
51+
if (v_cache.defined() && v_cache.numel() != 0) {
52+
v_cache_size_per_block_ = v_cache.numel() * v_cache.element_size();
53+
LOG(INFO) << "value cache size per block: " << v_cache_size_per_block_;
54+
}
5355

54-
LOG(INFO) << "k_cache_size_per_block: " << k_cache_size_per_block_;
55-
LOG(INFO) << "v_cache_size_per_block: " << v_cache_size_per_block_;
56+
auto index_cache = host_kv_caches_->at(0).get_index_cache();
57+
if (index_cache.defined() && index_cache.numel() != 0) {
58+
index_cache_size_per_block_ =
59+
index_cache.numel() * index_cache.element_size();
60+
LOG(INFO) << "index cache size per block: " << index_cache_size_per_block_;
61+
}
5662

5763
if (config_.protocol == "rdma") {
5864
if (config_.total_size > 0 && config_.tensor_data != nullptr) {
@@ -103,14 +109,28 @@ uint32_t KVCacheStore::batch_put(
103109

104110
str_keys.emplace_back(str_key);
105111

112+
std::vector<mooncake::Slice> slice;
113+
slice.reserve(3);
114+
106115
void* k_cache =
107116
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
108-
void* v_cache =
109-
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
117+
slice.emplace_back(mooncake::Slice{k_cache, k_cache_size_per_block_});
118+
119+
if (v_cache_size_per_block_ != 0) {
120+
void* v_cache =
121+
host_kv_caches_->at(block_info.dst_block_id).get_v_cache().data_ptr();
122+
slice.emplace_back(mooncake::Slice{v_cache, v_cache_size_per_block_});
123+
}
110124

111-
slices.emplace_back(std::vector<mooncake::Slice>{
112-
mooncake::Slice{k_cache, k_cache_size_per_block_},
113-
mooncake::Slice{v_cache, v_cache_size_per_block_}});
125+
if (index_cache_size_per_block_ != 0) {
126+
void* index_cache = host_kv_caches_->at(block_info.dst_block_id)
127+
.get_index_cache()
128+
.data_ptr();
129+
slice.emplace_back(
130+
mooncake::Slice{index_cache, index_cache_size_per_block_});
131+
}
132+
133+
slices.emplace_back(std::move(slice));
114134
}
115135

116136
if (str_keys.size() == 0) {
@@ -150,16 +170,28 @@ uint32_t KVCacheStore::batch_get(
150170

151171
str_keys.emplace_back(str_key);
152172

173+
std::vector<mooncake::Slice> slice;
174+
slice.reserve(3);
175+
153176
void* k_cache =
154177
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
155-
void* v_cache =
156-
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
178+
slice.emplace_back(mooncake::Slice{k_cache, k_cache_size_per_block_});
157179

158-
slices.insert(
159-
std::make_pair(str_key,
160-
std::vector<mooncake::Slice>{
161-
mooncake::Slice{k_cache, k_cache_size_per_block_},
162-
mooncake::Slice{v_cache, v_cache_size_per_block_}}));
180+
if (v_cache_size_per_block_ != 0) {
181+
void* v_cache =
182+
host_kv_caches_->at(block_info.dst_block_id).get_v_cache().data_ptr();
183+
slice.emplace_back(mooncake::Slice{v_cache, v_cache_size_per_block_});
184+
}
185+
186+
if (index_cache_size_per_block_ != 0) {
187+
void* index_cache = host_kv_caches_->at(block_info.dst_block_id)
188+
.get_index_cache()
189+
.data_ptr();
190+
slice.emplace_back(
191+
mooncake::Slice{index_cache, index_cache_size_per_block_});
192+
}
193+
194+
slices.insert(std::make_pair(str_key, std::move(slice)));
163195
}
164196

165197
if (str_keys.size() == 0) {
@@ -177,24 +209,6 @@ uint32_t KVCacheStore::batch_get(
177209
return success_cnt;
178210
}
179211

180-
uint32_t KVCacheStore::batch_remove(
181-
Slice<BlockTransferInfo>& block_transfer_info) {
182-
CHECK(is_initialized_) << "KVCacheStore is not initialized.";
183-
uint32_t success_cnt = 0;
184-
for (auto block_info : block_transfer_info) {
185-
std::string str_key(reinterpret_cast<const char*>(block_info.hash_key),
186-
MURMUR_HASH3_VALUE_LEN);
187-
str_key.append(std::to_string(config_.tp_rank));
188-
189-
auto result = client_ptr_->Remove(str_key);
190-
191-
if (result.has_value()) {
192-
success_cnt++;
193-
}
194-
}
195-
return success_cnt;
196-
}
197-
198212
uint32_t KVCacheStore::batch_exist(std::vector<std::string>&& keys) {
199213
if (!is_initialized_) {
200214
return 0;

xllm/core/framework/kv_cache/kv_cache_store.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ class KVCacheStore {
4949

5050
uint32_t batch_get(Slice<BlockTransferInfo>& block_transfer_info);
5151

52-
uint32_t batch_remove(Slice<BlockTransferInfo>& block_transfer_info);
53-
5452
uint32_t batch_exist(std::vector<std::string>&& keys);
5553

5654
static KVCacheStore& get_instance() {
@@ -71,8 +69,9 @@ class KVCacheStore {
7169

7270
std::vector<xllm::KVCache>* host_kv_caches_;
7371

74-
uint64_t k_cache_size_per_block_;
75-
uint64_t v_cache_size_per_block_;
72+
uint64_t k_cache_size_per_block_ = 0;
73+
uint64_t v_cache_size_per_block_ = 0;
74+
uint64_t index_cache_size_per_block_ = 0;
7675

7776
std::shared_ptr<mooncake::Client> client_ptr_;
7877
};

xllm/core/framework/kv_cache/multi_tier_kv_cache_transfer.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ uint32_t MultiTierKVCacheTransfer::offload_kv_blocks(
192192

193193
bool MultiTierKVCacheTransfer::d2h_batch_copy(
194194
Slice<BlockTransferInfo>& block_transfer_info) {
195+
#if defined(USE_NPU)
195196
const int64_t num_layers = options_.layers();
196197
uint32_t num_batches =
197198
block_transfer_info.size() * num_layers * cache_tensor_cnt_;
@@ -266,12 +267,14 @@ bool MultiTierKVCacheTransfer::d2h_batch_copy(
266267
delete[] dsts;
267268
delete[] srcs;
268269
delete[] copy_size;
270+
#endif
269271
return true;
270272
}
271273

272274
bool MultiTierKVCacheTransfer::h2d_batch_copy(
273275
const uint64_t batch_id,
274276
Slice<BlockTransferInfo>& block_transfer_info) {
277+
#if defined(USE_NPU)
275278
CHECK(block_transfer_info.size() < BATCH_COPY_MAX_SIZE / cache_tensor_cnt_)
276279
<< "h2d_batch_copy support copy blocks less than "
277280
<< BATCH_COPY_MAX_SIZE / cache_tensor_cnt_ << ", but got "
@@ -353,16 +356,15 @@ bool MultiTierKVCacheTransfer::h2d_batch_copy(
353356
layer_cnt++;
354357
}
355358

356-
ret = aclrtMemcpyBatchAsync(dsts,
357-
copy_size,
358-
srcs,
359-
copy_size,
360-
num_batches * layer_cnt,
361-
attrs,
362-
attrs_indexes,
363-
1,
364-
&fail_index,
365-
stream->get_stream()->stream());
359+
ret = aclrtMemcpyBatch(dsts,
360+
copy_size,
361+
srcs,
362+
copy_size,
363+
num_batches * layer_cnt,
364+
attrs,
365+
attrs_indexes,
366+
1,
367+
&fail_index);
366368

367369
if (ret != 0 || fail_index != SIZE_MAX) {
368370
LOG(ERROR) << "aclrtMemcpyBatch error: " << ret
@@ -390,7 +392,7 @@ bool MultiTierKVCacheTransfer::h2d_batch_copy(
390392
delete[] dsts;
391393
delete[] srcs;
392394
delete[] copy_size;
393-
395+
#endif
394396
return true;
395397
}
396398

xllm/core/framework/kv_cache/multi_tier_kv_cache_transfer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ limitations under the License.
1919

2020
#include <memory>
2121

22-
#include "acl/acl_rt.h"
2322
#include "common/types.h"
2423
#include "framework/kv_cache/kv_cache_store.h"
2524
#include "framework/model/model_input_params.h"
@@ -28,6 +27,7 @@ limitations under the License.
2827
#include "util/threadpool.h"
2928

3029
#if defined(USE_NPU)
30+
#include "acl/acl_rt.h"
3131
#include "platform/npu/npu_layer_synchronizer.h"
3232
#endif
3333

0 commit comments

Comments
 (0)