@@ -605,3 +605,79 @@ def test_spec_decode_logprobs(
605605 )
606606 assert ref_logprob .rank == spec_logprob .rank
607607 assert ref_logprob .decoded_token == spec_logprob .decoded_token
608+
609+
610+ def test_prompt_logprobs_with_chunking_and_preemption ():
611+ """Test that prompt logprobs are correctly returned when using
612+ both chunked prefill and preemption.
613+
614+ This test ensures that the num_prompt_logprobs tracking persists
615+ across preemptions and prefill chunks.
616+ """
617+
618+ # Create prompts that will trigger chunking and preemption
619+ prompts = [
620+ "The following numbers of the sequence "
621+ + ", " .join (str (i ) for i in range (10 ))
622+ + " are:" ,
623+ "In one word, the capital of France is " ,
624+ ] + [f"Tell me about the number { i } : " for i in range (32 )]
625+
626+ sampling_params = SamplingParams (
627+ temperature = 0.0 ,
628+ max_tokens = 40 ,
629+ min_tokens = 20 ,
630+ prompt_logprobs = 2 , # Request prompt logprobs
631+ )
632+
633+ with VllmRunner (
634+ "Qwen/Qwen3-0.6B" ,
635+ max_model_len = 512 ,
636+ enable_chunked_prefill = True ,
637+ max_num_batched_tokens = 48 , # Force prefill chunking
638+ num_gpu_blocks_override = 32 , # Force preemptions
639+ disable_log_stats = False ,
640+ gpu_memory_utilization = 0.25 ,
641+ ) as vllm_model :
642+ metrics_before = vllm_model .llm .get_metrics ()
643+
644+ # Generate with prompt logprobs using generate_w_logprobs which
645+ # returns (output_ids, output_str, output_logprobs, prompt_logprobs)
646+ outputs = vllm_model .generate_w_logprobs (
647+ prompts , sampling_params = sampling_params , include_prompt_token_ids = True
648+ )
649+
650+ # Verify that all outputs have prompt logprobs
651+ for i , output in enumerate (outputs ):
652+ _ , _ , _ , prompt_token_ids , prompt_logprobs = output
653+ assert prompt_logprobs is not None and len (prompt_logprobs ) > 0 , (
654+ f"Output { i } missing prompt logprobs"
655+ )
656+ assert len (prompt_logprobs ) == len (prompt_token_ids ), (
657+ "Unexpected number of prompt logprob positions"
658+ )
659+
660+ # Each position should have the requested number of logprobs
661+ for pos , logprobs_dict in enumerate (prompt_logprobs ):
662+ if logprobs_dict is not None : # First token may be None
663+ assert (
664+ sampling_params .prompt_logprobs
665+ <= len (logprobs_dict )
666+ <= sampling_params .prompt_logprobs + 1
667+ ), (
668+ f"Output { i } position { pos } has { len (logprobs_dict )} "
669+ f"logprobs, expected { sampling_params .prompt_logprobs } "
670+ )
671+
672+ # Check that we actually had preemptions
673+ metrics_after = vllm_model .llm .get_metrics ()
674+ preemptions_before = next (
675+ (m .value for m in metrics_before if m .name == "vllm:num_preemptions" ), 0
676+ )
677+ preemptions_after = next (
678+ (m .value for m in metrics_after if m .name == "vllm:num_preemptions" ), 0
679+ )
680+ preemptions = preemptions_after - preemptions_before
681+ assert preemptions > 0 , "Test did not trigger any preemptions"
682+
683+ print (f"Test passed with { preemptions } preemptions" )
0 commit comments