Fix prompt lookup decoding generating past max_length#46993
Open
Sunt-ing wants to merge 1 commit into
Open
Conversation
PromptLookupCandidateGenerator capped the candidate slice against self.max_length, but start_idx/end_idx index into the past input_ids while max_length bounds the output length, so the cap never bound and prompt-lookup decoding generated past max_new_tokens / max_length. Offset it by the current length so the candidate respects the remaining budget, matching AssistedCandidateGenerator.
Contributor
CI recapDashboard: View test results in Grafana |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
generate(..., prompt_lookup_num_tokens=N)can return more new tokens thanmax_new_tokens/max_length, and the overshoot grows withN. Plain greedy decoding stops exactly at the limit, and the overlapping tokens are identical to greedy, so this is not an acceptance or rollback bug, only a missing length cap.The cause is a wrong coordinate system in the candidate length cap, in
PromptLookupCandidateGenerator.get_candidates:start_idx/end_idxare slice indices into the existinginput_ids, whileself.max_lengthis the absolute output-length bound. During generationinput_length <= self.max_lengthalways holds, so theself.max_lengthterm never binds and the candidate length is effectively capped only bynum_output_tokens. The bulk-accepted candidates then push the sequence pastmax_length(the stopping criteria run after the append, so they stop the loop but never trim).AssistedCandidateGeneratoralready does this correctly by subtracting the current length:This PR aligns
PromptLookupCandidateGeneratorwith it, capping the candidate to the remaining budget (one line):With enough budget the new term exceeds
num_output_tokens, so behavior is unchanged; near the limit it tightens so the total length stops exactly atmax_length. Themax_length == input_length + 1case (zero budget) is already handled by the early return at the top of the method.Proof (4 dense models,
ModelTesterconfigs, CPU fp32,prompt_len=10,max_new_tokens=20somax_length=30; greedy withmin_new_tokens=max_new_tokensas the oracle that fills exactly to the limit):The tokens that overlap with greedy stay identical before and after, so accepted content is unchanged; only the illegal overshoot is removed.
Environment, reproduction script, and full before/after output
Environment:
Run from the repo root with
CUDA_VISIBLE_DEVICES="" PYTHONPATH=src:. python repro.py:Before the fix:
After the fix:
A regression test
test_prompt_lookup_decoding_respects_max_lengthis added next totest_prompt_lookup_decoding_stops_at_eos; it builds aPromptLookupCandidateGeneratorwhose budget is smaller thannum_output_tokensand asserts the proposed candidate count respects the remaining budget. It fails on the unpatched code and passes with this change. Theprompt_lookupand assisted decoding tests intests/generation/test_utils.pypass for llama, gpt_neox, granite and mistral.Code Agent Policy
Before submitting
Who can review?
@gante @zucchini-nlp