@@ -638,6 +638,7 @@ bool WorkerImpl::init_model(const std::string& model_weights_path,
638638 if (!status) {
639639 return false ;
640640 }
641+ layers_per_copy_ = context_.get_model_args ().n_layers () / 4 ;
641642
642643 this ->load_model (std::move (model_loader));
643644
@@ -898,9 +899,14 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
898899 }
899900
900901 const int64_t num_layers = context_.get_model_args ().n_layers ();
902+ uint32_t layers_per_copy = layers_per_copy_;
901903 uint32_t num_batches = block_transfer_info.size () * 2 ;
904+ while (num_batches * layers_per_copy > BATCH_COPY_MAX_SIZE) {
905+ layers_per_copy--;
906+ }
902907
903- auto synchronizer = std::make_shared<NPULayerSynchronizerImpl>(num_layers);
908+ uint32_t copy_cnt = (num_layers + layers_per_copy - 1 ) / layers_per_copy;
909+ auto synchronizer = std::make_shared<NPULayerSynchronizerImpl>(copy_cnt);
904910 {
905911 std::lock_guard<std::mutex> lock (mutex_);
906912 if (layer_wise_load_synchronizer_.count (batch_id) != 0 ) {
@@ -909,47 +915,54 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
909915 layer_wise_load_synchronizer_[batch_id] = synchronizer;
910916 }
911917
912- void ** srcs = new void *[num_batches];
913- void ** dsts = new void *[num_batches];
914- size_t * copy_size = new size_t [num_batches];
915918 aclrtMemcpyBatchAttr attrs[1 ] = {h2d_attrs_};
916919 size_t attrs_indexes[1 ] = {0 };
917920
918921 std::unique_ptr<Stream> stream;
919922 copy_stream_.wait_dequeue (stream);
920923 c10::StreamGuard streamGuard = stream->set_stream_guard ();
921-
922924 aclError ret = 0 ;
923925
924- for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
925- auto dst_k_cache = kv_caches_.at (layer_id).get_k_cache ();
926- auto dst_v_cache = kv_caches_.at (layer_id).get_v_cache ();
926+ void ** srcs = new void *[num_batches * layers_per_copy];
927+ void ** dsts = new void *[num_batches * layers_per_copy];
928+ size_t * copy_size = new size_t [num_batches * layers_per_copy];
929+
930+ for (int index = 0 ; index < copy_cnt; index++) {
931+ int layer_id = index * layers_per_copy;
927932 size_t fail_index = 0 ;
928933 uint32_t curr_index = 0 ;
929- auto * event = synchronizer->get_event (layer_id);
930- auto * event_flag = synchronizer->get_event_flag (layer_id);
934+ uint32_t layer_cnt = 0 ;
931935
932- for ( const auto & info : block_transfer_info ) {
933- auto src_k_cache = host_kv_caches_ .at (info. src_block_id ).get_k_cache ();
934- auto src_v_cache = host_kv_caches_ .at (info. src_block_id ).get_v_cache ();
936+ while (layer_id < (index + 1 ) * layers_per_copy && layer_id < num_layers ) {
937+ auto dst_k_cache = kv_caches_ .at (layer_id ).get_k_cache ();
938+ auto dst_v_cache = kv_caches_ .at (layer_id ).get_v_cache ();
935939
936- srcs[curr_index] = src_k_cache[layer_id].data_ptr ();
937- dsts[curr_index] = dst_k_cache[info.dst_block_id ].data_ptr ();
938- copy_size[curr_index] = key_cache_size_per_layer_;
939- curr_index++;
940+ for (const auto & info : block_transfer_info) {
941+ auto src_k_cache = host_kv_caches_.at (info.src_block_id ).get_k_cache ();
942+ auto src_v_cache = host_kv_caches_.at (info.src_block_id ).get_v_cache ();
940943
941- srcs[curr_index] = src_v_cache[layer_id].data_ptr ();
942- dsts[curr_index] = dst_v_cache[info.dst_block_id ].data_ptr ();
943- copy_size[curr_index] = value_cache_size_per_layer_;
944- curr_index++;
944+ srcs[curr_index] = src_k_cache[layer_id].data_ptr ();
945+ dsts[curr_index] = dst_k_cache[info.dst_block_id ].data_ptr ();
946+ copy_size[curr_index] = key_cache_size_per_layer_;
947+ curr_index++;
948+
949+ srcs[curr_index] = src_v_cache[layer_id].data_ptr ();
950+ dsts[curr_index] = dst_v_cache[info.dst_block_id ].data_ptr ();
951+ copy_size[curr_index] = value_cache_size_per_layer_;
952+ curr_index++;
953+ }
954+ layer_id++;
955+ layer_cnt++;
945956 }
946957
947958 // TODO(kangmeng): change to async API
959+ CHECK (layer_cnt <= layers_per_copy)
960+ << " layer_cnt should less equal to layers_per_copy." ;
948961 ret = aclrtMemcpyBatch (dsts,
949962 copy_size,
950963 srcs,
951964 copy_size,
952- num_batches,
965+ num_batches * layer_cnt ,
953966 attrs,
954967 attrs_indexes,
955968 1 ,
@@ -959,11 +972,13 @@ bool WorkerImpl::h2d_batch_copy(const uint64_t batch_id,
959972 LOG (ERROR) << " aclrtMemcpyBatch error: " << ret
960973 << " , fail_index:" << fail_index;
961974 } else {
975+ auto * event = synchronizer->get_event (index);
962976 ret = aclrtRecordEvent (*event, stream->get_stream ()->stream ());
963977 if (ret != 0 ) {
964978 LOG (ERROR) << " aclrtRecordEvent error: " << ret;
965979 }
966980 }
981+ auto * event_flag = synchronizer->get_event_flag (index);
967982 event_flag->store (true , std::memory_order_release);
968983 if (ret != 0 ) break ;
969984 }
0 commit comments