Skip to content

Commit 3e35b5a

Browse files
authored
feat: enhance Qwen3-MoE to support TP settings beyond 4. (#434)
1 parent f024ad5 commit 3e35b5a

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License.
1717

1818
#include <gflags/gflags.h>
1919

20+
#include <unordered_set>
21+
2022
#include "common/global_flags.h"
2123

2224
namespace xllm {
@@ -233,6 +235,8 @@ NpuQwen3MoeDecoderLayerImpl::NpuQwen3MoeDecoderLayerImpl(
233235
CHECK_EQ(parallel_args.world_size(), dp_size_ * dp_local_tp_size_);
234236
dp_local_tp_rank_ = parallel_args.rank() % dp_local_tp_size_;
235237

238+
n_kv_heads_ = static_cast<int32_t>(model_args.n_kv_heads().value());
239+
236240
param_from_args(prefill_param_, model_args, parallel_args, true);
237241
param_from_args(decode_param_, model_args, parallel_args, false);
238242
initialize_tensors(options);
@@ -345,8 +349,8 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_basic_parameters(
345349
param.rmsnormQKNorm = true;
346350
param.hiddenSizePerAttentionHead = args.head_dim();
347351
std::optional<long int> optionalValue = args.n_kv_heads();
348-
param.numKeyValueHeadsPerRank =
349-
static_cast<int>(optionalValue.value()) / parallel_args.world_size();
352+
param.numKeyValueHeadsPerRank = std::max(
353+
1, static_cast<int>(optionalValue.value()) / parallel_args.world_size());
350354
param.numAttentionHeadsPerRank = args.n_heads() / dp_local_tp_size_;
351355

352356
param.attnLinearTransposeType = {1, -1, -1, 1, -1, -1};
@@ -390,8 +394,16 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_mlp_parameters(
390394
void NpuQwen3MoeDecoderLayerImpl::initialize_parallel_parameters(
391395
atb_speed::qwen::MoeDecoderLayerParam& param,
392396
const ParallelArgs& parallel_args) {
393-
param.lmHeadLocalTp = 0;
397+
param.lmHeadLocalTp = dp_local_tp_size_;
394398
param.mapping = parallel_args.mapping();
399+
param.tensorParallelInfo = {parallel_args.rank(),
400+
parallel_args.world_size(),
401+
FLAGS_communication_backend,
402+
FLAGS_rank_tablefile,
403+
nullptr,
404+
""};
405+
406+
param.PrintParam();
395407
param.maxDecodeDpTokenSize = 0; // TODO
396408
}
397409

@@ -543,13 +555,31 @@ void NpuQwen3MoeDecoderLayerImpl::process_general_weights(
543555
const int index = get_mapped_index(name, weight_mapping);
544556
const bool is_sharded = shard_map.count(index);
545557
torch::Tensor tmp_tensor;
546-
558+
int32_t tp_rank = dp_local_tp_rank_;
559+
int32_t tp_size = dp_local_tp_size_;
560+
561+
static const std::unordered_set<int> qkv_tensor_indices = {IN_QKV_WEIGHT_1,
562+
IN_QKV_WEIGHT_2,
563+
IN_QKV_BIAS_1,
564+
IN_QKV_BIAS_2,
565+
IN_QKV_DESCALE_1,
566+
IN_QKV_DESCALE_2,
567+
IN_QKV_OFFSET_1,
568+
IN_QKV_OFFSET_2,
569+
IN_QKV_SCALE_1,
570+
IN_QKV_SCALE_2};
571+
572+
if (qkv_tensor_indices.count(index) > 0) {
573+
if (n_kv_heads_ < dp_local_tp_size_) {
574+
int32_t repeat_times = (dp_local_tp_size_ / n_kv_heads_);
575+
576+
tp_rank = tp_rank / repeat_times;
577+
tp_size = n_kv_heads_;
578+
}
579+
}
547580
if (is_sharded) {
548-
tmp_tensor = get_sharded_tensor(state_dict,
549-
name,
550-
shard_map.at(index),
551-
dp_local_tp_rank_,
552-
dp_local_tp_size_)
581+
tmp_tensor = get_sharded_tensor(
582+
state_dict, name, shard_map.at(index), tp_rank, tp_size)
553583
.to(device_);
554584
} else {
555585
tmp_tensor = tensor.to(device_);

xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ class NpuQwen3MoeDecoderLayerImpl : public NpuBaseLayer {
190190
int32_t start_expert_id_;
191191
int32_t end_expert_id_;
192192
int32_t ep_rank_;
193+
int32_t n_kv_heads_;
193194

194195
int32_t dp_size_;
195196
int32_t dp_local_tp_size_;

0 commit comments

Comments
 (0)