Skip to content

Commit c67cf12

Browse files
committed
feat: support async layer wise batch copy.
1 parent 0868300 commit c67cf12

File tree

16 files changed

+107
-45
lines changed

16 files changed

+107
-45
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ DEFINE_uint32(prefetch_timeout,
168168
0,
169169
"Prefetch timeout for prefetch from kv cache store.");
170170

171+
DEFINE_uint32(prefetch_bacth_size,
172+
2,
173+
"Prefetch from kvcache store copy batch size.");
174+
175+
DEFINE_uint32(layers_wise_copy_batchs, 4, "Layer wise H2D copy batchs.");
176+
171177
// --- parallel config ---
172178

173179
DEFINE_int32(dp_size, 1, "Data parallel size for MLA attention.");

xllm/core/common/global_flags.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ DECLARE_int32(max_decode_token_per_sequence);
155155

156156
DECLARE_uint32(prefetch_timeout);
157157

158+
DECLARE_uint32(prefetch_bacth_size);
159+
160+
DECLARE_uint32(layers_wise_copy_batchs);
161+
158162
DECLARE_string(priority_strategy);
159163

160164
DECLARE_bool(enable_online_preempt_offline);

xllm/core/common/options.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ std::string Options::to_string() const {
5454
<< ", enable_cache_upload: " << enable_cache_upload()
5555
<< ", enable_kvcache_store: " << enable_kvcache_store()
5656
<< ", prefetch_timeout: " << prefetch_timeout()
57+
<< ", prefetch_bacth_size: " << prefetch_bacth_size()
58+
<< ", layers_wise_copy_batchs: " << layers_wise_copy_batchs()
5759
<< ", store_protocol: " << store_protocol()
5860
<< ", store_master_server_address: " << store_master_server_address()
5961
<< ", store_metadata_server: " << store_metadata_server()

xllm/core/common/options.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ class Options {
192192

193193
// Prefetch timeout for prefetch from kv cache store
194194
PROPERTY(uint32_t, prefetch_timeout) = 0;
195+
196+
// Prefetch from kvcache store copy batch size
197+
PROPERTY(uint32_t, prefetch_bacth_size) = 2;
198+
199+
// Layer wise H2D copy batchs
200+
PROPERTY(uint32_t, layers_wise_copy_batchs) = 4;
195201
};
196202

197203
} // namespace xllm

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ limitations under the License.
3333

3434
namespace xllm {
3535

36-
constexpr uint32_t COPY_BATCH_SIZE = 1;
37-
3836
WorkerService::WorkerService(runtime::Options options,
3937
const torch::Device& device)
4038
: options_(options), device_(device), initialized_(false) {
@@ -477,21 +475,23 @@ void WorkerService::PrefetchFromStorage(
477475
auto close_future = stream_handler->get_close_future();
478476
bool is_completed = false;
479477

480-
for (size_t i = 0; i < transfer_slice.size(); i += COPY_BATCH_SIZE) {
481-
auto current_slice = transfer_slice.slice(
482-
i, std::min(i + COPY_BATCH_SIZE, transfer_slice.size()));
478+
for (size_t i = 0; i < transfer_slice.size();
479+
i += options_.prefetch_bacth_size()) {
480+
auto current_slice =
481+
transfer_slice.slice(i,
482+
std::min(i + options_.prefetch_bacth_size(),
483+
transfer_slice.size()));
483484

484485
auto success_cnt = worker_->prefetch_from_storage(current_slice);
485486

486487
if (success_cnt != current_slice.size() ||
487-
i + COPY_BATCH_SIZE >= transfer_slice.size()) {
488+
i + options_.prefetch_bacth_size() >= transfer_slice.size()) {
488489
is_completed = true;
489490
}
490491

491492
butil::IOBuf buf;
492493
buf.append(std::to_string(success_cnt));
493494
if (brpc::StreamWrite(*stream_id.get(), buf) != 0) {
494-
is_completed = false;
495495
break;
496496
}
497497

xllm/core/framework/request/sequence.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ void Sequence::reset() {
382382
kv_state_.reset();
383383
host_kv_state_.reset();
384384
timer_.reset();
385+
is_timeout_set_ = false;
385386
volatile_num_prompt_tokens_ = num_tokens_;
386387
}
387388

@@ -462,12 +463,13 @@ bool Sequence::update_prefetch_result(uint32_t timeout) {
462463
}
463464

464465
if (timeout != 0 && !termination_flag_.load(std::memory_order_acquire)) {
465-
if (timer_ != nullptr) {
466-
timer_ = std::make_shared<Timer>();
466+
if (!is_timeout_set_) {
467+
timer_.reset();
468+
is_timeout_set_ = true;
467469
return false;
468470
}
469471

470-
if (timer_->elapsed_milliseconds() < timeout) {
472+
if (timer_.elapsed_milliseconds() < timeout) {
471473
return false;
472474
}
473475
}

xllm/core/framework/request/sequence.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,8 @@ class Sequence final {
364364
std::atomic<bool> termination_flag_{false};
365365
std::vector<std::shared_ptr<std::atomic<uint32_t>>> prefetch_results_;
366366

367-
std::shared_ptr<Timer> timer_ = nullptr;
367+
Timer timer_;
368+
bool is_timeout_set_ = false;
368369
};
369370

370371
} // namespace xllm

xllm/core/kernels/npu/xllm_ops/top_k_top_p.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License.
2020
#include <vector>
2121

2222
#include "acl/acl.h"
23-
#include "aclnn_apply_top_k_top_p.h"
23+
#include "aclnnop/aclnn_apply_top_k_top_p.h"
2424
#include "acltensor_utils.h"
2525
#include "util/tensor_helper.h"
2626

xllm/core/runtime/master.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ Master::Master(const Options& options, EngineType type) : options_(options) {
218218
.store_master_server_address(options_.store_master_server_address())
219219
.store_metadata_server(options_.store_metadata_server())
220220
.store_local_hostname(options_.store_local_hostname())
221+
.prefetch_bacth_size(options_.prefetch_bacth_size())
222+
.layers_wise_copy_batchs(options_.layers_wise_copy_batchs())
221223
.enable_continuous_kvcache(options_.enable_continuous_kvcache())
222224
.enable_offline_inference(options_.enable_offline_inference())
223225
.spawn_worker_path(options_.spawn_worker_path())

xllm/core/runtime/options.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,12 @@ struct Options {
158158
// value used if port is not included)
159159
PROPERTY(std::string, store_local_hostname) = "";
160160

161+
// Prefetch from kvcache store copy batch size
162+
PROPERTY(uint32_t, prefetch_bacth_size) = 2;
163+
164+
// Layer wise H2D copy batchs
165+
PROPERTY(uint32_t, layers_wise_copy_batchs) = 4;
166+
161167
// dit
162168
// max requests per batch
163169
PROPERTY(int, max_requests_per_batch) = 0;

0 commit comments

Comments
 (0)