@@ -58,6 +58,101 @@ int32_t kv_cache_slot_id(int32_t position,
5858 return block_id * block_size + block_offset;
5959}
6060
61+ // Convert tensor to int64 for MLU platform (temp workaround)
62+ // MLU will support int32 for masked_scatter in the future
63+ torch::Tensor ensure_int64_for_certain_platform (torch::Tensor tensor) {
64+ #if defined(USE_MLU)
65+ return tensor.to (torch::kInt64 );
66+ #else
67+ return tensor;
68+ #endif
69+ }
70+
71+ // Push cumulative sum to vector (used for cumulative format)
72+ void push_cumsum (std::vector<int32_t >& vec, int32_t len) {
73+ if (vec.empty ()) {
74+ vec.emplace_back (0 );
75+ }
76+ vec.emplace_back (vec.back () + len);
77+ }
78+
79+ // Batch expansion strategy for validation
80+ // Process validation sequence lengths for each token (used in
81+ // prepare_validate_inputs) For NPU without ATB: add direct values for each
82+ // token For MLU: add cumulative values for each token
83+ void batch_expansion_process_seq_lens (
84+ std::vector<int32_t >& kv_seq_lens_vec,
85+ std::vector<int32_t >& q_seq_lens_vec,
86+ std::vector<std::vector<int32_t >>& block_tables_vec,
87+ const Slice<int32_t >& kv_seq_lens_slice,
88+ const Slice<int32_t >& block_table_slice,
89+ int32_t seq_id,
90+ int32_t position_offset,
91+ int32_t num_val_tokens) {
92+ for (int32_t offset = position_offset;
93+ offset < num_val_tokens + position_offset;
94+ ++offset) {
95+ #if defined(USE_MLU)
96+ // process kv length and q length with the style of cumulative lengths
97+ // we use batch expansion strategy for validation, so q_len is always 1
98+ int32_t kv_len =
99+ kv_seq_lens_slice[seq_id + 1 ] - kv_seq_lens_slice[seq_id] + offset;
100+ int32_t q_len = 1 ;
101+ push_cumsum (kv_seq_lens_vec, kv_len);
102+ push_cumsum (q_seq_lens_vec, q_len);
103+ #else
104+ // For NPU without ATB: direct format
105+ q_seq_lens_vec.emplace_back (1 );
106+ kv_seq_lens_vec.emplace_back (kv_seq_lens_slice[seq_id] + token_id);
107+ #endif
108+ block_tables_vec.emplace_back (block_table_slice);
109+ }
110+ }
111+
112+ // Update kv_seq_lens_vec based on platform type
113+ // For NPU: directly add kv_seq_lens_slice[seq_id] + offset
114+ // For others: build cumulative format
115+ void update_kv_seq_lens_vec (std::vector<int32_t >& kv_seq_lens_vec,
116+ const Slice<int32_t >& kv_seq_lens_slice,
117+ int32_t seq_id,
118+ int32_t offset) {
119+ #if defined(USE_NPU)
120+ kv_seq_lens_vec.emplace_back (kv_seq_lens_slice[seq_id] + offset);
121+ #else
122+ // build cumulative format for kv_seq_lens
123+ int32_t offset_kv_len =
124+ kv_seq_lens_slice[seq_id + 1 ] - kv_seq_lens_slice[seq_id] + offset;
125+ push_cumsum (kv_seq_lens_vec, offset_kv_len);
126+ #endif
127+ }
128+
129+ // For GPU and MLU, kv_seq_lens_vec uses the cumulative format (accumulative
130+ // storage). The maximum sequence length is the largest difference between
131+ // consecutive elements. For NPU, kv_seq_lens_vec is in direct format (actual
132+ // lengths), so we simply return the maximum value.
133+ int32_t get_kv_max_seq_len (std::vector<int32_t >& kv_seq_lens_vec) {
134+ #if defined(USE_NPU)
135+ // NPU: kv_seq_lens_vec is in direct format, return the maximum value
136+ // directly.
137+ return *std::max_element (kv_seq_lens_vec.begin (), kv_seq_lens_vec.end ());
138+ #else
139+ // GPU/MLU: kv_seq_lens_vec is in cumulative format.
140+ // The maximum sequence length is the maximum difference between consecutive
141+ // elements.
142+ if (kv_seq_lens_vec.size () < 2 ) {
143+ return 0 ;
144+ }
145+ int32_t max_seq_len = 0 ;
146+ for (size_t i = 1 ; i < kv_seq_lens_vec.size (); ++i) {
147+ int32_t len = kv_seq_lens_vec[i] - kv_seq_lens_vec[i - 1 ];
148+ if (len > max_seq_len) {
149+ max_seq_len = len;
150+ }
151+ }
152+ return max_seq_len;
153+ #endif
154+ }
155+
61156} // namespace
62157
63158SpeculativeWorkerImpl::SpeculativeWorkerImpl (const ParallelArgs& parallel_args,
@@ -68,6 +163,11 @@ SpeculativeWorkerImpl::SpeculativeWorkerImpl(const ParallelArgs& parallel_args,
68163 runtime_options.enable_schedule_overlap (false );
69164 impl_ =
70165 std::make_unique<LLMWorkerImpl>(parallel_args, device, runtime_options);
166+ // here we specify num speculative tokens to 0 to pass the indication of
167+ // draft model to worker when enable_speculative_decode.
168+ // NOTE: If you want to modify this part, make sure you also check the usage
169+ // of
170+ // num_speculative_tokens in draft model.
71171 runtime_options.num_decoding_tokens (1 ).num_speculative_tokens (0 );
72172 draft_impl_ =
73173 std::make_unique<LLMWorkerImpl>(parallel_args, device, runtime_options);
@@ -194,13 +294,15 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
194294
195295 // prepare input for draft model
196296 auto & embeddings = output.sample_output .embeddings ;
197- auto next_tokens = safe_to (output.sample_output .next_tokens , torch::kInt );
297+ auto next_tokens = ensure_int64_for_certain_platform (
298+ safe_to (output.sample_output .next_tokens , torch::kInt ));
198299
199300 if (embeddings.defined ()) {
200301 prefill_input.input_params .input_embedding = embeddings.clone ();
201302 }
202303 if (next_tokens.defined ()) {
203304 auto & token_ids = prefill_input.token_ids ;
305+ token_ids = ensure_int64_for_certain_platform (token_ids);
204306 auto mask = (token_ids == -1 );
205307 token_ids.masked_scatter_ (mask, next_tokens);
206308 }
@@ -257,7 +359,7 @@ void SpeculativeWorkerImpl::prepare_prefill_inputs(
257359 new_token_ids.reserve (input.token_ids .numel ());
258360 for (size_t i = 0 ; i < input_params.num_sequences ; ++i) {
259361 int32_t q_len = 0 ;
260- q_len = input_params.q_seq_lens_vec [i] ;
362+ q_len = input_params.get_q_seq_len (i) ;
261363 Slice<int32_t > tokens_ids_slice_i =
262364 tokens_ids_slice.slice (start_idx + 1 , start_idx + q_len);
263365 start_idx += q_len;
@@ -314,9 +416,10 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
314416
315417 for (int i = 0 ; i < options_.num_speculative_tokens (); ++i) {
316418 ForwardOutput draft_output = draft_outputs[i];
317- auto next_tokens =
318- safe_to (draft_output.sample_output .next_tokens , torch::kInt );
419+ auto next_tokens = ensure_int64_for_certain_platform (
420+ safe_to (draft_output.sample_output .next_tokens , torch::kInt )) ;
319421 auto & token_ids = validate_input.token_ids ;
422+ token_ids = ensure_int64_for_certain_platform (token_ids);
320423 auto mask = (token_ids == -1 * (i + 1 ));
321424 token_ids.masked_scatter_ (mask, next_tokens);
322425 }
@@ -381,7 +484,7 @@ void SpeculativeWorkerImpl::prepare_draft_inputs(const ForwardInput& input,
381484
382485 for (int32_t seq_id = 0 ; seq_id < num_sequences; ++seq_id) {
383486 new_positions.emplace_back (positions_slice[seq_id] + offset);
384- kv_seq_lens_vec. emplace_back ( kv_seq_lens_slice[seq_id] + offset);
487+ update_kv_seq_lens_vec (kv_seq_lens_vec, kv_seq_lens_slice, seq_id, offset);
385488 torch::Tensor block_table = block_tables[seq_id];
386489 Slice<int32_t > block_table_slice = {block_table.data_ptr <int32_t >(),
387490 block_table.numel ()};
@@ -451,17 +554,21 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
451554
452555 // process kv length and q length
453556 if (FLAGS_enable_atb_spec_kernel) {
557+ // expand the num of decode tokens for each batch in the batch for
558+ // validation
454559 kv_seq_lens_vec.emplace_back (kv_seq_lens_slice[seq_id] +
455560 num_speculative_tokens + position_offset);
456561 q_seq_lens_vec.emplace_back (num_val_tokens);
457562 } else {
458- for (int32_t offset = position_offset;
459- offset < num_val_tokens + position_offset;
460- ++offset) {
461- q_seq_lens_vec.emplace_back (1 );
462- kv_seq_lens_vec.emplace_back (kv_seq_lens_slice[seq_id] + offset);
463- block_tables_vec.emplace_back (block_table_slice);
464- }
563+ // expand the batch sizes for validation
564+ batch_expansion_process_seq_lens (kv_seq_lens_vec,
565+ q_seq_lens_vec,
566+ block_tables_vec,
567+ kv_seq_lens_slice,
568+ block_table_slice,
569+ seq_id,
570+ position_offset,
571+ num_val_tokens);
465572 }
466573
467574 // process slot id
@@ -571,6 +678,7 @@ SampleOutput SpeculativeWorkerImpl::validate(
571678 size_t num_draft_tokens = num_target_tokens - batch_size;
572679 COUNTER_ADD (speculative_num_draft_tokens_total, num_draft_tokens);
573680 COUNTER_ADD (speculative_num_accepted_tokens_total, num_draft_tokens - count);
681+
574682 return sample_output;
575683}
576684
@@ -589,11 +697,14 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
589697 torch::Tensor positions = safe_to (inputs.positions , torch::kCPU );
590698 Slice<int32_t > positions_slice = {positions.data_ptr <int32_t >(),
591699 positions.numel ()};
700+ // Get the tokens generated in the last step (flattened for easier indexing)
592701 torch::Tensor last_token_ids = safe_to (
593702 last_step_output_.sample_output .next_tokens .flatten (), torch::kCPU );
594703 Slice<int64_t > last_tokens_ids_slice = {last_token_ids.data_ptr <int64_t >(),
595704 last_token_ids.numel ()};
596705
706+ // Determine how many tokens were decoded in the last step
707+ // If the output is 2D, it means multiple tokens were generated per sequence
597708 int32_t last_step_decode_num = 1 ;
598709 if (last_step_output_.sample_output .next_tokens .dim () == 2 ) {
599710 last_step_decode_num = last_step_output_.sample_output .next_tokens .size (1 );
@@ -611,13 +722,20 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
611722 kv_seq_lens_vec.reserve (num_sequences);
612723 new_token_slot_ids.reserve (num_sequences);
613724
614- // get right token id and position
725+ // Process each sequence to get the correct token ID and position for the next
726+ // step
615727 for (int32_t seq_id = 0 ; seq_id < num_sequences; ++seq_id) {
616728 int32_t postion_offset = 0 ;
617729 int32_t last_step_token_id = 0 ;
730+
731+ // If the token ID is non-negative, it's a direct token ID (not a
732+ // placeholder)
618733 if (tokens_ids_slice[seq_id] >= 0 ) {
619734 last_step_token_id = tokens_ids_slice[seq_id];
620735 } else {
736+ // Negative token IDs are placeholders that need to be resolved from
737+ // last_step_output_ The absolute value minus 1 gives the index into the
738+ // last step's output
621739 int32_t last_step_index = -1 * tokens_ids_slice[seq_id] - 1 ;
622740 last_step_index = last_step_index * last_step_decode_num;
623741 postion_offset = -1 ;
@@ -632,8 +750,11 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
632750
633751 new_token_ids.emplace_back (last_step_token_id);
634752 new_positions.emplace_back (positions_slice[seq_id] + postion_offset);
635- kv_seq_lens_vec.emplace_back (kv_seq_lens_slice[seq_id] + postion_offset);
753+ update_kv_seq_lens_vec (
754+ kv_seq_lens_vec, kv_seq_lens_slice, seq_id, postion_offset);
636755
756+ // Calculate the new cache slot ID based on the position offset
757+ // This handles cases where we need to move to a different block
637758 torch::Tensor block_table = block_tables[seq_id];
638759 Slice<int32_t > block_table_slice = {block_table.data_ptr <int32_t >(),
639760 block_table.numel ()};
@@ -642,12 +763,12 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
642763 new_token_slot_ids.emplace_back (slot_id);
643764 }
644765
766+ // Create new tensors with updated values
645767 torch::TensorOptions int_options = inputs.token_ids .options ();
646768 new_inputs.token_ids = torch::tensor (new_token_ids, int_options);
647769 new_inputs.positions = torch::tensor (new_positions, int_options);
648770 // update the input_params
649- input_params.kv_max_seq_len =
650- *std::max_element (kv_seq_lens_vec.begin (), kv_seq_lens_vec.end ());
771+ input_params.kv_max_seq_len = get_kv_max_seq_len (kv_seq_lens_vec);
651772 input_params.kv_seq_lens_vec = std::move (kv_seq_lens_vec);
652773 input_params.kv_seq_lens =
653774 torch::tensor (input_params.kv_seq_lens_vec , int_options);
0 commit comments