Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,16 @@ DEFINE_bool(enable_online_preempt_offline,

// --- kvcache store config ---

DEFINE_uint32(prefetch_timeout,
0,
"Prefetch timeout for prefetch from kv cache store.");

DEFINE_uint32(prefetch_bacth_size,
2,
"Prefetch from kvcache store copy batch size.");

DEFINE_uint32(layers_wise_copy_batchs, 4, "Layer wise H2D copy batchs.");

DEFINE_double(host_blocks_factor,
0.0,
"Host block factor, e.g. host block num = host_blocks_factor * "
Expand Down
6 changes: 6 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ DECLARE_bool(use_zero_evict);

DECLARE_int32(max_decode_token_per_sequence);

DECLARE_uint32(prefetch_timeout);

DECLARE_uint32(prefetch_bacth_size);

DECLARE_uint32(layers_wise_copy_batchs);

DECLARE_string(priority_strategy);

DECLARE_bool(enable_online_preempt_offline);
Expand Down
3 changes: 3 additions & 0 deletions xllm/core/common/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ std::string Options::to_string() const {
<< ", enable_service_routing: " << enable_service_routing()
<< ", enable_cache_upload: " << enable_cache_upload()
<< ", enable_kvcache_store: " << enable_kvcache_store()
<< ", prefetch_timeout: " << prefetch_timeout()
<< ", prefetch_bacth_size: " << prefetch_bacth_size()
<< ", layers_wise_copy_batchs: " << layers_wise_copy_batchs()
<< ", store_protocol: " << store_protocol()
<< ", store_master_server_address: " << store_master_server_address()
<< ", store_metadata_server: " << store_metadata_server()
Expand Down
9 changes: 9 additions & 0 deletions xllm/core/common/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ class Options {
// Index ID for internal server ID, which must be set different values
// if the model supports multiple version or there are multiple models.
PROPERTY(int64_t, server_idx) = 0;

// Prefetch timeout for prefetch from kv cache store
PROPERTY(uint32_t, prefetch_timeout) = 0;

// Prefetch from kvcache store copy batch size
PROPERTY(uint32_t, prefetch_bacth_size) = 2;

// Layer wise H2D copy batchs
PROPERTY(uint32_t, layers_wise_copy_batchs) = 4;
};

} // namespace xllm
13 changes: 7 additions & 6 deletions xllm/core/distributed_runtime/comm_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,14 @@ void CommChannel::transfer_kv_blocks(

class ClientStreamReceiver : public brpc::StreamInputHandler {
private:
const std::atomic<bool>& termination_flag_;
std::shared_ptr<std::atomic<bool>> termination_flag_;
std::shared_ptr<std::atomic<uint32_t>> success_cnt_;
std::promise<void> close_promise_;
std::atomic<bool> promise_set_{false};

public:
ClientStreamReceiver(const std::atomic<bool>& termination_flag,
std::shared_ptr<std::atomic<uint32_t>>& success_cnt)
ClientStreamReceiver(std::shared_ptr<std::atomic<bool>> termination_flag,
std::shared_ptr<std::atomic<uint32_t>> success_cnt)
: termination_flag_(termination_flag), success_cnt_(success_cnt) {}

~ClientStreamReceiver() {
Expand All @@ -398,9 +398,10 @@ class ClientStreamReceiver : public brpc::StreamInputHandler {
int32_t success_cnt = std::stoi(msg_str);

if (success_cnt > 0 &&
!termination_flag_.load(std::memory_order_acquire)) {
!termination_flag_->load(std::memory_order_acquire)) {
success_cnt_->fetch_add(success_cnt, std::memory_order_relaxed);
} else {
termination_flag_->store(true, std::memory_order_release);
brpc::StreamClose(id);
if (!promise_set_.exchange(true)) {
close_promise_.set_value();
Expand All @@ -425,9 +426,9 @@ class ClientStreamReceiver : public brpc::StreamInputHandler {
};

void CommChannel::prefetch_from_storage(
const std::atomic<bool>& flag,
const std::vector<BlockTransferInfo>& block_transfer_info,
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) {
std::shared_ptr<std::atomic<bool>> flag,
std::shared_ptr<std::atomic<uint32_t>> success_cnt) {
proto::BlockTransferInfos pb_block_transfer_info;
if (!block_transfer_info_to_proto(block_transfer_info,
&pb_block_transfer_info)) {
Expand Down
4 changes: 2 additions & 2 deletions xllm/core/distributed_runtime/comm_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ class CommChannel {
const std::vector<BlockTransferInfo>& block_transfer_info);

virtual void prefetch_from_storage(
const std::atomic<bool>& flag,
const std::vector<BlockTransferInfo>& block_transfer_info,
std::shared_ptr<std::atomic<uint32_t>>& success_cnt);
std::shared_ptr<std::atomic<bool>> flag,
std::shared_ptr<std::atomic<uint32_t>> success_cnt);

virtual bool get_last_step_result_async(
folly::Promise<std::optional<RawForwardOutput>>& promise);
Expand Down
8 changes: 4 additions & 4 deletions xllm/core/distributed_runtime/remote_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,15 +313,15 @@ void RemoteWorker::transfer_kv_blocks(
}

void RemoteWorker::prefetch_from_storage(
const std::atomic<bool>& flag,
const std::vector<BlockTransferInfo>& block_transfer_info,
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) {
std::shared_ptr<std::atomic<bool>> flag,
std::shared_ptr<std::atomic<uint32_t>> success_cnt) {
copy_threadpool_.schedule(
[this,
flag = &flag,
block_transfer_info = std::move(block_transfer_info),
flag = flag,
success_cnt = success_cnt]() mutable {
channel_->prefetch_from_storage(flag, block_transfer_info, success_cnt);
channel_->prefetch_from_storage(block_transfer_info, flag, success_cnt);
});
}

Expand Down
4 changes: 2 additions & 2 deletions xllm/core/distributed_runtime/remote_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ class RemoteWorker : public WorkerClient {
const std::vector<BlockTransferInfo>& block_transfer_info) override;

virtual void prefetch_from_storage(
const std::atomic<bool>& flag,
const std::vector<BlockTransferInfo>& block_transfer_info,
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) override;
std::shared_ptr<std::atomic<bool>> flag,
std::shared_ptr<std::atomic<uint32_t>> success_cnt) override;

// Run the model and return the output.
virtual folly::SemiFuture<std::optional<ForwardOutput>> step_async(
Expand Down
30 changes: 13 additions & 17 deletions xllm/core/distributed_runtime/worker_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ limitations under the License.

namespace xllm {

constexpr uint32_t COPY_BATCH_SIZE = 1;

WorkerService::WorkerService(runtime::Options options,
const torch::Device& device)
: options_(options), device_(device), initialized_(false) {
Expand Down Expand Up @@ -404,11 +402,8 @@ void WorkerService::TransferBlocks(
std::vector<BlockTransferInfo> block_transfer_info;
uint64_t batch_id = proto_to_block_transfer_info(*req, block_transfer_info);

if (batch_id == UNINITIALIZED_BATCH_ID) {
resp->set_success_cnt(worker_->transfer_kv_blocks(block_transfer_info));
} else {
worker_->transfer_kv_blocks(batch_id, std::move(block_transfer_info));
}
resp->set_success_cnt(
worker_->transfer_kv_blocks(batch_id, std::move(block_transfer_info)));
return;
}

Expand Down Expand Up @@ -477,22 +472,24 @@ void WorkerService::PrefetchFromStorage(
auto close_future = stream_handler->get_close_future();
bool is_completed = false;

for (size_t i = 0; i < transfer_slice.size(); i += COPY_BATCH_SIZE) {
auto current_slice = transfer_slice.slice(
i, std::min(i + COPY_BATCH_SIZE, transfer_slice.size()));
for (size_t i = 0; i < transfer_slice.size();
i += options_.prefetch_bacth_size()) {
auto current_slice =
transfer_slice.slice(i,
std::min(i + options_.prefetch_bacth_size(),
transfer_slice.size()));

auto success_cnt = worker_->prefetch_from_storage(current_slice);
auto success_cnt = worker_->transfer_kv_blocks(UNINITIALIZED_BATCH_ID,
current_slice);

if (success_cnt != current_slice.size() ||
i + COPY_BATCH_SIZE >= transfer_slice.size()) {
i + options_.prefetch_bacth_size() >= transfer_slice.size()) {
is_completed = true;
}

butil::IOBuf buf;
buf.append(std::to_string(success_cnt));
if (brpc::StreamWrite(*stream_id.get(), buf) != 0) {
brpc::StreamClose(*stream_id.get());
is_completed = false;
break;
}

Expand All @@ -505,9 +502,8 @@ void WorkerService::PrefetchFromStorage(
break;
}
}
if (is_completed) {
close_future.wait();
}

close_future.wait();
brpc::StreamClose(*stream_id.get());
});

Expand Down
2 changes: 2 additions & 0 deletions xllm/core/framework/block/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ cc_library(
block_manager_pool.h
block_manager_impl.h
concurrent_block_manager_impl.h
hierarchy_block_manager_pool.h
SRCS
block.cpp
block_manager_pool.cpp
concurrent_block_manager_impl.cpp
block_manager_impl.cpp
hierarchy_block_manager_pool.cpp
DEPS
$<$<BOOL:${USE_NPU}>:torch_npu>
$<$<BOOL:${USE_NPU}>:graph>
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/framework/block/block_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class BlockManager {

virtual void deallocate(const Slice<Block>& blocks) = 0;

virtual void deallocate(std::vector<Block>& blocks) = 0;

virtual std::vector<Block> allocate(size_t num_blocks) = 0;

virtual std::vector<Block> allocate_shared(
Expand Down
7 changes: 6 additions & 1 deletion xllm/core/framework/block/block_manager_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ void BlockManagerImpl::deallocate(const Slice<Block>& blocks) {
}
}

void BlockManagerImpl::deallocate(std::vector<Block>& blocks) {
Slice<Block> slice(blocks);
deallocate(slice);
blocks.clear();
}

bool BlockManagerImpl::has_enough_blocks(uint32_t num_blocks) {
if (num_blocks <= num_free_blocks_) {
return true;
Expand Down Expand Up @@ -171,7 +177,6 @@ void BlockManagerImpl::get_merged_kvcache_event(KvCacheEvent* event) const {
if (events != nullptr) {
event->removed_cache.merge(events->removed_cache);
event->stored_cache.merge(events->stored_cache);
event->offload_cache.merge(events->offload_cache);
events->clear();
}
}
Expand Down
4 changes: 3 additions & 1 deletion xllm/core/framework/block/block_manager_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class BlockManagerImpl : public BlockManager {

void deallocate(const Slice<Block>& blocks) override;

void deallocate(std::vector<Block>& blocks) override;

// allocate shared blocks when enable prefix cache
std::vector<Block> allocate_shared(
const Slice<int32_t>& tokens_ids,
Expand Down Expand Up @@ -77,7 +79,7 @@ class BlockManagerImpl : public BlockManager {
}

float get_gpu_cache_usage_perc() const override {
return 1.0 - num_free_blocks_ * 1.0 / num_total_blocks();
return 1 - static_cast<float>(num_free_blocks_) / num_total_blocks();
}

// call BlockManager to free block used by Block.
Expand Down
Loading
Loading