Skip to content

Commit f52f4dd

Browse files
committed
feat: optimize layer wise copy.
1 parent a94b9be commit f52f4dd

File tree

16 files changed

+82
-73
lines changed

16 files changed

+82
-73
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ DEFINE_int32(
164164
256,
165165
"Max decode token per sequence which used for ZeroEvictionScheduler.");
166166

167+
DEFINE_uint32(prefetch_timeout,
168+
0,
169+
"Prefetch timeout for prefetch from kv cache store.");
170+
167171
// --- parallel config ---
168172

169173
DEFINE_int32(dp_size, 1, "Data parallel size for MLA attention.");

xllm/core/common/global_flags.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ DECLARE_bool(use_zero_evict);
153153

154154
DECLARE_int32(max_decode_token_per_sequence);
155155

156+
DECLARE_uint32(prefetch_timeout);
157+
156158
DECLARE_string(priority_strategy);
157159

158160
DECLARE_bool(enable_online_preempt_offline);

xllm/core/common/options.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ std::string Options::to_string() const {
5353
<< ", enable_service_routing: " << enable_service_routing()
5454
<< ", enable_cache_upload: " << enable_cache_upload()
5555
<< ", enable_kvcache_store: " << enable_kvcache_store()
56+
<< ", prefetch_timeout: " << prefetch_timeout()
5657
<< ", store_protocol: " << store_protocol()
5758
<< ", store_master_server_address: " << store_master_server_address()
5859
<< ", store_metadata_server: " << store_metadata_server()

xllm/core/common/options.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ class Options {
189189

190190
// whether the worker and master are on the same machine.
191191
PROPERTY(bool, is_local) = false;
192+
193+
// Prefetch timeout for prefetch from kv cache store
194+
PROPERTY(uint32_t, prefetch_timeout) = 0;
192195
};
193196

194197
} // namespace xllm

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,6 @@ void WorkerService::PrefetchFromStorage(
489489
butil::IOBuf buf;
490490
buf.append(std::to_string(success_cnt));
491491
if (brpc::StreamWrite(*stream_id.get(), buf) != 0) {
492-
brpc::StreamClose(*stream_id.get());
493492
is_completed = false;
494493
break;
495494
}

xllm/core/framework/request/sequence.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ void Sequence::add_host_kv_blocks(const std::vector<Block>& blocks) {
381381
void Sequence::reset() {
382382
kv_state_.reset();
383383
host_kv_state_.reset();
384-
timeout_checker_.reset();
384+
timer_.reset();
385385
volatile_num_prompt_tokens_ = num_tokens_;
386386
}
387387

@@ -456,14 +456,20 @@ Slice<int32_t> Sequence::get_generated_tokens() const {
456456
return {tokens_.data(), 0};
457457
}
458458

459-
bool Sequence::update_prefetch_result() {
459+
bool Sequence::update_prefetch_result(uint32_t timeout) {
460460
if (prefetch_results_.empty()) {
461461
return true;
462462
}
463463

464-
if (!termination_flag_.load(std::memory_order_acquire) &&
465-
timeout_checker_.check_timeout()) {
466-
return false;
464+
if (timeout != 0 && !termination_flag_.load(std::memory_order_acquire)) {
465+
if (timer_ != nullptr) {
466+
timer_ = std::make_shared<Timer>();
467+
return false;
468+
}
469+
470+
if (timer_->elapsed_milliseconds() < timeout) {
471+
return false;
472+
}
467473
}
468474

469475
termination_flag_.store(true, std::memory_order_release);

xllm/core/framework/request/sequence.h

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ limitations under the License.
3535
#include "sequence_kv_state.h"
3636
#include "sequence_logprob_state.h"
3737
#include "stopping_checker.h"
38+
#include "util/timer.h"
3839

3940
namespace xllm {
4041

@@ -82,44 +83,6 @@ struct SequenceParams {
8283
StoppingChecker* stopping_checker; // not owned
8384
};
8485

85-
static uint32_t timeout_ms = 0;
86-
class TimeoutChecker {
87-
private:
88-
std::chrono::steady_clock::time_point timeout_start_;
89-
bool is_timeout_set_ = false;
90-
91-
public:
92-
TimeoutChecker() { init(); }
93-
94-
bool check_timeout() {
95-
if (!is_timeout_set_) {
96-
timeout_start_ = std::chrono::steady_clock::now();
97-
is_timeout_set_ = true;
98-
99-
return false;
100-
} else {
101-
auto now = std::chrono::steady_clock::now();
102-
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
103-
now - timeout_start_);
104-
105-
return elapsed.count() >= timeout_ms;
106-
}
107-
}
108-
109-
void reset() { is_timeout_set_ = false; }
110-
111-
private:
112-
static void init_timeout() {
113-
const char* env_str = std::getenv("PREFETCH_TIMEOUT_MS");
114-
timeout_ms = env_str ? std::strtoul(env_str, nullptr, 10) : 0;
115-
LOG(INFO) << "Prefetch timeout set as: " << timeout_ms;
116-
}
117-
static void init() {
118-
static std::once_flag flag_;
119-
std::call_once(flag_, init_timeout);
120-
}
121-
};
122-
12386
class Sequence final {
12487
public:
12588
Sequence(size_t index,
@@ -286,7 +249,7 @@ class Sequence final {
286249
return &prefetch_results_;
287250
}
288251

289-
bool update_prefetch_result();
252+
bool update_prefetch_result(uint32_t timeout = 30);
290253

291254
void reset();
292255

@@ -401,7 +364,7 @@ class Sequence final {
401364
std::atomic<bool> termination_flag_{false};
402365
std::vector<std::shared_ptr<std::atomic<uint32_t>>> prefetch_results_;
403366

404-
TimeoutChecker timeout_checker_;
367+
std::shared_ptr<Timer> timer_ = nullptr;
405368
};
406369

407370
} // namespace xllm

xllm/core/platform/npu/npu_layer_synchronizer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class NPULayerSynchronizerImpl {
3131
aclrtEvent* get_event(const int64_t layer_index);
3232
std::atomic<bool>* get_event_flag(const int64_t layer_index);
3333
bool synchronize_layer(const int64_t layer_index);
34+
uint32_t get_event_size() { return events_.size(); };
3435

3536
private:
3637
std::vector<aclrtEvent> events_;

xllm/core/runtime/llm_master.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ LLMMaster::LLMMaster(const Options& options)
9595
.disable_ttft_profiling(options_.disable_ttft_profiling())
9696
.enable_forward_interruption(options_.enable_forward_interruption())
9797
.max_global_ttft_ms(options_.max_global_ttft_ms())
98-
.max_global_tpot_ms(options_.max_global_tpot_ms());
98+
.max_global_tpot_ms(options_.max_global_tpot_ms())
99+
.prefetch_timeout(options_.prefetch_timeout());
99100
scheduler_ = create_continuous_scheduler(engine_.get(), scheduler_options);
100101

101102
if (options_.enable_service_routing()) {

xllm/core/runtime/worker_impl.cpp

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ bool WorkerImpl::init_model(const std::string& model_weights_path) {
614614
if (!status) {
615615
return false;
616616
}
617+
layers_per_copy_ = context_.get_model_args().n_layers() / 4;
617618

618619
this->load_model(std::move(model_loader));
619620

@@ -874,9 +875,14 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
874875
}
875876

876877
const int64_t num_layers = context_.get_model_args().n_layers();
878+
uint32_t layers_per_copy = layers_per_copy_;
877879
uint32_t num_batches = block_transfer_info.size() * 2;
880+
while (num_batches * layers_per_copy > BATCH_COPY_MAX_SIZE) {
881+
layers_per_copy--;
882+
}
878883

879-
auto synchronizer = std::make_shared<NPULayerSynchronizerImpl>(num_layers);
884+
uint32_t copy_cnt = (num_layers + layers_per_copy - 1) / layers_per_copy;
885+
auto synchronizer = std::make_shared<NPULayerSynchronizerImpl>(copy_cnt);
880886
{
881887
std::lock_guard<std::mutex> lock(mutex_);
882888
if (layer_wise_load_synchronizer_.count(batch_id) != 0) {
@@ -885,47 +891,54 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
885891
layer_wise_load_synchronizer_[batch_id] = synchronizer;
886892
}
887893

888-
void** srcs = new void*[num_batches];
889-
void** dsts = new void*[num_batches];
890-
size_t* copy_size = new size_t[num_batches];
891894
aclrtMemcpyBatchAttr attrs[1] = {h2d_attrs_};
892895
size_t attrs_indexes[1] = {0};
893896

894897
std::unique_ptr<Stream> stream;
895898
copy_stream_.wait_dequeue(stream);
896899
c10::StreamGuard streamGuard = stream->set_stream_guard();
897-
898900
aclError ret = 0;
899901

900-
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
901-
auto dst_k_cache = kv_caches_.at(layer_id).get_k_cache();
902-
auto dst_v_cache = kv_caches_.at(layer_id).get_v_cache();
902+
void** srcs = new void*[num_batches * layers_per_copy];
903+
void** dsts = new void*[num_batches * layers_per_copy];
904+
size_t* copy_size = new size_t[num_batches * layers_per_copy];
905+
906+
for (int index = 0; index < copy_cnt; index++) {
907+
int layer_id = index * layers_per_copy;
903908
size_t fail_index = 0;
904909
uint32_t curr_index = 0;
905-
auto* event = synchronizer->get_event(layer_id);
906-
auto* event_flag = synchronizer->get_event_flag(layer_id);
910+
uint32_t layer_cnt = 0;
907911

908-
for (const auto& info : block_transfer_info) {
909-
auto src_k_cache = host_kv_caches_.at(info.src_block_id).get_k_cache();
910-
auto src_v_cache = host_kv_caches_.at(info.src_block_id).get_v_cache();
912+
while (layer_id < (index + 1) * layers_per_copy && layer_id < num_layers) {
913+
auto dst_k_cache = kv_caches_.at(layer_id).get_k_cache();
914+
auto dst_v_cache = kv_caches_.at(layer_id).get_v_cache();
911915

912-
srcs[curr_index] = src_k_cache[layer_id].data_ptr();
913-
dsts[curr_index] = dst_k_cache[info.dst_block_id].data_ptr();
914-
copy_size[curr_index] = key_cache_size_per_layer_;
915-
curr_index++;
916+
for (const auto& info : block_transfer_info) {
917+
auto src_k_cache = host_kv_caches_.at(info.src_block_id).get_k_cache();
918+
auto src_v_cache = host_kv_caches_.at(info.src_block_id).get_v_cache();
916919

917-
srcs[curr_index] = src_v_cache[layer_id].data_ptr();
918-
dsts[curr_index] = dst_v_cache[info.dst_block_id].data_ptr();
919-
copy_size[curr_index] = value_cache_size_per_layer_;
920-
curr_index++;
920+
srcs[curr_index] = src_k_cache[layer_id].data_ptr();
921+
dsts[curr_index] = dst_k_cache[info.dst_block_id].data_ptr();
922+
copy_size[curr_index] = key_cache_size_per_layer_;
923+
curr_index++;
924+
925+
srcs[curr_index] = src_v_cache[layer_id].data_ptr();
926+
dsts[curr_index] = dst_v_cache[info.dst_block_id].data_ptr();
927+
copy_size[curr_index] = value_cache_size_per_layer_;
928+
curr_index++;
929+
}
930+
layer_id++;
931+
layer_cnt++;
921932
}
922933

923934
// TODO(kangmeng): change to async API
935+
CHECK(layer_cnt <= layers_per_copy)
936+
<< "layer_cnt should less equal to layers_per_copy.";
924937
ret = aclrtMemcpyBatch(dsts,
925938
copy_size,
926939
srcs,
927940
copy_size,
928-
num_batches,
941+
num_batches * layer_cnt,
929942
attrs,
930943
attrs_indexes,
931944
1,
@@ -935,11 +948,13 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
935948
LOG(ERROR) << "aclrtMemcpyBatch error: " << ret
936949
<< ", fail_index:" << fail_index;
937950
} else {
951+
auto* event = synchronizer->get_event(index);
938952
ret = aclrtRecordEvent(*event, stream->get_stream()->stream());
939953
if (ret != 0) {
940954
LOG(ERROR) << "aclrtRecordEvent error: " << ret;
941955
}
942956
}
957+
auto* event_flag = synchronizer->get_event_flag(index);
943958
event_flag->store(true, std::memory_order_release);
944959
if (ret != 0) break;
945960
}

0 commit comments

Comments
 (0)