From dbd79b91b48f9ad3b683dabbe39510c932ca4ac8 Mon Sep 17 00:00:00 2001 From: Ting Sun Date: Tue, 23 Jun 2026 22:52:07 +0800 Subject: [PATCH] Fix prompt lookup decoding generating past max_length PromptLookupCandidateGenerator capped the candidate slice against self.max_length, but start_idx/end_idx index into the past input_ids while max_length bounds the output length, so the cap never bound and prompt-lookup decoding generated past max_new_tokens / max_length. Offset it by the current length so the candidate respects the remaining budget, matching AssistedCandidateGenerator. --- .../generation/candidate_generator.py | 3 ++- tests/generation/test_utils.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index bbc0e3bfa931..dc992b6d3c0c 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -1098,7 +1098,8 @@ def get_candidates(self, input_ids: torch.LongTensor, **kwargs) -> tuple[torch.L for idx in match_indices: start_idx = idx + ngram_size end_idx = start_idx + self.num_output_tokens - end_idx = min(end_idx, input_length, self.max_length) + # Offset the output-length cap by the current length so candidates respect the remaining budget. + end_idx = min(end_idx, input_length, start_idx + self.max_length - input_length - 1) if start_idx < end_idx: chosen_ids = input_ids[0, start_idx:end_idx] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b24c212ae57d..ae004eb9cdeb 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -919,6 +919,25 @@ def test_prompt_lookup_decoding_stops_at_eos(self): # PLD shouldn't propose any new tokens based on eos-match self.assertTrue(output_prompt_lookup.shape[-1] == 10) + @pytest.mark.generate + def test_prompt_lookup_decoding_respects_max_length(self): + # `end_idx` indexes into `input_ids`, while `max_length` bounds the output length, so the candidate cap + # must be offset by the current length; otherwise PLD proposes tokens that push generation past `max_length`. + + # The opening bigram is repeated at the end, so the trailing ngram matches early and yields a long continuation. + input_ids = torch.tensor([[10, 11, 12, 13, 14, 15, 16, 17, 10, 11]], device=torch_device) + + # Only `max_length - cur_len - 1` (= 2) tokens of budget remain, although `num_output_tokens` asks for 5. + candidate_generator = PromptLookupCandidateGenerator( + eos_token_id=torch.tensor([0], device=torch_device), + num_output_tokens=5, + max_matching_ngram_size=2, + max_length=input_ids.shape[-1] + 3, + ) + candidates = candidate_generator.get_candidates(input_ids)[0] + num_proposed = candidates.shape[-1] - input_ids.shape[-1] + self.assertLessEqual(num_proposed, candidate_generator.max_length - input_ids.shape[-1] - 1) + @pytest.mark.generate def test_left_padding_compatibility( self, unpadded_custom_inputs: dict | None = None, padded_custom_inputs: dict | None = None