Skip to content

Commit 561983f

Browse files
authored
feat: implement chunked prefill and prefix cache for Qwen3 MoE. (#435)
1 parent eaf82da commit 561983f

File tree

3 files changed

+41
-12
lines changed

3 files changed

+41
-12
lines changed

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@ if(USE_NPU)
2828
if(DEVICE_TYPE STREQUAL "USE_A3")
2929
message("downloading a3 arm xllm kernels")
3030
file(DOWNLOAD
31-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a3.arm.rpm"
31+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a3.arm.rpm"
3232
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
3333
)
3434
else()
3535
if(DEVICE_ARCH STREQUAL "ARM")
3636
message("downloading a2 arm xllm_kernels")
3737
file(DOWNLOAD
38-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.arm.rpm"
38+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a2.arm.rpm"
3939
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
4040
)
4141
else()
4242
message("downloading a2 x86 xllm_kernels")
4343
file(DOWNLOAD
44-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.x86.rpm"
44+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a2.x86.rpm"
4545
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
4646
)
4747
endif()

xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_basic_parameters(
321321

322322
param.mlpLinearTransposeType = {-1, -1, -1, -1};
323323

324+
param.enableSplitFuse =
325+
(FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache) && is_prefill;
324326
if (quantize_type_.empty()) {
325327
param.moeLinearTransposeType = std::vector<int>{1, 1, -1, 1};
326328
} else {
@@ -403,7 +405,6 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_parallel_parameters(
403405
nullptr,
404406
""};
405407

406-
param.PrintParam();
407408
param.maxDecodeDpTokenSize = 0; // TODO
408409
}
409410

@@ -895,7 +896,9 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward(
895896
std::atomic<bool>* event_flag,
896897
int node_id) {
897898
atb::Status st;
898-
if (input_params.global_empty_kv_cache) {
899+
bool is_prefill = input_params.decode_seq_range.second !=
900+
input_params.q_seq_lens.size(0) - 1;
901+
if (is_prefill) {
899902
build_node_variant_pack(prefill_node_,
900903
x,
901904
cos_pos,
@@ -998,6 +1001,14 @@ void NpuQwen3MoeDecoderLayerImpl::build_node_variant_pack(
9981001
atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots);
9991002
}
10001003

1004+
if (is_prefill &&
1005+
(FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache)) {
1006+
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16) =
1007+
atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens);
1008+
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16).hostData =
1009+
const_cast<int32_t*>(input_params.q_seq_lens_vec.data());
1010+
}
1011+
10011012
for (size_t i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) {
10021013
CHECK_THROW(node.inTensors.at(i) == nullptr,
10031014
model_name_ << " inTensor " << i << " is NULL");

xllm/models/llm/qwen3_moe.h

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,7 @@ class Qwen3MoeModelImpl : public torch::nn::Module {
156156
model_args.rope_theta(),
157157
options);
158158

159-
max_seq_len_ = model_args.max_position_embeddings();
160-
int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984;
159+
int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1;
161160
attn_mask_ = layer::AttentionMask(options.device(),
162161
options.dtype().toScalarType(),
163162
/*mask_value=*/mask_value);
@@ -258,11 +257,30 @@ class Qwen3MoeModelImpl : public torch::nn::Module {
258257
}
259258

260259
torch::Tensor attn_mask;
261-
if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) {
262-
attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_);
263-
} else {
264-
attn_mask = attn_mask_.gen_free_mask(
265-
num_speculative_tokens_ + 1, dtype_, device_);
260+
max_seq_len_ = FLAGS_enable_chunked_prefill
261+
? std::max(input_params.kv_max_seq_len, max_seq_len_)
262+
: 128;
263+
if (FLAGS_enable_chunked_prefill) {
264+
attn_mask = attn_mask_.get_attn_mask(
265+
max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device());
266+
267+
int batch_size = input_params.q_seq_lens_vec.size();
268+
if (batch_size > 0) {
269+
std::vector<torch::Tensor> req_mask_vec;
270+
req_mask_vec.reserve(batch_size);
271+
272+
for (int j = 0; j < batch_size; j++) {
273+
int start =
274+
input_params.kv_seq_lens_vec[j] - input_params.q_seq_lens_vec[j];
275+
int end = input_params.kv_seq_lens_vec[j];
276+
277+
auto req_mask_slice = attn_mask.slice(0, start, end);
278+
req_mask_vec.emplace_back(req_mask_slice);
279+
}
280+
attn_mask = torch::cat(req_mask_vec, 0);
281+
}
282+
} else if (input_params.global_empty_kv_cache) {
283+
attn_mask = attn_mask_.get_attn_mask(max_seq_len_, dtype_, device_);
266284
}
267285
auto deep_stacks = input_params.deep_stacks;
268286
int deep_stack_size = deep_stacks.size();

0 commit comments

Comments
 (0)