Skip to content

Commit f143701

Browse files
committed
feat: extend KVCache store to support MLU format with index cache.
1 parent 8dfb319 commit f143701

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+457
-525
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -164,16 +164,6 @@ DEFINE_int32(
164164
256,
165165
"Max decode token per sequence which used for ZeroEvictionScheduler.");
166166

167-
DEFINE_uint32(prefetch_timeout,
168-
0,
169-
"Prefetch timeout for prefetch from kv cache store.");
170-
171-
DEFINE_uint32(prefetch_bacth_size,
172-
2,
173-
"Prefetch from kvcache store copy batch size.");
174-
175-
DEFINE_uint32(layers_wise_copy_batchs, 4, "Layer wise H2D copy batchs.");
176-
177167
// --- parallel config ---
178168

179169
DEFINE_int32(dp_size, 1, "Data parallel size for MLA attention.");
@@ -341,6 +331,16 @@ DEFINE_bool(enable_online_preempt_offline,
341331

342332
// --- kvcache store config ---
343333

334+
DEFINE_uint32(prefetch_timeout,
335+
0,
336+
"Prefetch timeout for prefetch from kv cache store.");
337+
338+
DEFINE_uint32(prefetch_bacth_size,
339+
2,
340+
"Prefetch from kvcache store copy batch size.");
341+
342+
DEFINE_uint32(layers_wise_copy_batchs, 4, "Layer wise H2D copy batchs.");
343+
344344
DEFINE_double(host_blocks_factor,
345345
0.0,
346346
"Host block factor, e.g. host block num = host_blocks_factor * "

xllm/core/distributed_runtime/comm_channel.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -372,14 +372,14 @@ void CommChannel::transfer_kv_blocks(
372372

373373
class ClientStreamReceiver : public brpc::StreamInputHandler {
374374
private:
375-
std::atomic<bool>* termination_flag_;
375+
std::shared_ptr<std::atomic<bool>> termination_flag_;
376376
std::shared_ptr<std::atomic<uint32_t>> success_cnt_;
377377
std::promise<void> close_promise_;
378378
std::atomic<bool> promise_set_{false};
379379

380380
public:
381-
ClientStreamReceiver(std::atomic<bool>* termination_flag,
382-
std::shared_ptr<std::atomic<uint32_t>>& success_cnt)
381+
ClientStreamReceiver(std::shared_ptr<std::atomic<bool>> termination_flag,
382+
std::shared_ptr<std::atomic<uint32_t>> success_cnt)
383383
: termination_flag_(termination_flag), success_cnt_(success_cnt) {}
384384

385385
~ClientStreamReceiver() {
@@ -427,8 +427,8 @@ class ClientStreamReceiver : public brpc::StreamInputHandler {
427427

428428
void CommChannel::prefetch_from_storage(
429429
const std::vector<BlockTransferInfo>& block_transfer_info,
430-
std::atomic<bool>* flag,
431-
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) {
430+
std::shared_ptr<std::atomic<bool>> flag,
431+
std::shared_ptr<std::atomic<uint32_t>> success_cnt) {
432432
proto::BlockTransferInfos pb_block_transfer_info;
433433
if (!block_transfer_info_to_proto(block_transfer_info,
434434
&pb_block_transfer_info)) {

xllm/core/distributed_runtime/comm_channel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ class CommChannel {
9999

100100
virtual void prefetch_from_storage(
101101
const std::vector<BlockTransferInfo>& block_transfer_info,
102-
std::atomic<bool>* flag,
103-
std::shared_ptr<std::atomic<uint32_t>>& success_cnt);
102+
std::shared_ptr<std::atomic<bool>> flag,
103+
std::shared_ptr<std::atomic<uint32_t>> success_cnt);
104104

105105
virtual bool get_last_step_result_async(
106106
folly::Promise<std::optional<RawForwardOutput>>& promise);

xllm/core/distributed_runtime/remote_worker.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,12 @@ void RemoteWorker::transfer_kv_blocks(
314314

315315
void RemoteWorker::prefetch_from_storage(
316316
const std::vector<BlockTransferInfo>& block_transfer_info,
317-
std::atomic<bool>* flag,
318-
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) {
317+
std::shared_ptr<std::atomic<bool>> flag,
318+
std::shared_ptr<std::atomic<uint32_t>> success_cnt) {
319319
copy_threadpool_.schedule(
320320
[this,
321-
flag = flag,
322321
block_transfer_info = std::move(block_transfer_info),
322+
flag = flag,
323323
success_cnt = success_cnt]() mutable {
324324
channel_->prefetch_from_storage(block_transfer_info, flag, success_cnt);
325325
});

xllm/core/distributed_runtime/remote_worker.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ class RemoteWorker : public WorkerClient {
121121

122122
virtual void prefetch_from_storage(
123123
const std::vector<BlockTransferInfo>& block_transfer_info,
124-
std::atomic<bool>* flag,
125-
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) override;
124+
std::shared_ptr<std::atomic<bool>> flag,
125+
std::shared_ptr<std::atomic<uint32_t>> success_cnt) override;
126126

127127
// Run the model and return the output.
128128
virtual folly::SemiFuture<std::optional<ForwardOutput>> step_async(

xllm/core/framework/block/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ cc_library(
1111
block_manager_pool.h
1212
block_manager_impl.h
1313
concurrent_block_manager_impl.h
14-
multi_tier_block_manager_pool.h
14+
hierarchy_block_manager_pool.h
1515
SRCS
1616
block.cpp
1717
block_manager_pool.cpp
1818
concurrent_block_manager_impl.cpp
1919
block_manager_impl.cpp
20-
multi_tier_block_manager_pool.cpp
20+
hierarchy_block_manager_pool.cpp
2121
DEPS
2222
$<$<BOOL:${USE_NPU}>:torch_npu>
2323
$<$<BOOL:${USE_NPU}>:graph>

xllm/core/framework/block/block_manager_pool.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class BlockManagerPool : public KVCacheManager {
7676
int32_t get_manager_with_max_free_blocks() const;
7777
int32_t get_dp_rank(Sequence* sequence) const;
7878

79-
void process_beam_search(Sequence* sequence, bool need_swap = false);
79+
bool process_beam_search(Sequence* sequence, bool need_swap = false);
8080

8181
private:
8282
std::vector<std::vector<BlockTransferInfo>> swap_block_transfer_infos_;

xllm/core/framework/block/multi_tier_block_manager_pool.cpp renamed to xllm/core/framework/block/hierarchy_block_manager_pool.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "multi_tier_block_manager_pool.h"
16+
#include "hierarchy_block_manager_pool.h"
1717

1818
#include "block_manager_impl.h"
1919
#include "concurrent_block_manager_impl.h"
2020

2121
namespace xllm {
2222

23-
MultiTierBlockManagerPool::MultiTierBlockManagerPool(
23+
HierarchyBlockManagerPool::HierarchyBlockManagerPool(
2424
const BlockManagerPool::Options& options,
2525
Engine* engine,
2626
int32_t dp_size)
@@ -52,7 +52,7 @@ MultiTierBlockManagerPool::MultiTierBlockManagerPool(
5252
saved_device_blocks_.resize(host_block_managers_.size());
5353
}
5454

55-
void MultiTierBlockManagerPool::deallocate(Sequence* sequence) {
55+
void HierarchyBlockManagerPool::deallocate(Sequence* sequence) {
5656
DCHECK(sequence != nullptr);
5757
// add blocks to the prefix cache
5858
int32_t dp_rank = BlockManagerPool::get_dp_rank(sequence);
@@ -65,7 +65,7 @@ void MultiTierBlockManagerPool::deallocate(Sequence* sequence) {
6565
return;
6666
}
6767

68-
int cached_block_num =
68+
size_t cached_block_num =
6969
sequence->host_kv_state().kv_cache_tokens_num() / options_.block_size();
7070

7171
if (host_blocks->size() > 0) {
@@ -82,7 +82,7 @@ void MultiTierBlockManagerPool::deallocate(Sequence* sequence) {
8282
sequence->host_kv_state().add_kv_blocks(
8383
host_block_managers_[dp_rank]->allocate(needed_block_num));
8484

85-
for (int i = cached_block_num; i < host_blocks->size(); i++) {
85+
for (size_t i = cached_block_num; i < host_blocks->size(); i++) {
8686
if (blocks->at(i).ref_count() != 2) {
8787
continue;
8888
}
@@ -107,7 +107,7 @@ void MultiTierBlockManagerPool::deallocate(Sequence* sequence) {
107107
sequence->reset();
108108
}
109109

110-
bool MultiTierBlockManagerPool::allocate(Sequence* sequence,
110+
bool HierarchyBlockManagerPool::allocate(Sequence* sequence,
111111
size_t num_tokens) {
112112
BlockManagerPool::allocate(sequence, num_tokens);
113113

@@ -137,7 +137,7 @@ bool MultiTierBlockManagerPool::allocate(Sequence* sequence,
137137
return true;
138138
}
139139

140-
void MultiTierBlockManagerPool::allocate_host_shared(Sequence* sequence) {
140+
void HierarchyBlockManagerPool::allocate_host_shared(Sequence* sequence) {
141141
if (options_.enable_prefix_cache()) {
142142
int32_t dp_rank = BlockManagerPool::get_dp_rank(sequence);
143143
std::vector<Block> shared_blocks =
@@ -146,7 +146,7 @@ void MultiTierBlockManagerPool::allocate_host_shared(Sequence* sequence) {
146146
}
147147
}
148148

149-
void MultiTierBlockManagerPool::prefetch_from_storage(
149+
void HierarchyBlockManagerPool::prefetch_from_storage(
150150
std::shared_ptr<Request>& request) {
151151
if (!options_.enable_kvcache_store()) {
152152
return;
@@ -202,7 +202,7 @@ void MultiTierBlockManagerPool::prefetch_from_storage(
202202
}
203203
}
204204

205-
bool MultiTierBlockManagerPool::update_prefetch_result(
205+
bool HierarchyBlockManagerPool::update_prefetch_result(
206206
std::shared_ptr<Request>& request,
207207
const uint32_t timeout) {
208208
if (!options_.enable_kvcache_store()) {
@@ -216,8 +216,9 @@ bool MultiTierBlockManagerPool::update_prefetch_result(
216216
return prefetch_result;
217217
}
218218

219-
void MultiTierBlockManagerPool::transfer_blocks(std::vector<Batch>* batches) {
220-
if (batches != nullptr) {
219+
void HierarchyBlockManagerPool::transfer_blocks(
220+
std::optional<std::vector<Batch>> batches) {
221+
if (batches.has_value()) {
221222
// load blocks from host to device
222223
for (int i = 0; i < batches->size(); i++) {
223224
if (!load_block_transfer_infos_[i].empty()) {
@@ -265,7 +266,7 @@ void MultiTierBlockManagerPool::transfer_blocks(std::vector<Batch>* batches) {
265266
saved_device_blocks_.resize(host_block_managers_.size());
266267
}
267268

268-
void MultiTierBlockManagerPool::get_merged_kvcache_event(
269+
void HierarchyBlockManagerPool::get_merged_kvcache_event(
269270
KvCacheEvent* event) const {
270271
if (host_block_managers_.empty()) {
271272
BlockManagerPool::get_merged_kvcache_event(event);

xllm/core/framework/block/multi_tier_block_manager_pool.h renamed to xllm/core/framework/block/hierarchy_block_manager_pool.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ namespace xllm {
2222

2323
class Engine;
2424

25-
class MultiTierBlockManagerPool : public BlockManagerPool {
25+
class HierarchyBlockManagerPool : public BlockManagerPool {
2626
public:
27-
explicit MultiTierBlockManagerPool(const BlockManagerPool::Options& options,
27+
explicit HierarchyBlockManagerPool(const BlockManagerPool::Options& options,
2828
Engine* engine,
2929
int32_t dp_size = 1);
30-
~MultiTierBlockManagerPool() = default;
30+
~HierarchyBlockManagerPool() = default;
3131

3232
bool allocate(Sequence* sequence, size_t num_tokens) override;
3333

3434
void deallocate(Sequence* sequence) override;
3535

36-
void transfer_blocks(std::vector<Batch>* batches = nullptr) override;
36+
void transfer_blocks(std::optional<std::vector<Batch>> batches) override;
3737

3838
void prefetch_from_storage(std::shared_ptr<Request>& request) override;
3939

xllm/core/framework/block/kv_cache_manager.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class KVCacheManager {
3333
virtual bool allocate(std::vector<Sequence*>& sequences) = 0;
3434
virtual bool allocate(Sequence* sequence, size_t num_tokens) = 0;
3535

36-
virtual void transfer_blocks(std::vector<Batch>* batches = nullptr) {
36+
virtual void transfer_blocks(std::optional<std::vector<Batch>> batches) {
3737
return;
3838
};
3939

0 commit comments

Comments
 (0)