Skip to content

Commit bee4fd2

Browse files
njhilldevpatelio
authored andcommitted
[BugFix] Fix chunked prompt logprobs + preemption (vllm-project#29071)
1 parent 4ab8c9b commit bee4fd2

File tree

6 files changed

+127
-31
lines changed

6 files changed

+127
-31
lines changed

tests/conftest.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,7 @@ def generate(
853853
@staticmethod
854854
def _final_steps_generate_w_logprobs(
855855
req_outputs: list[RequestOutput],
856+
include_prompt_token_ids: bool = False,
856857
) -> list[TokensTextLogprobsPromptLogprobs]:
857858
outputs: list[TokensTextLogprobsPromptLogprobs] = []
858859
for req_output in req_outputs:
@@ -861,9 +862,26 @@ def _final_steps_generate_w_logprobs(
861862
output_str = sample.text
862863
output_ids = list(sample.token_ids)
863864
output_logprobs = sample.logprobs
864-
outputs.append(
865-
(output_ids, output_str, output_logprobs, req_output.prompt_logprobs)
866-
)
865+
if include_prompt_token_ids:
866+
outputs.append(
867+
( # type: ignore[arg-type]
868+
output_ids,
869+
output_str,
870+
output_logprobs,
871+
req_output.prompt_token_ids,
872+
req_output.prompt_logprobs,
873+
)
874+
)
875+
else:
876+
outputs.append(
877+
(
878+
output_ids,
879+
output_str,
880+
output_logprobs,
881+
req_output.prompt_logprobs,
882+
)
883+
)
884+
867885
return outputs
868886

869887
def generate_w_logprobs(
@@ -873,6 +891,7 @@ def generate_w_logprobs(
873891
images: PromptImageInput | None = None,
874892
audios: PromptAudioInput | None = None,
875893
videos: PromptVideoInput | None = None,
894+
include_prompt_token_ids: bool = False,
876895
**kwargs: Any,
877896
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
878897
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
@@ -882,7 +901,7 @@ def generate_w_logprobs(
882901
)
883902

884903
toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(
885-
req_outputs
904+
req_outputs, include_prompt_token_ids
886905
)
887906
# Omit prompt logprobs if not required by sampling params
888907
return (

tests/v1/sample/test_logprobs.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,3 +605,79 @@ def test_spec_decode_logprobs(
605605
)
606606
assert ref_logprob.rank == spec_logprob.rank
607607
assert ref_logprob.decoded_token == spec_logprob.decoded_token
608+
609+
610+
def test_prompt_logprobs_with_chunking_and_preemption():
611+
"""Test that prompt logprobs are correctly returned when using
612+
both chunked prefill and preemption.
613+
614+
This test ensures that the num_prompt_logprobs tracking persists
615+
across preemptions and prefill chunks.
616+
"""
617+
618+
# Create prompts that will trigger chunking and preemption
619+
prompts = [
620+
"The following numbers of the sequence "
621+
+ ", ".join(str(i) for i in range(10))
622+
+ " are:",
623+
"In one word, the capital of France is ",
624+
] + [f"Tell me about the number {i}: " for i in range(32)]
625+
626+
sampling_params = SamplingParams(
627+
temperature=0.0,
628+
max_tokens=40,
629+
min_tokens=20,
630+
prompt_logprobs=2, # Request prompt logprobs
631+
)
632+
633+
with VllmRunner(
634+
"Qwen/Qwen3-0.6B",
635+
max_model_len=512,
636+
enable_chunked_prefill=True,
637+
max_num_batched_tokens=48, # Force prefill chunking
638+
num_gpu_blocks_override=32, # Force preemptions
639+
disable_log_stats=False,
640+
gpu_memory_utilization=0.25,
641+
) as vllm_model:
642+
metrics_before = vllm_model.llm.get_metrics()
643+
644+
# Generate with prompt logprobs using generate_w_logprobs which
645+
# returns (output_ids, output_str, output_logprobs, prompt_logprobs)
646+
outputs = vllm_model.generate_w_logprobs(
647+
prompts, sampling_params=sampling_params, include_prompt_token_ids=True
648+
)
649+
650+
# Verify that all outputs have prompt logprobs
651+
for i, output in enumerate(outputs):
652+
_, _, _, prompt_token_ids, prompt_logprobs = output
653+
assert prompt_logprobs is not None and len(prompt_logprobs) > 0, (
654+
f"Output {i} missing prompt logprobs"
655+
)
656+
assert len(prompt_logprobs) == len(prompt_token_ids), (
657+
"Unexpected number of prompt logprob positions"
658+
)
659+
660+
# Each position should have the requested number of logprobs
661+
for pos, logprobs_dict in enumerate(prompt_logprobs):
662+
if logprobs_dict is not None: # First token may be None
663+
assert (
664+
sampling_params.prompt_logprobs
665+
<= len(logprobs_dict)
666+
<= sampling_params.prompt_logprobs + 1
667+
), (
668+
f"Output {i} position {pos} has {len(logprobs_dict)} "
669+
f"logprobs, expected {sampling_params.prompt_logprobs}"
670+
)
671+
672+
# Check that we actually had preemptions
673+
metrics_after = vllm_model.llm.get_metrics()
674+
preemptions_before = next(
675+
(m.value for m in metrics_before if m.name == "vllm:num_preemptions"), 0
676+
)
677+
preemptions_after = next(
678+
(m.value for m in metrics_after if m.name == "vllm:num_preemptions"), 0
679+
)
680+
preemptions = preemptions_after - preemptions_before
681+
assert preemptions > 0, "Test did not trigger any preemptions"
682+
683+
print(f"Test passed with {preemptions} preemptions")

vllm/v1/worker/gpu_input_batch.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,6 @@ def __init__(
219219
self.generators: dict[int, torch.Generator] = {}
220220

221221
self.num_logprobs: dict[str, int] = {}
222-
# NOTE(rob): num_prompt_logprobs only includes reqs
223-
# that are currently in the prefill phase.
224-
self.num_prompt_logprobs: dict[str, int] = {}
225222

226223
# To accumulate prompt logprobs tensor chunks across prefill steps.
227224
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
@@ -385,12 +382,6 @@ def add_request(
385382
if sampling_params.logprobs == -1
386383
else sampling_params.logprobs
387384
)
388-
if sampling_params.prompt_logprobs is not None:
389-
self.num_prompt_logprobs[req_id] = (
390-
self.vocab_size
391-
if sampling_params.prompt_logprobs == -1
392-
else sampling_params.prompt_logprobs
393-
)
394385

395386
if sampling_params.allowed_token_ids:
396387
self.has_allowed_token_ids.add(req_id)
@@ -488,7 +479,6 @@ def remove_request(self, req_id: str) -> int | None:
488479
self.repetition_penalties_reqs.discard(req_id)
489480
self.generators.pop(req_index, None)
490481
self.num_logprobs.pop(req_id, None)
491-
self.num_prompt_logprobs.pop(req_id, None)
492482
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
493483

494484
self.has_allowed_token_ids.discard(req_id)
@@ -972,10 +962,6 @@ def no_penalties(self) -> bool:
972962
def max_num_logprobs(self) -> int | None:
973963
return max(self.num_logprobs.values()) if self.num_logprobs else None
974964

975-
@property
976-
def no_prompt_logprob(self) -> bool:
977-
return not self.num_prompt_logprobs
978-
979965
@property
980966
def no_allowed_token_ids(self) -> bool:
981967
return len(self.has_allowed_token_ids) == 0

vllm/v1/worker/gpu_model_runner.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,9 @@ def __init__(
393393

394394
# Request states.
395395
self.requests: dict[str, CachedRequestState] = {}
396+
# NOTE(rob): num_prompt_logprobs only includes reqs
397+
# that are currently in the prefill phase.
398+
self.num_prompt_logprobs: dict[str, int] = {}
396399
self.comm_stream = torch.cuda.Stream()
397400

398401
# Input Batch
@@ -687,6 +690,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
687690
# Remove finished requests from the cached states.
688691
for req_id in scheduler_output.finished_req_ids:
689692
self.requests.pop(req_id, None)
693+
self.num_prompt_logprobs.pop(req_id, None)
690694
# Remove the finished requests from the persistent batch.
691695
# NOTE(woosuk): There could be an edge case where finished_req_ids and
692696
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -755,6 +759,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
755759
)
756760
self.requests[req_id] = req_state
757761

762+
if sampling_params and sampling_params.prompt_logprobs is not None:
763+
self.num_prompt_logprobs[req_id] = (
764+
self.input_batch.vocab_size
765+
if sampling_params.prompt_logprobs == -1
766+
else sampling_params.prompt_logprobs
767+
)
768+
758769
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
759770
if self.uses_mrope:
760771
self._init_mrope_positions(req_state)
@@ -2671,7 +2682,7 @@ def execute_model(
26712682
scheduler_output, self.vllm_config
26722683
)
26732684
if self.cache_config.kv_sharing_fast_prefill:
2674-
assert not self.input_batch.num_prompt_logprobs, (
2685+
assert not self.num_prompt_logprobs, (
26752686
"--kv-sharing-fast-prefill produces incorrect "
26762687
"logprobs for prompt tokens, tokens, please disable "
26772688
"it when the requests need prompt logprobs"
@@ -3436,7 +3447,7 @@ def _get_prompt_logprobs_dict(
34363447
hidden_states: torch.Tensor,
34373448
num_scheduled_tokens: dict[str, int],
34383449
) -> dict[str, LogprobsTensors | None]:
3439-
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
3450+
num_prompt_logprobs_dict = self.num_prompt_logprobs
34403451
if not num_prompt_logprobs_dict:
34413452
return {}
34423453

@@ -3447,7 +3458,10 @@ def _get_prompt_logprobs_dict(
34473458
# maintainable loop over optimal performance.
34483459
completed_prefill_reqs = []
34493460
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
3450-
num_tokens = num_scheduled_tokens[req_id]
3461+
num_tokens = num_scheduled_tokens.get(req_id)
3462+
if num_tokens is None:
3463+
# This can happen if the request was preempted in prefill stage.
3464+
continue
34513465

34523466
# Get metadata for this request.
34533467
request = self.requests[req_id]

vllm/v1/worker/tpu_input_batch.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,6 @@ def __init__(
149149
self.generators: dict[int, torch.Generator] = {}
150150

151151
self.num_logprobs: dict[str, int] = {}
152-
# NOTE(rob): num_prompt_logprobs only includes reqs
153-
# that are currently in the prefill phase.
154-
self.num_prompt_logprobs: dict[str, int] = {}
155152

156153
# To accumulate prompt logprobs tensor chunks across prefill steps.
157154
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
@@ -256,8 +253,6 @@ def add_request(
256253

257254
if sampling_params.logprobs is not None:
258255
self.num_logprobs[req_id] = sampling_params.logprobs
259-
if sampling_params.prompt_logprobs is not None:
260-
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
261256
if sampling_params.logit_bias is not None:
262257
self.logit_bias[req_index] = sampling_params.logit_bias
263258

@@ -317,7 +312,6 @@ def remove_request(self, req_id: str) -> int | None:
317312
self.repetition_penalties_reqs.discard(req_id)
318313
self.generators.pop(req_index, None)
319314
self.num_logprobs.pop(req_id, None)
320-
self.num_prompt_logprobs.pop(req_id, None)
321315
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
322316

323317
# LoRA
@@ -584,10 +578,6 @@ def no_penalties(self) -> bool:
584578
def max_num_logprobs(self) -> int | None:
585579
return max(self.num_logprobs.values()) if self.num_logprobs else None
586580

587-
@property
588-
def no_prompt_logprob(self) -> bool:
589-
return not self.num_prompt_logprobs
590-
591581
@property
592582
def no_allowed_token_ids(self) -> bool:
593583
return len(self.has_allowed_token_ids) == 0

vllm/v1/worker/tpu_model_runner.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ def __init__(
247247

248248
# Request states.
249249
self.requests: dict[str, CachedRequestState] = {}
250+
# NOTE(rob): num_prompt_logprobs only includes reqs
251+
# that are currently in the prefill phase.
252+
self.num_prompt_logprobs: dict[str, int] = {}
250253

251254
# Initialize input batch early to avoid AttributeError in _update_states
252255
self.input_batch = InputBatch(
@@ -420,6 +423,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
420423
# Remove finished requests from the cached states.
421424
for req_id in scheduler_output.finished_req_ids:
422425
self.requests.pop(req_id, None)
426+
self.num_prompt_logprobs.pop(req_id, None)
423427

424428
# Remove the finished requests from the persistent batch.
425429
# NOTE(woosuk): There could be an edge case where finished_req_ids and
@@ -477,6 +481,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
477481
lora_request=new_req_data.lora_request,
478482
)
479483

484+
if sampling_params and sampling_params.prompt_logprobs is not None:
485+
self.num_prompt_logprobs[req_id] = (
486+
self.input_batch.vocab_size
487+
if sampling_params.prompt_logprobs == -1
488+
else sampling_params.prompt_logprobs
489+
)
490+
480491
req_ids_to_add.append(req_id)
481492

482493
# Update the states of the running/resumed requests.

0 commit comments

Comments
 (0)