@@ -614,6 +614,7 @@ bool WorkerImpl::init_model(const std::string& model_weights_path) {
614614 if (!status) {
615615 return false ;
616616 }
617+ layers_per_copy_ = context_.get_model_args ().n_layers () / 4 ;
617618
618619 this ->load_model (std::move (model_loader));
619620
@@ -874,9 +875,14 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
874875 }
875876
876877 const int64_t num_layers = context_.get_model_args ().n_layers ();
878+ uint32_t layers_per_copy = layers_per_copy_;
877879 uint32_t num_batches = block_transfer_info.size () * 2 ;
880+ while (num_batches * layers_per_copy > BATCH_COPY_MAX_SIZE) {
881+ layers_per_copy--;
882+ }
878883
879- auto synchronizer = std::make_shared<NPULayerSynchronizerImpl>(num_layers);
884+ uint32_t copy_cnt = (num_layers + layers_per_copy - 1 ) / layers_per_copy;
885+ auto synchronizer = std::make_shared<NPULayerSynchronizerImpl>(copy_cnt);
880886 {
881887 std::lock_guard<std::mutex> lock (mutex_);
882888 if (layer_wise_load_synchronizer_.count (batch_id) != 0 ) {
@@ -885,47 +891,54 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
885891 layer_wise_load_synchronizer_[batch_id] = synchronizer;
886892 }
887893
888- void ** srcs = new void *[num_batches];
889- void ** dsts = new void *[num_batches];
890- size_t * copy_size = new size_t [num_batches];
891894 aclrtMemcpyBatchAttr attrs[1 ] = {h2d_attrs_};
892895 size_t attrs_indexes[1 ] = {0 };
893896
894897 std::unique_ptr<Stream> stream;
895898 copy_stream_.wait_dequeue (stream);
896899 c10::StreamGuard streamGuard = stream->set_stream_guard ();
897-
898900 aclError ret = 0 ;
899901
900- for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
901- auto dst_k_cache = kv_caches_.at (layer_id).get_k_cache ();
902- auto dst_v_cache = kv_caches_.at (layer_id).get_v_cache ();
902+ void ** srcs = new void *[num_batches * layers_per_copy];
903+ void ** dsts = new void *[num_batches * layers_per_copy];
904+ size_t * copy_size = new size_t [num_batches * layers_per_copy];
905+
906+ for (int index = 0 ; index < copy_cnt; index++) {
907+ int layer_id = index * layers_per_copy;
903908 size_t fail_index = 0 ;
904909 uint32_t curr_index = 0 ;
905- auto * event = synchronizer->get_event (layer_id);
906- auto * event_flag = synchronizer->get_event_flag (layer_id);
910+ uint32_t layer_cnt = 0 ;
907911
908- for ( const auto & info : block_transfer_info ) {
909- auto src_k_cache = host_kv_caches_ .at (info. src_block_id ).get_k_cache ();
910- auto src_v_cache = host_kv_caches_ .at (info. src_block_id ).get_v_cache ();
912+ while (layer_id < (index + 1 ) * layers_per_copy && layer_id < num_layers ) {
913+ auto dst_k_cache = kv_caches_ .at (layer_id ).get_k_cache ();
914+ auto dst_v_cache = kv_caches_ .at (layer_id ).get_v_cache ();
911915
912- srcs[curr_index] = src_k_cache[layer_id].data_ptr ();
913- dsts[curr_index] = dst_k_cache[info.dst_block_id ].data_ptr ();
914- copy_size[curr_index] = key_cache_size_per_layer_;
915- curr_index++;
916+ for (const auto & info : block_transfer_info) {
917+ auto src_k_cache = host_kv_caches_.at (info.src_block_id ).get_k_cache ();
918+ auto src_v_cache = host_kv_caches_.at (info.src_block_id ).get_v_cache ();
916919
917- srcs[curr_index] = src_v_cache[layer_id].data_ptr ();
918- dsts[curr_index] = dst_v_cache[info.dst_block_id ].data_ptr ();
919- copy_size[curr_index] = value_cache_size_per_layer_;
920- curr_index++;
920+ srcs[curr_index] = src_k_cache[layer_id].data_ptr ();
921+ dsts[curr_index] = dst_k_cache[info.dst_block_id ].data_ptr ();
922+ copy_size[curr_index] = key_cache_size_per_layer_;
923+ curr_index++;
924+
925+ srcs[curr_index] = src_v_cache[layer_id].data_ptr ();
926+ dsts[curr_index] = dst_v_cache[info.dst_block_id ].data_ptr ();
927+ copy_size[curr_index] = value_cache_size_per_layer_;
928+ curr_index++;
929+ }
930+ layer_id++;
931+ layer_cnt++;
921932 }
922933
923934 // TODO(kangmeng): change to async API
935+ CHECK (layer_cnt <= layers_per_copy)
936+ << " layer_cnt should less equal to layers_per_copy." ;
924937 ret = aclrtMemcpyBatch (dsts,
925938 copy_size,
926939 srcs,
927940 copy_size,
928- num_batches,
941+ num_batches * layer_cnt ,
929942 attrs,
930943 attrs_indexes,
931944 1 ,
@@ -935,11 +948,13 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
935948 LOG (ERROR) << " aclrtMemcpyBatch error: " << ret
936949 << " , fail_index:" << fail_index;
937950 } else {
951+ auto * event = synchronizer->get_event (index);
938952 ret = aclrtRecordEvent (*event, stream->get_stream ()->stream ());
939953 if (ret != 0 ) {
940954 LOG (ERROR) << " aclrtRecordEvent error: " << ret;
941955 }
942956 }
957+ auto * event_flag = synchronizer->get_event_flag (index);
943958 event_flag->store (true , std::memory_order_release);
944959 if (ret != 0 ) break ;
945960 }
0 commit comments