Skip to content

Commit 6cf592d

Browse files
WoosukKwoncharlotte12l
authored andcommitted
[Model Runner V2] Refactor prefill token preparation (vllm-project#29712)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
1 parent 7a537cf commit 6cf592d

File tree

5 files changed

+83
-78
lines changed

5 files changed

+83
-78
lines changed

vllm/v1/worker/gpu/cudagraph_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def capture_graph(
7878
kv_cache_config: KVCacheConfig,
7979
) -> None:
8080
num_reqs = min(num_tokens, self.max_num_reqs)
81-
input_ids = input_buffers.input_ids.gpu[:num_tokens]
81+
input_ids = input_buffers.input_ids[:num_tokens]
8282
positions = input_buffers.positions[:num_tokens]
8383
attn_metadata = prepare_inputs_to_capture(
8484
num_reqs,

vllm/v1/worker/gpu/input_batch.py

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from dataclasses import dataclass
44
from typing import Any
55

6-
import numba
76
import numpy as np
87
import torch
98

@@ -30,15 +29,12 @@ def __init__(
3029
self.pin_memory = pin_memory
3130

3231
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
33-
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
32+
self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
3433
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
3534
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
3635
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
3736
self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
3837

39-
# Spec decoding.
40-
self.next_prefill_tokens = self._make_buffer(max_num_reqs, dtype=torch.int32)
41-
4238
# Structured outputs.
4339
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
4440
self.grammar_bitmask = self._make_buffer(
@@ -120,7 +116,7 @@ def make_dummy(
120116
input_buffers.seq_lens[num_reqs:] = 0
121117
seq_lens = input_buffers.seq_lens[:num_reqs]
122118

123-
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
119+
input_ids = input_buffers.input_ids[:num_tokens]
124120
positions = input_buffers.positions[:num_tokens]
125121
# attn_metadata = defaultdict(lambda: None)
126122
logits_indices = query_start_loc[1:] - 1
@@ -146,41 +142,63 @@ def make_dummy(
146142
)
147143

148144

149-
@numba.njit(cache=True)
150-
def _prepare_prefill_inputs(
151-
idx_mapping: np.ndarray, # [B]
152-
query_lens: np.ndarray, # [B]
153-
query_start_loc: np.ndarray, # [B + 1]
154-
prefill_token_ids: np.ndarray, # [N, max_model_len]
155-
num_computed_prefill_tokens: np.ndarray, # [N]
156-
input_ids: np.ndarray, # [num_input_tokens]
157-
) -> None:
158-
num_reqs = idx_mapping.shape[0]
159-
query_starts = query_start_loc[:num_reqs]
160-
query_ends = query_start_loc[1 : num_reqs + 1]
161-
starts = num_computed_prefill_tokens[idx_mapping]
162-
ends = starts + query_lens
163-
for i in range(num_reqs):
164-
input_ids[query_starts[i] : query_ends[i]] = prefill_token_ids[
165-
idx_mapping[i], starts[i] : ends[i]
166-
]
145+
@triton.jit
146+
def _prepare_prefill_inputs_kernel(
147+
input_ids_ptr,
148+
next_prefill_tokens_ptr,
149+
idx_mapping_ptr,
150+
query_start_loc_ptr,
151+
prefill_token_ids_ptr,
152+
prefill_token_ids_stride,
153+
prefill_lens_ptr,
154+
num_computed_tokens_ptr,
155+
BLOCK_SIZE: tl.constexpr,
156+
):
157+
batch_idx = tl.program_id(0)
158+
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
159+
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
160+
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
161+
if num_computed >= prefill_len:
162+
# Not prefill.
163+
return
164+
165+
query_start = tl.load(query_start_loc_ptr + batch_idx)
166+
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
167+
query_len = query_end - query_start
168+
169+
prefill_ptr = prefill_token_ids_ptr + req_state_idx * prefill_token_ids_stride
170+
for i in range(0, query_len, BLOCK_SIZE):
171+
block = i + tl.arange(0, BLOCK_SIZE)
172+
mask = block < query_len
173+
tokens = tl.load(prefill_ptr + num_computed + block, mask=mask)
174+
tl.store(input_ids_ptr + query_start + block, tokens, mask=mask)
175+
176+
next_pos = num_computed + query_len
177+
if next_pos < prefill_len:
178+
next_token = tl.load(prefill_ptr + next_pos)
179+
tl.store(next_prefill_tokens_ptr + req_state_idx, next_token)
167180

168181

169182
def prepare_prefill_inputs(
170-
idx_mapping: np.ndarray,
171-
num_scheduled_tokens: np.ndarray,
172-
query_start_loc: np.ndarray,
173-
prefill_token_ids: np.ndarray,
174-
num_computed_prefill_tokens: np.ndarray,
175-
input_ids: np.ndarray,
183+
input_ids: torch.Tensor,
184+
next_prefill_tokens: torch.Tensor,
185+
idx_mapping: torch.Tensor,
186+
query_start_loc: torch.Tensor,
187+
prefill_token_ids: torch.Tensor,
188+
prefill_len: torch.Tensor,
189+
num_computed_tokens: torch.Tensor,
176190
) -> None:
177-
_prepare_prefill_inputs(
191+
num_reqs = idx_mapping.shape[0]
192+
_prepare_prefill_inputs_kernel[(num_reqs,)](
193+
input_ids,
194+
next_prefill_tokens,
178195
idx_mapping,
179-
num_scheduled_tokens,
180196
query_start_loc,
181197
prefill_token_ids,
182-
num_computed_prefill_tokens,
183-
input_ids,
198+
prefill_token_ids.stride(0),
199+
prefill_len,
200+
num_computed_tokens,
201+
BLOCK_SIZE=1024,
184202
)
185203

186204

vllm/v1/worker/gpu/model_runner.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,9 @@ def __init__(
104104
if self.use_async_scheduling:
105105
self.input_prep_event = torch.cuda.Event()
106106
self.structured_outputs_event = torch.cuda.Event()
107-
self.spec_decode_event = torch.cuda.Event()
108107
else:
109108
self.input_prep_event = None
110109
self.structured_outputs_event = None
111-
self.spec_decode_event = None
112110

113111
if self.speculative_config is not None:
114112
self.do_spec_decode = True
@@ -412,9 +410,6 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None:
412410
cu_num_new_blocks[i].append(x + len(block_ids))
413411
new_block_ids[i].extend(block_ids)
414412
overwrite.append(True)
415-
# Update the GPU tensors for request states.
416-
if scheduler_output.scheduled_new_reqs:
417-
self.req_states.prefill_len.copy_to_gpu()
418413

419414
# Add new blocks for the existing requests.
420415
cached_reqs = scheduler_output.scheduled_cached_reqs
@@ -507,16 +502,16 @@ def prepare_inputs(
507502
query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1]
508503
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
509504

510-
# Copy prefill tokens from CPU to GPU.
505+
# Get prefill tokens.
511506
prepare_prefill_inputs(
512-
idx_mapping_np,
513-
num_scheduled_tokens,
514-
query_start_loc_np,
515-
self.req_states.prefill_token_ids.np,
516-
self.req_states.num_computed_prefill_tokens,
517-
self.input_buffers.input_ids.np,
507+
self.input_buffers.input_ids,
508+
self.req_states.next_prefill_tokens,
509+
idx_mapping,
510+
query_start_loc_gpu,
511+
self.req_states.prefill_token_ids.gpu,
512+
self.req_states.prefill_len.gpu,
513+
self.req_states.num_computed_tokens,
518514
)
519-
self.input_buffers.input_ids.copy_to_gpu(num_tokens)
520515

521516
# Prepare positions and seq_lens.
522517
prepare_pos_seq_lens(
@@ -531,7 +526,7 @@ def prepare_inputs(
531526
# Some input token ids are directly read from the last sampled tokens
532527
# and draft tokens. Also, get the logits indices to sample tokens from.
533528
logits_indices = combine_sampled_and_draft_tokens(
534-
self.input_buffers.input_ids.gpu,
529+
self.input_buffers.input_ids,
535530
idx_mapping,
536531
self.req_states.last_sampled_tokens,
537532
query_start_loc_gpu,
@@ -572,7 +567,7 @@ def prepare_inputs(
572567
kv_cache_config=self.kv_cache_config,
573568
)
574569

575-
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
570+
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
576571
positions = self.input_buffers.positions[:num_tokens_after_padding]
577572
return InputBatch(
578573
req_ids=req_ids,
@@ -782,28 +777,21 @@ def propose_draft(
782777
num_sampled: torch.Tensor,
783778
num_rejected: torch.Tensor,
784779
) -> torch.Tensor:
785-
num_reqs = input_batch.num_reqs
786-
idx_mapping_np = input_batch.idx_mapping_np
787-
with async_barrier(self.spec_decode_event):
788-
self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
789-
self.req_states.prefill_token_ids.np[
790-
idx_mapping_np,
791-
self.req_states.num_computed_prefill_tokens[idx_mapping_np],
792-
]
793-
)
794-
next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu(
795-
num_reqs
796-
)
797-
798780
assert self.speculator is not None
781+
last_sampled_tokens = self.req_states.last_sampled_tokens[
782+
input_batch.idx_mapping
783+
]
784+
next_prefill_tokens = self.req_states.next_prefill_tokens[
785+
input_batch.idx_mapping
786+
]
799787
draft_tokens = self.speculator.propose(
800788
input_batch,
801789
sampling_metadata,
802790
last_hidden_states,
803791
aux_hidden_states,
804792
num_sampled,
805793
num_rejected,
806-
self.req_states.last_sampled_tokens,
794+
last_sampled_tokens,
807795
next_prefill_tokens,
808796
)
809797
return draft_tokens

vllm/v1/worker/gpu/spec_decode/eagle.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def run_model(
121121
num_tokens_across_dp=num_tokens_across_dp,
122122
):
123123
ret_hidden_states = self.model(
124-
input_ids=self.input_buffers.input_ids.gpu[:num_tokens],
124+
input_ids=self.input_buffers.input_ids[:num_tokens],
125125
positions=self.input_buffers.positions[:num_tokens],
126126
hidden_states=self.hidden_states[:num_tokens],
127127
)
@@ -194,7 +194,7 @@ def propose(
194194
num_sampled: torch.Tensor,
195195
# [num_reqs]
196196
num_rejected: torch.Tensor,
197-
# [max_num_reqs, 1]
197+
# [num_reqs]
198198
last_sampled: torch.Tensor,
199199
# [num_reqs]
200200
next_prefill_tokens: torch.Tensor,
@@ -316,7 +316,6 @@ def _prepare_eagle_inputs_kernel(
316316
eagle_positions_ptr,
317317
target_input_ids_ptr,
318318
target_positions_ptr,
319-
idx_mapping_ptr,
320319
last_sampled_ptr,
321320
next_prefill_tokens_ptr,
322321
num_sampled_ptr,
@@ -335,8 +334,7 @@ def _prepare_eagle_inputs_kernel(
335334

336335
num_sampled = tl.load(num_sampled_ptr + batch_idx)
337336
if num_sampled > 0:
338-
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
339-
next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
337+
next_token = tl.load(last_sampled_ptr + batch_idx).to(tl.int32)
340338
else:
341339
# Chunked prefilling.
342340
# Get the next prefill token.
@@ -368,9 +366,9 @@ def prepare_eagle_inputs(
368366
num_sampled: torch.Tensor,
369367
# [num_reqs]
370368
num_rejected: torch.Tensor,
371-
# [max_num_reqs, 1]
369+
# [num_reqs]
372370
last_sampled: torch.Tensor,
373-
# [max_num_reqs]
371+
# [num_reqs]
374372
next_prefill_tokens: torch.Tensor,
375373
) -> torch.Tensor:
376374
num_reqs = input_batch.num_reqs
@@ -381,11 +379,10 @@ def prepare_eagle_inputs(
381379
)
382380
_prepare_eagle_inputs_kernel[(num_reqs,)](
383381
last_token_indices,
384-
input_buffers.input_ids.gpu,
382+
input_buffers.input_ids,
385383
input_buffers.positions,
386384
input_batch.input_ids,
387385
input_batch.positions,
388-
input_batch.idx_mapping,
389386
last_sampled,
390387
next_prefill_tokens,
391388
num_sampled,
@@ -485,7 +482,7 @@ def prepare_eagle_decode(
485482
last_token_indices,
486483
target_seq_lens,
487484
num_rejected,
488-
input_buffers.input_ids.gpu,
485+
input_buffers.input_ids,
489486
input_buffers.positions,
490487
input_hidden_states,
491488
input_hidden_states.stride(0),
@@ -553,7 +550,7 @@ def update_eagle_inputs(
553550
):
554551
num_reqs, hidden_size = output_hidden_states.shape
555552
_update_eagle_inputs_kernel[(num_reqs,)](
556-
input_buffers.input_ids.gpu,
553+
input_buffers.input_ids,
557554
input_buffers.positions,
558555
hidden_states,
559556
hidden_states.stride(0),

vllm/v1/worker/gpu/states.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,7 @@ def __init__(
117117
self.prefill_token_ids = UvaBuffer(
118118
self.max_num_reqs, self.max_model_len, dtype=torch.int32
119119
)
120-
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
121-
120+
self.prefill_len = UvaBuffer(self.max_num_reqs, dtype=torch.int32)
122121
# Number of computed tokens.
123122
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
124123
self.num_computed_tokens = torch.zeros(
@@ -140,6 +139,9 @@ def __init__(
140139
dtype=torch.int64,
141140
device=device,
142141
)
142+
self.next_prefill_tokens = torch.zeros(
143+
self.max_num_reqs, dtype=torch.int32, device=device
144+
)
143145

144146
# LoRA.
145147
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
@@ -380,13 +382,13 @@ def _expand_sampling_metadata_kernel(
380382
expanded_top_p_ptr,
381383
top_k_ptr,
382384
expanded_top_k_ptr,
383-
seeds_ptr,
384385
rep_penalty_ptr,
385386
expanded_rep_penalty_ptr,
386387
freq_penalty_ptr,
387388
expanded_freq_penalty_ptr,
388389
pres_penalty_ptr,
389390
expanded_pres_penalty_ptr,
391+
seeds_ptr,
390392
expanded_seeds_ptr,
391393
cu_num_logits_ptr,
392394
BLOCK_SIZE: tl.constexpr,

0 commit comments

Comments
 (0)