@@ -17,6 +17,8 @@ limitations under the License.
1717
1818#include < gflags/gflags.h>
1919
20+ #include < unordered_set>
21+
2022#include " common/global_flags.h"
2123
2224namespace xllm {
@@ -233,6 +235,8 @@ NpuQwen3MoeDecoderLayerImpl::NpuQwen3MoeDecoderLayerImpl(
233235 CHECK_EQ (parallel_args.world_size (), dp_size_ * dp_local_tp_size_);
234236 dp_local_tp_rank_ = parallel_args.rank () % dp_local_tp_size_;
235237
238+ n_kv_heads_ = static_cast <int32_t >(model_args.n_kv_heads ().value ());
239+
236240 param_from_args (prefill_param_, model_args, parallel_args, true );
237241 param_from_args (decode_param_, model_args, parallel_args, false );
238242 initialize_tensors (options);
@@ -345,8 +349,8 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_basic_parameters(
345349 param.rmsnormQKNorm = true ;
346350 param.hiddenSizePerAttentionHead = args.head_dim ();
347351 std::optional<long int > optionalValue = args.n_kv_heads ();
348- param.numKeyValueHeadsPerRank =
349- static_cast <int >(optionalValue.value ()) / parallel_args.world_size ();
352+ param.numKeyValueHeadsPerRank = std::max (
353+ 1 , static_cast <int >(optionalValue.value ()) / parallel_args.world_size () );
350354 param.numAttentionHeadsPerRank = args.n_heads () / dp_local_tp_size_;
351355
352356 param.attnLinearTransposeType = {1 , -1 , -1 , 1 , -1 , -1 };
@@ -390,8 +394,16 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_mlp_parameters(
390394void NpuQwen3MoeDecoderLayerImpl::initialize_parallel_parameters (
391395 atb_speed::qwen::MoeDecoderLayerParam& param,
392396 const ParallelArgs& parallel_args) {
393- param.lmHeadLocalTp = 0 ;
397+ param.lmHeadLocalTp = dp_local_tp_size_ ;
394398 param.mapping = parallel_args.mapping ();
399+ param.tensorParallelInfo = {parallel_args.rank (),
400+ parallel_args.world_size (),
401+ FLAGS_communication_backend,
402+ FLAGS_rank_tablefile,
403+ nullptr ,
404+ " " };
405+
406+ param.PrintParam ();
395407 param.maxDecodeDpTokenSize = 0 ; // TODO
396408}
397409
@@ -543,13 +555,31 @@ void NpuQwen3MoeDecoderLayerImpl::process_general_weights(
543555 const int index = get_mapped_index (name, weight_mapping);
544556 const bool is_sharded = shard_map.count (index);
545557 torch::Tensor tmp_tensor;
546-
558+ int32_t tp_rank = dp_local_tp_rank_;
559+ int32_t tp_size = dp_local_tp_size_;
560+
561+ static const std::unordered_set<int > qkv_tensor_indices = {IN_QKV_WEIGHT_1,
562+ IN_QKV_WEIGHT_2,
563+ IN_QKV_BIAS_1,
564+ IN_QKV_BIAS_2,
565+ IN_QKV_DESCALE_1,
566+ IN_QKV_DESCALE_2,
567+ IN_QKV_OFFSET_1,
568+ IN_QKV_OFFSET_2,
569+ IN_QKV_SCALE_1,
570+ IN_QKV_SCALE_2};
571+
572+ if (qkv_tensor_indices.count (index) > 0 ) {
573+ if (n_kv_heads_ < dp_local_tp_size_) {
574+ int32_t repeat_times = (dp_local_tp_size_ / n_kv_heads_);
575+
576+ tp_rank = tp_rank / repeat_times;
577+ tp_size = n_kv_heads_;
578+ }
579+ }
547580 if (is_sharded) {
548- tmp_tensor = get_sharded_tensor (state_dict,
549- name,
550- shard_map.at (index),
551- dp_local_tp_rank_,
552- dp_local_tp_size_)
581+ tmp_tensor = get_sharded_tensor (
582+ state_dict, name, shard_map.at (index), tp_rank, tp_size)
553583 .to (device_);
554584 } else {
555585 tmp_tensor = tensor.to (device_);
0 commit comments