Skip to content

Commit 0868300

Browse files
committed
feat: optimize layer wise copy.
1 parent 875da2e commit 0868300

File tree

15 files changed

+72
-71
lines changed

15 files changed

+72
-71
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ DEFINE_int32(
164164
256,
165165
"Max decode token per sequence which used for ZeroEvictionScheduler.");
166166

167+
DEFINE_uint32(prefetch_timeout,
168+
0,
169+
"Prefetch timeout for prefetch from kv cache store.");
170+
167171
// --- parallel config ---
168172

169173
DEFINE_int32(dp_size, 1, "Data parallel size for MLA attention.");

xllm/core/common/global_flags.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ DECLARE_bool(use_zero_evict);
153153

154154
DECLARE_int32(max_decode_token_per_sequence);
155155

156+
DECLARE_uint32(prefetch_timeout);
157+
156158
DECLARE_string(priority_strategy);
157159

158160
DECLARE_bool(enable_online_preempt_offline);

xllm/core/common/options.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ std::string Options::to_string() const {
5353
<< ", enable_service_routing: " << enable_service_routing()
5454
<< ", enable_cache_upload: " << enable_cache_upload()
5555
<< ", enable_kvcache_store: " << enable_kvcache_store()
56+
<< ", prefetch_timeout: " << prefetch_timeout()
5657
<< ", store_protocol: " << store_protocol()
5758
<< ", store_master_server_address: " << store_master_server_address()
5859
<< ", store_metadata_server: " << store_metadata_server()

xllm/core/common/options.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ class Options {
189189

190190
// whether the worker and master are on the same machine.
191191
PROPERTY(bool, is_local) = false;
192+
193+
// Prefetch timeout for prefetch from kv cache store
194+
PROPERTY(uint32_t, prefetch_timeout) = 0;
192195
};
193196

194197
} // namespace xllm

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,6 @@ void WorkerService::PrefetchFromStorage(
491491
butil::IOBuf buf;
492492
buf.append(std::to_string(success_cnt));
493493
if (brpc::StreamWrite(*stream_id.get(), buf) != 0) {
494-
brpc::StreamClose(*stream_id.get());
495494
is_completed = false;
496495
break;
497496
}

xllm/core/framework/request/sequence.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ void Sequence::add_host_kv_blocks(const std::vector<Block>& blocks) {
381381
void Sequence::reset() {
382382
kv_state_.reset();
383383
host_kv_state_.reset();
384-
timeout_checker_.reset();
384+
timer_.reset();
385385
volatile_num_prompt_tokens_ = num_tokens_;
386386
}
387387

@@ -456,14 +456,20 @@ Slice<int32_t> Sequence::get_generated_tokens() const {
456456
return {tokens_.data(), 0};
457457
}
458458

459-
bool Sequence::update_prefetch_result() {
459+
bool Sequence::update_prefetch_result(uint32_t timeout) {
460460
if (prefetch_results_.empty()) {
461461
return true;
462462
}
463463

464-
if (!termination_flag_.load(std::memory_order_acquire) &&
465-
timeout_checker_.check_timeout()) {
466-
return false;
464+
if (timeout != 0 && !termination_flag_.load(std::memory_order_acquire)) {
465+
if (timer_ != nullptr) {
466+
timer_ = std::make_shared<Timer>();
467+
return false;
468+
}
469+
470+
if (timer_->elapsed_milliseconds() < timeout) {
471+
return false;
472+
}
467473
}
468474

469475
termination_flag_.store(true, std::memory_order_release);

xllm/core/framework/request/sequence.h

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ limitations under the License.
3535
#include "sequence_kv_state.h"
3636
#include "sequence_logprob_state.h"
3737
#include "stopping_checker.h"
38+
#include "util/timer.h"
3839

3940
namespace xllm {
4041

@@ -82,44 +83,6 @@ struct SequenceParams {
8283
StoppingChecker* stopping_checker; // not owned
8384
};
8485

85-
static uint32_t timeout_ms = 0;
86-
class TimeoutChecker {
87-
private:
88-
std::chrono::steady_clock::time_point timeout_start_;
89-
bool is_timeout_set_ = false;
90-
91-
public:
92-
TimeoutChecker() { init(); }
93-
94-
bool check_timeout() {
95-
if (!is_timeout_set_) {
96-
timeout_start_ = std::chrono::steady_clock::now();
97-
is_timeout_set_ = true;
98-
99-
return false;
100-
} else {
101-
auto now = std::chrono::steady_clock::now();
102-
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
103-
now - timeout_start_);
104-
105-
return elapsed.count() >= timeout_ms;
106-
}
107-
}
108-
109-
void reset() { is_timeout_set_ = false; }
110-
111-
private:
112-
static void init_timeout() {
113-
const char* env_str = std::getenv("PREFETCH_TIMEOUT_MS");
114-
timeout_ms = env_str ? std::strtoul(env_str, nullptr, 10) : 0;
115-
LOG(INFO) << "Prefetch timeout set as: " << timeout_ms;
116-
}
117-
static void init() {
118-
static std::once_flag flag_;
119-
std::call_once(flag_, init_timeout);
120-
}
121-
};
122-
12386
class Sequence final {
12487
public:
12588
Sequence(size_t index,
@@ -286,7 +249,7 @@ class Sequence final {
286249
return &prefetch_results_;
287250
}
288251

289-
bool update_prefetch_result();
252+
bool update_prefetch_result(uint32_t timeout = 30);
290253

291254
void reset();
292255

@@ -401,7 +364,7 @@ class Sequence final {
401364
std::atomic<bool> termination_flag_{false};
402365
std::vector<std::shared_ptr<std::atomic<uint32_t>>> prefetch_results_;
403366

404-
TimeoutChecker timeout_checker_;
367+
std::shared_ptr<Timer> timer_ = nullptr;
405368
};
406369

407370
} // namespace xllm

xllm/core/platform/npu/npu_layer_synchronizer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class NPULayerSynchronizerImpl {
3131
aclrtEvent* get_event(const int64_t layer_index);
3232
std::atomic<bool>* get_event_flag(const int64_t layer_index);
3333
bool synchronize_layer(const int64_t layer_index);
34+
uint32_t get_event_size() { return events_.size(); };
3435

3536
private:
3637
std::vector<aclrtEvent> events_;

xllm/core/runtime/llm_master.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ LLMMaster::LLMMaster(const Options& options)
9595
.disable_ttft_profiling(options_.disable_ttft_profiling())
9696
.enable_forward_interruption(options_.enable_forward_interruption())
9797
.max_global_ttft_ms(options_.max_global_ttft_ms())
98-
.max_global_tpot_ms(options_.max_global_tpot_ms());
98+
.max_global_tpot_ms(options_.max_global_tpot_ms())
99+
.prefetch_timeout(options_.prefetch_timeout());
99100
scheduler_ = create_continuous_scheduler(engine_.get(), scheduler_options);
100101

101102
if (options_.enable_service_routing()) {

xllm/core/runtime/worker_impl.cpp

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,7 @@ bool WorkerImpl::init_model(const std::string& model_weights_path,
638638
if (!status) {
639639
return false;
640640
}
641+
layers_per_copy_ = context_.get_model_args().n_layers() / 4;
641642

642643
this->load_model(std::move(model_loader));
643644

@@ -898,9 +899,14 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
898899
}
899900

900901
const int64_t num_layers = context_.get_model_args().n_layers();
902+
uint32_t layers_per_copy = layers_per_copy_;
901903
uint32_t num_batches = block_transfer_info.size() * 2;
904+
while (num_batches * layers_per_copy > BATCH_COPY_MAX_SIZE) {
905+
layers_per_copy--;
906+
}
902907

903-
auto synchronizer = std::make_shared<NPULayerSynchronizerImpl>(num_layers);
908+
uint32_t copy_cnt = (num_layers + layers_per_copy - 1) / layers_per_copy;
909+
auto synchronizer = std::make_shared<NPULayerSynchronizerImpl>(copy_cnt);
904910
{
905911
std::lock_guard<std::mutex> lock(mutex_);
906912
if (layer_wise_load_synchronizer_.count(batch_id) != 0) {
@@ -909,47 +915,54 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
909915
layer_wise_load_synchronizer_[batch_id] = synchronizer;
910916
}
911917

912-
void** srcs = new void*[num_batches];
913-
void** dsts = new void*[num_batches];
914-
size_t* copy_size = new size_t[num_batches];
915918
aclrtMemcpyBatchAttr attrs[1] = {h2d_attrs_};
916919
size_t attrs_indexes[1] = {0};
917920

918921
std::unique_ptr<Stream> stream;
919922
copy_stream_.wait_dequeue(stream);
920923
c10::StreamGuard streamGuard = stream->set_stream_guard();
921-
922924
aclError ret = 0;
923925

924-
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
925-
auto dst_k_cache = kv_caches_.at(layer_id).get_k_cache();
926-
auto dst_v_cache = kv_caches_.at(layer_id).get_v_cache();
926+
void** srcs = new void*[num_batches * layers_per_copy];
927+
void** dsts = new void*[num_batches * layers_per_copy];
928+
size_t* copy_size = new size_t[num_batches * layers_per_copy];
929+
930+
for (int index = 0; index < copy_cnt; index++) {
931+
int layer_id = index * layers_per_copy;
927932
size_t fail_index = 0;
928933
uint32_t curr_index = 0;
929-
auto* event = synchronizer->get_event(layer_id);
930-
auto* event_flag = synchronizer->get_event_flag(layer_id);
934+
uint32_t layer_cnt = 0;
931935

932-
for (const auto& info : block_transfer_info) {
933-
auto src_k_cache = host_kv_caches_.at(info.src_block_id).get_k_cache();
934-
auto src_v_cache = host_kv_caches_.at(info.src_block_id).get_v_cache();
936+
while (layer_id < (index + 1) * layers_per_copy && layer_id < num_layers) {
937+
auto dst_k_cache = kv_caches_.at(layer_id).get_k_cache();
938+
auto dst_v_cache = kv_caches_.at(layer_id).get_v_cache();
935939

936-
srcs[curr_index] = src_k_cache[layer_id].data_ptr();
937-
dsts[curr_index] = dst_k_cache[info.dst_block_id].data_ptr();
938-
copy_size[curr_index] = key_cache_size_per_layer_;
939-
curr_index++;
940+
for (const auto& info : block_transfer_info) {
941+
auto src_k_cache = host_kv_caches_.at(info.src_block_id).get_k_cache();
942+
auto src_v_cache = host_kv_caches_.at(info.src_block_id).get_v_cache();
940943

941-
srcs[curr_index] = src_v_cache[layer_id].data_ptr();
942-
dsts[curr_index] = dst_v_cache[info.dst_block_id].data_ptr();
943-
copy_size[curr_index] = value_cache_size_per_layer_;
944-
curr_index++;
944+
srcs[curr_index] = src_k_cache[layer_id].data_ptr();
945+
dsts[curr_index] = dst_k_cache[info.dst_block_id].data_ptr();
946+
copy_size[curr_index] = key_cache_size_per_layer_;
947+
curr_index++;
948+
949+
srcs[curr_index] = src_v_cache[layer_id].data_ptr();
950+
dsts[curr_index] = dst_v_cache[info.dst_block_id].data_ptr();
951+
copy_size[curr_index] = value_cache_size_per_layer_;
952+
curr_index++;
953+
}
954+
layer_id++;
955+
layer_cnt++;
945956
}
946957

947958
// TODO(kangmeng): change to async API
959+
CHECK(layer_cnt <= layers_per_copy)
960+
<< "layer_cnt should less equal to layers_per_copy.";
948961
ret = aclrtMemcpyBatch(dsts,
949962
copy_size,
950963
srcs,
951964
copy_size,
952-
num_batches,
965+
num_batches * layer_cnt,
953966
attrs,
954967
attrs_indexes,
955968
1,
@@ -959,11 +972,13 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
959972
LOG(ERROR) << "aclrtMemcpyBatch error: " << ret
960973
<< ", fail_index:" << fail_index;
961974
} else {
975+
auto* event = synchronizer->get_event(index);
962976
ret = aclrtRecordEvent(*event, stream->get_stream()->stream());
963977
if (ret != 0) {
964978
LOG(ERROR) << "aclrtRecordEvent error: " << ret;
965979
}
966980
}
981+
auto* event_flag = synchronizer->get_event_flag(index);
967982
event_flag->store(true, std::memory_order_release);
968983
if (ret != 0) break;
969984
}

0 commit comments

Comments
 (0)