@@ -2510,9 +2510,7 @@ def sample_tokens(
25102510 max_gen_len = sampled_token_ids .shape [- 1 ]
25112511 if max_gen_len == 1 :
25122512 # No spec decode tokens. It's a tensor.
2513- valid_sampled_token_ids : list [np .ndarray ] = [
2514- row for row in sampled_token_ids .cpu ().numpy ()
2515- ]
2513+ valid_sampled_token_ids = sampled_token_ids .tolist ()
25162514 else :
25172515 # Includes spec decode tokens. It's a numpy array
25182516 valid_sampled_token_ids = self .rejection_sampler .parse_output (
@@ -2521,7 +2519,7 @@ def sample_tokens(
25212519 )
25222520 # Mask out the sampled tokens that should not be sampled.
25232521 for i in discard_sampled_tokens_req_indices :
2524- valid_sampled_token_ids [int (i )] = np . array ([] )
2522+ valid_sampled_token_ids [int (i )]. clear ( )
25252523 else :
25262524 valid_sampled_token_ids = []
25272525 invalid_req_indices = discard_sampled_tokens_req_indices .tolist (
@@ -2547,17 +2545,16 @@ def sample_tokens(
25472545 # the sampled tokens back, because there's no direct communication
25482546 # between the first-stage worker and the last-stage worker.
25492547 for req_idx in range (num_sampled_tokens ):
2550- sampled_ids : np .ndarray | None
25512548 if self .use_async_scheduling :
2552- sampled_ids = ( np . array ( [- 1 ]) if req_idx
2553- not in invalid_req_indices_set else None )
2549+ sampled_ids = [- 1 ] * 1 if \
2550+ req_idx not in invalid_req_indices_set else None
25542551 else :
25552552 sampled_ids = valid_sampled_token_ids [req_idx ]
2556- if sampled_ids is None or sampled_ids . shape [ 0 ] == 0 :
2553+ if not sampled_ids :
25572554 continue
25582555
25592556 start_idx = self .input_batch .num_tokens_no_spec [req_idx ]
2560- end_idx = start_idx + sampled_ids . shape [ 0 ]
2557+ end_idx = start_idx + len ( sampled_ids )
25612558 assert end_idx <= self .model_config .max_model_len , (
25622559 "Sampled token IDs exceed the max model length. "
25632560 f"Total number of tokens: { end_idx } > max_model_len: "
@@ -2571,7 +2568,7 @@ def sample_tokens(
25712568 self .input_batch .num_tokens [req_idx ] = end_idx
25722569 req_id = self .input_batch .req_ids [req_idx ]
25732570 req_state = self .requests [req_id ]
2574- req_state .output_token_ids .extend (sampled_ids . tolist () )
2571+ req_state .output_token_ids .extend (sampled_ids )
25752572
25762573 def propose_draft_token_ids (sampled_token_ids ):
25772574 assert self .spec_decode_common_attn_metadata is not None
@@ -2935,14 +2932,12 @@ def _dummy_run(
29352932 assert len (num_scheduled_tokens_list ) == num_reqs
29362933 num_scheduled_tokens = np .array (num_scheduled_tokens_list ,
29372934 dtype = np .int32 )
2938- num_sampled_tokens = np .ones (num_reqs , dtype = np .int32 )
29392935
29402936 if not self .in_profile_run and self .dynamic_eplb :
29412937 self .eplb_updator .forward_before ()
29422938
29432939 with self .maybe_dummy_run_with_lora (self .lora_config ,
2944- num_scheduled_tokens ,
2945- num_sampled_tokens ):
2940+ num_scheduled_tokens ):
29462941 if self .is_multimodal_model :
29472942 input_ids = None
29482943 inputs_embeds = self .inputs_embeds .gpu [:num_tokens ]
0 commit comments