Skip to content

Commit d7556d5

Browse files
committed
fix sample_tokens
Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent ec14697 commit d7556d5

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2510,9 +2510,7 @@ def sample_tokens(
25102510
max_gen_len = sampled_token_ids.shape[-1]
25112511
if max_gen_len == 1:
25122512
# No spec decode tokens. It's a tensor.
2513-
valid_sampled_token_ids: list[np.ndarray] = [
2514-
row for row in sampled_token_ids.cpu().numpy()
2515-
]
2513+
valid_sampled_token_ids = sampled_token_ids.tolist()
25162514
else:
25172515
# Includes spec decode tokens. It's a numpy array
25182516
valid_sampled_token_ids = self.rejection_sampler.parse_output(
@@ -2521,7 +2519,7 @@ def sample_tokens(
25212519
)
25222520
# Mask out the sampled tokens that should not be sampled.
25232521
for i in discard_sampled_tokens_req_indices:
2524-
valid_sampled_token_ids[int(i)] = np.array([])
2522+
valid_sampled_token_ids[int(i)].clear()
25252523
else:
25262524
valid_sampled_token_ids = []
25272525
invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
@@ -2547,17 +2545,16 @@ def sample_tokens(
25472545
# the sampled tokens back, because there's no direct communication
25482546
# between the first-stage worker and the last-stage worker.
25492547
for req_idx in range(num_sampled_tokens):
2550-
sampled_ids: np.ndarray | None
25512548
if self.use_async_scheduling:
2552-
sampled_ids = (np.array([-1]) if req_idx
2553-
not in invalid_req_indices_set else None)
2549+
sampled_ids = [-1] * 1 if \
2550+
req_idx not in invalid_req_indices_set else None
25542551
else:
25552552
sampled_ids = valid_sampled_token_ids[req_idx]
2556-
if sampled_ids is None or sampled_ids.shape[0] == 0:
2553+
if not sampled_ids:
25572554
continue
25582555

25592556
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
2560-
end_idx = start_idx + sampled_ids.shape[0]
2557+
end_idx = start_idx + len(sampled_ids)
25612558
assert end_idx <= self.model_config.max_model_len, (
25622559
"Sampled token IDs exceed the max model length. "
25632560
f"Total number of tokens: {end_idx} > max_model_len: "
@@ -2571,7 +2568,7 @@ def sample_tokens(
25712568
self.input_batch.num_tokens[req_idx] = end_idx
25722569
req_id = self.input_batch.req_ids[req_idx]
25732570
req_state = self.requests[req_id]
2574-
req_state.output_token_ids.extend(sampled_ids.tolist())
2571+
req_state.output_token_ids.extend(sampled_ids)
25752572

25762573
def propose_draft_token_ids(sampled_token_ids):
25772574
assert self.spec_decode_common_attn_metadata is not None
@@ -2935,14 +2932,12 @@ def _dummy_run(
29352932
assert len(num_scheduled_tokens_list) == num_reqs
29362933
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
29372934
dtype=np.int32)
2938-
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
29392935

29402936
if not self.in_profile_run and self.dynamic_eplb:
29412937
self.eplb_updator.forward_before()
29422938

29432939
with self.maybe_dummy_run_with_lora(self.lora_config,
2944-
num_scheduled_tokens,
2945-
num_sampled_tokens):
2940+
num_scheduled_tokens):
29462941
if self.is_multimodal_model:
29472942
input_ids = None
29482943
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]

0 commit comments

Comments
 (0)