Skip to content

Commit 6b34e20

Browse files
committed
feat: support deepseek mtp on mlu.
1 parent fbf4515 commit 6b34e20

File tree

10 files changed

+471
-36
lines changed

10 files changed

+471
-36
lines changed

xllm/core/framework/model/model_args.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ struct ModelArgs {
124124
PROPERTY(int32_t, v_head_dim) = 0;
125125
PROPERTY(int32_t, q_lora_rank) = 0;
126126
PROPERTY(int32_t, kv_lora_rank) = 0;
127+
// deepseek v3/v3.2 MTP
128+
PROPERTY(int32_t, num_nextn_predict_layers) = 0;
127129

128130
// deepseek v3.2 indexer
129131
PROPERTY(int32_t, index_head_dim) = 0;

xllm/core/framework/model/model_input_params.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,17 @@ struct ModelInputParams {
161161
LOG(INFO) << "ModelInputParams: dp_global_token_nums is "
162162
<< dp_global_token_nums;
163163
}
164+
165+
int32_t get_q_seq_len(int32_t seq_idx) const {
166+
#if defined(USE_NPU)
167+
CHECK(seq_idx < q_seq_lens_vec.size()) << "seq_idx out of range";
168+
return q_seq_lens_vec[seq_idx];
169+
#else
170+
CHECK(seq_idx < q_seq_lens_vec.size() - 1) << "seq_idx out of range";
171+
return q_seq_lens_vec[seq_idx + 1] - q_seq_lens_vec[seq_idx];
172+
#endif
173+
}
174+
164175
// whether the kv-cache is empty for all sequences.
165176
bool empty_kv_cache = true;
166177
BatchForwardType batch_forward_type;

xllm/core/runtime/speculative_worker_impl.cpp

Lines changed: 137 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,101 @@ int32_t kv_cache_slot_id(int32_t position,
5858
return block_id * block_size + block_offset;
5959
}
6060

61+
// Convert tensor to int64 for MLU platform (temp workaround)
62+
// MLU will support int32 for masked_scatter in the future
63+
torch::Tensor ensure_int64_for_certain_platform(torch::Tensor tensor) {
64+
#if defined(USE_MLU)
65+
return tensor.to(torch::kInt64);
66+
#else
67+
return tensor;
68+
#endif
69+
}
70+
71+
// Push cumulative sum to vector (used for cumulative format)
72+
void push_cumsum(std::vector<int32_t>& vec, int32_t len) {
73+
if (vec.empty()) {
74+
vec.emplace_back(0);
75+
}
76+
vec.emplace_back(vec.back() + len);
77+
}
78+
79+
// Batch expansion strategy for validation
80+
// Process validation sequence lengths for each token (used in
81+
// prepare_validate_inputs) For NPU without ATB: add direct values for each
82+
// token For MLU: add cumulative values for each token
83+
void batch_expansion_process_seq_lens(
84+
std::vector<int32_t>& kv_seq_lens_vec,
85+
std::vector<int32_t>& q_seq_lens_vec,
86+
std::vector<std::vector<int32_t>>& block_tables_vec,
87+
const Slice<int32_t>& kv_seq_lens_slice,
88+
const Slice<int32_t>& block_table_slice,
89+
int32_t seq_id,
90+
int32_t position_offset,
91+
int32_t num_val_tokens) {
92+
for (int32_t offset = position_offset;
93+
offset < num_val_tokens + position_offset;
94+
++offset) {
95+
#if defined(USE_MLU)
96+
// process kv length and q length with the style of cumulative lengths
97+
// we use batch expansion strategy for validation, so q_len is always 1
98+
int32_t kv_len =
99+
kv_seq_lens_slice[seq_id + 1] - kv_seq_lens_slice[seq_id] + offset;
100+
int32_t q_len = 1;
101+
push_cumsum(kv_seq_lens_vec, kv_len);
102+
push_cumsum(q_seq_lens_vec, q_len);
103+
#else
104+
// For NPU without ATB: direct format
105+
q_seq_lens_vec.emplace_back(1);
106+
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + token_id);
107+
#endif
108+
block_tables_vec.emplace_back(block_table_slice);
109+
}
110+
}
111+
112+
// Update kv_seq_lens_vec based on platform type
113+
// For NPU: directly add kv_seq_lens_slice[seq_id] + offset
114+
// For others: build cumulative format
115+
void update_kv_seq_lens_vec(std::vector<int32_t>& kv_seq_lens_vec,
116+
const Slice<int32_t>& kv_seq_lens_slice,
117+
int32_t seq_id,
118+
int32_t offset) {
119+
#if defined(USE_NPU)
120+
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + offset);
121+
#else
122+
// build cumulative format for kv_seq_lens
123+
int32_t offset_kv_len =
124+
kv_seq_lens_slice[seq_id + 1] - kv_seq_lens_slice[seq_id] + offset;
125+
push_cumsum(kv_seq_lens_vec, offset_kv_len);
126+
#endif
127+
}
128+
129+
// For GPU and MLU, kv_seq_lens_vec uses the cumulative format (accumulative
130+
// storage). The maximum sequence length is the largest difference between
131+
// consecutive elements. For NPU, kv_seq_lens_vec is in direct format (actual
132+
// lengths), so we simply return the maximum value.
133+
int32_t get_kv_max_seq_len(std::vector<int32_t>& kv_seq_lens_vec) {
134+
#if defined(USE_NPU)
135+
// NPU: kv_seq_lens_vec is in direct format, return the maximum value
136+
// directly.
137+
return *std::max_element(kv_seq_lens_vec.begin(), kv_seq_lens_vec.end());
138+
#else
139+
// GPU/MLU: kv_seq_lens_vec is in cumulative format.
140+
// The maximum sequence length is the maximum difference between consecutive
141+
// elements.
142+
if (kv_seq_lens_vec.size() < 2) {
143+
return 0;
144+
}
145+
int32_t max_seq_len = 0;
146+
for (size_t i = 1; i < kv_seq_lens_vec.size(); ++i) {
147+
int32_t len = kv_seq_lens_vec[i] - kv_seq_lens_vec[i - 1];
148+
if (len > max_seq_len) {
149+
max_seq_len = len;
150+
}
151+
}
152+
return max_seq_len;
153+
#endif
154+
}
155+
61156
} // namespace
62157

63158
SpeculativeWorkerImpl::SpeculativeWorkerImpl(const ParallelArgs& parallel_args,
@@ -68,6 +163,11 @@ SpeculativeWorkerImpl::SpeculativeWorkerImpl(const ParallelArgs& parallel_args,
68163
runtime_options.enable_schedule_overlap(false);
69164
impl_ =
70165
std::make_unique<LLMWorkerImpl>(parallel_args, device, runtime_options);
166+
// here we specify num speculative tokens to 0 to pass the indication of
167+
// draft model to worker when enable_speculative_decode.
168+
// NOTE: If you want to modify this part, make sure you also check the usage
169+
// of
170+
// num_speculative_tokens in draft model.
71171
runtime_options.num_decoding_tokens(1).num_speculative_tokens(0);
72172
draft_impl_ =
73173
std::make_unique<LLMWorkerImpl>(parallel_args, device, runtime_options);
@@ -194,13 +294,15 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
194294

195295
// prepare input for draft model
196296
auto& embeddings = output.sample_output.embeddings;
197-
auto next_tokens = safe_to(output.sample_output.next_tokens, torch::kInt);
297+
auto next_tokens = ensure_int64_for_certain_platform(
298+
safe_to(output.sample_output.next_tokens, torch::kInt));
198299

199300
if (embeddings.defined()) {
200301
prefill_input.input_params.input_embedding = embeddings.clone();
201302
}
202303
if (next_tokens.defined()) {
203304
auto& token_ids = prefill_input.token_ids;
305+
token_ids = ensure_int64_for_certain_platform(token_ids);
204306
auto mask = (token_ids == -1);
205307
token_ids.masked_scatter_(mask, next_tokens);
206308
}
@@ -257,7 +359,7 @@ void SpeculativeWorkerImpl::prepare_prefill_inputs(
257359
new_token_ids.reserve(input.token_ids.numel());
258360
for (size_t i = 0; i < input_params.num_sequences; ++i) {
259361
int32_t q_len = 0;
260-
q_len = input_params.q_seq_lens_vec[i];
362+
q_len = input_params.get_q_seq_len(i);
261363
Slice<int32_t> tokens_ids_slice_i =
262364
tokens_ids_slice.slice(start_idx + 1, start_idx + q_len);
263365
start_idx += q_len;
@@ -314,9 +416,10 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
314416

315417
for (int i = 0; i < options_.num_speculative_tokens(); ++i) {
316418
ForwardOutput draft_output = draft_outputs[i];
317-
auto next_tokens =
318-
safe_to(draft_output.sample_output.next_tokens, torch::kInt);
419+
auto next_tokens = ensure_int64_for_certain_platform(
420+
safe_to(draft_output.sample_output.next_tokens, torch::kInt));
319421
auto& token_ids = validate_input.token_ids;
422+
token_ids = ensure_int64_for_certain_platform(token_ids);
320423
auto mask = (token_ids == -1 * (i + 1));
321424
token_ids.masked_scatter_(mask, next_tokens);
322425
}
@@ -381,7 +484,7 @@ void SpeculativeWorkerImpl::prepare_draft_inputs(const ForwardInput& input,
381484

382485
for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) {
383486
new_positions.emplace_back(positions_slice[seq_id] + offset);
384-
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + offset);
487+
update_kv_seq_lens_vec(kv_seq_lens_vec, kv_seq_lens_slice, seq_id, offset);
385488
torch::Tensor block_table = block_tables[seq_id];
386489
Slice<int32_t> block_table_slice = {block_table.data_ptr<int32_t>(),
387490
block_table.numel()};
@@ -451,17 +554,21 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
451554

452555
// process kv length and q length
453556
if (FLAGS_enable_atb_spec_kernel) {
557+
// expand the num of decode tokens for each batch in the batch for
558+
// validation
454559
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] +
455560
num_speculative_tokens + position_offset);
456561
q_seq_lens_vec.emplace_back(num_val_tokens);
457562
} else {
458-
for (int32_t offset = position_offset;
459-
offset < num_val_tokens + position_offset;
460-
++offset) {
461-
q_seq_lens_vec.emplace_back(1);
462-
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + offset);
463-
block_tables_vec.emplace_back(block_table_slice);
464-
}
563+
// expand the batch sizes for validation
564+
batch_expansion_process_seq_lens(kv_seq_lens_vec,
565+
q_seq_lens_vec,
566+
block_tables_vec,
567+
kv_seq_lens_slice,
568+
block_table_slice,
569+
seq_id,
570+
position_offset,
571+
num_val_tokens);
465572
}
466573

467574
// process slot id
@@ -571,6 +678,7 @@ SampleOutput SpeculativeWorkerImpl::validate(
571678
size_t num_draft_tokens = num_target_tokens - batch_size;
572679
COUNTER_ADD(speculative_num_draft_tokens_total, num_draft_tokens);
573680
COUNTER_ADD(speculative_num_accepted_tokens_total, num_draft_tokens - count);
681+
574682
return sample_output;
575683
}
576684

@@ -589,11 +697,14 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
589697
torch::Tensor positions = safe_to(inputs.positions, torch::kCPU);
590698
Slice<int32_t> positions_slice = {positions.data_ptr<int32_t>(),
591699
positions.numel()};
700+
// Get the tokens generated in the last step (flattened for easier indexing)
592701
torch::Tensor last_token_ids = safe_to(
593702
last_step_output_.sample_output.next_tokens.flatten(), torch::kCPU);
594703
Slice<int64_t> last_tokens_ids_slice = {last_token_ids.data_ptr<int64_t>(),
595704
last_token_ids.numel()};
596705

706+
// Determine how many tokens were decoded in the last step
707+
// If the output is 2D, it means multiple tokens were generated per sequence
597708
int32_t last_step_decode_num = 1;
598709
if (last_step_output_.sample_output.next_tokens.dim() == 2) {
599710
last_step_decode_num = last_step_output_.sample_output.next_tokens.size(1);
@@ -611,13 +722,20 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
611722
kv_seq_lens_vec.reserve(num_sequences);
612723
new_token_slot_ids.reserve(num_sequences);
613724

614-
// get right token id and position
725+
// Process each sequence to get the correct token ID and position for the next
726+
// step
615727
for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) {
616728
int32_t postion_offset = 0;
617729
int32_t last_step_token_id = 0;
730+
731+
// If the token ID is non-negative, it's a direct token ID (not a
732+
// placeholder)
618733
if (tokens_ids_slice[seq_id] >= 0) {
619734
last_step_token_id = tokens_ids_slice[seq_id];
620735
} else {
736+
// Negative token IDs are placeholders that need to be resolved from
737+
// last_step_output_ The absolute value minus 1 gives the index into the
738+
// last step's output
621739
int32_t last_step_index = -1 * tokens_ids_slice[seq_id] - 1;
622740
last_step_index = last_step_index * last_step_decode_num;
623741
postion_offset = -1;
@@ -632,8 +750,11 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
632750

633751
new_token_ids.emplace_back(last_step_token_id);
634752
new_positions.emplace_back(positions_slice[seq_id] + postion_offset);
635-
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + postion_offset);
753+
update_kv_seq_lens_vec(
754+
kv_seq_lens_vec, kv_seq_lens_slice, seq_id, postion_offset);
636755

756+
// Calculate the new cache slot ID based on the position offset
757+
// This handles cases where we need to move to a different block
637758
torch::Tensor block_table = block_tables[seq_id];
638759
Slice<int32_t> block_table_slice = {block_table.data_ptr<int32_t>(),
639760
block_table.numel()};
@@ -642,12 +763,12 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
642763
new_token_slot_ids.emplace_back(slot_id);
643764
}
644765

766+
// Create new tensors with updated values
645767
torch::TensorOptions int_options = inputs.token_ids.options();
646768
new_inputs.token_ids = torch::tensor(new_token_ids, int_options);
647769
new_inputs.positions = torch::tensor(new_positions, int_options);
648770
// update the input_params
649-
input_params.kv_max_seq_len =
650-
*std::max_element(kv_seq_lens_vec.begin(), kv_seq_lens_vec.end());
771+
input_params.kv_max_seq_len = get_kv_max_seq_len(kv_seq_lens_vec);
651772
input_params.kv_seq_lens_vec = std::move(kv_seq_lens_vec);
652773
input_params.kv_seq_lens =
653774
torch::tensor(input_params.kv_seq_lens_vec, int_options);

xllm/core/runtime/worker_impl.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,9 +600,26 @@ bool WorkerImpl::init_model(const std::string& model_weights_path) {
600600
}
601601
}
602602

603+
#if defined(USE_NPU)
603604
if (options_.enable_speculative_decode() && FLAGS_enable_atb_spec_kernel) {
604605
args.num_speculative_tokens(options_.num_speculative_tokens());
605606
}
607+
#else
608+
if (options_.enable_speculative_decode()) {
609+
args.num_speculative_tokens(options_.num_speculative_tokens());
610+
// When running speculative decoding, the draft worker reuses the same
611+
// checkpoint as the target DeepSeek V3/V32 model. The draft worker needs to
612+
// instantiate the MTP variant, so override the model_type here without
613+
// mutating the original config.
614+
if (options_.num_speculative_tokens() == 0 &&
615+
(args.model_type() == "deepseek_v3" ||
616+
args.model_type() == "deepseek_v32")) {
617+
LOG(INFO) << "Overriding draft model_type from " << args.model_type()
618+
<< " to deepseek_mtp for speculative decoding";
619+
args.model_type("deepseek_mtp");
620+
}
621+
}
622+
#endif
606623

607624
// create model context
608625
dtype_ = dtype;

xllm/models/llm/llm_model_base.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,9 @@ class LlmForCausalLMImplBase : public torch::nn::Module {
422422
#endif
423423
}
424424

425-
void load_model(std::unique_ptr<ModelLoader> loader,
426-
std::string prefix = "model." /*llm model weight prefix*/) {
425+
virtual void load_model(
426+
std::unique_ptr<ModelLoader> loader,
427+
std::string prefix = "model." /*llm model weight prefix*/) {
427428
for (const auto& state_dict : loader->get_state_dicts()) {
428429
model_->load_state_dict(state_dict->get_dict_with_prefix(prefix));
429430
if (tie_word_embeddings) {

0 commit comments

Comments
 (0)