Skip to content

Fix prompt lookup decoding generating past max_length#46993

Open
Sunt-ing wants to merge 1 commit into
huggingface:mainfrom
Sunt-ing:1
Open

Fix prompt lookup decoding generating past max_length#46993
Sunt-ing wants to merge 1 commit into
huggingface:mainfrom
Sunt-ing:1

Conversation

@Sunt-ing

@Sunt-ing Sunt-ing commented Jul 1, 2026

Copy link
Copy Markdown
Contributor

CI

What does this PR do?

generate(..., prompt_lookup_num_tokens=N) can return more new tokens than max_new_tokens / max_length, and the overshoot grows with N. Plain greedy decoding stops exactly at the limit, and the overlapping tokens are identical to greedy, so this is not an acceptance or rollback bug, only a missing length cap.

The cause is a wrong coordinate system in the candidate length cap, in PromptLookupCandidateGenerator.get_candidates:

start_idx = idx + ngram_size            # index into the past `input_ids`
end_idx = start_idx + self.num_output_tokens
end_idx = min(end_idx, input_length, self.max_length)

start_idx / end_idx are slice indices into the existing input_ids, while self.max_length is the absolute output-length bound. During generation input_length <= self.max_length always holds, so the self.max_length term never binds and the candidate length is effectively capped only by num_output_tokens. The bulk-accepted candidates then push the sequence past max_length (the stopping criteria run after the append, so they stop the loop but never trim).

AssistedCandidateGenerator already does this correctly by subtracting the current length:

max_new_tokens = min(int(self.num_assistant_tokens), self.main_model_max_length - new_cur_len - 1)

This PR aligns PromptLookupCandidateGenerator with it, capping the candidate to the remaining budget (one line):

end_idx = min(end_idx, input_length, start_idx + self.max_length - input_length - 1)

With enough budget the new term exceeds num_output_tokens, so behavior is unchanged; near the limit it tightens so the total length stops exactly at max_length. The max_length == input_length + 1 case (zero budget) is already handled by the early return at the top of the method.

Proof (4 dense models, ModelTester configs, CPU fp32, prompt_len=10, max_new_tokens=20 so max_length=30; greedy with min_new_tokens=max_new_tokens as the oracle that fills exactly to the limit):

total length greedy N=2 N=3 N=5
before 30 30 31 33
after 30 30 30 30

The tokens that overlap with greedy stay identical before and after, so accepted content is unchanged; only the illegal overshoot is removed.

Environment, reproduction script, and full before/after output

Environment:

transformers 5.13.0.dev0 (this branch)
torch 2.8.0+cu128
python 3.12.3
Linux x86_64, CPU fp32 (CUDA_VISIBLE_DEVICES="")

Run from the repo root with CUDA_VISIBLE_DEVICES="" PYTHONPATH=src:. python repro.py:

import os, importlib
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import torch
from transformers import AutoModelForCausalLM

torch.manual_seed(0)

def build(tester_mod, tester_cls):
    mod = importlib.import_module(tester_mod)
    Tester = getattr(mod, tester_cls)
    class Dummy:
        def __getattr__(self, n): return lambda *a, **k: None
    t = Tester(Dummy())
    cfg = t.get_config()
    cfg.vocab_size = max(cfg.vocab_size, 50)
    torch.manual_seed(0)
    return AutoModelForCausalLM.from_config(cfg).to(torch.float32).eval()

def gen(model, input_ids, max_new, **extra):
    with torch.no_grad():
        torch.manual_seed(0)
        return model.generate(
            input_ids, attention_mask=torch.ones_like(input_ids),
            max_new_tokens=max_new, min_new_tokens=max_new,  # force greedy to fill exactly K
            do_sample=False, num_beams=1,
            eos_token_id=2, pad_token_id=0, **extra,
        )[0].tolist()

MAXNEW = 20
prompt = torch.tensor([[5, 6, 7, 8, 9, 10, 5, 6, 7, 8]], dtype=torch.long)
PL = prompt.shape[1]

for name, tmod, tcls in [
    ("gpt_neox", "tests.models.gpt_neox.test_modeling_gpt_neox", "GPTNeoXModelTester"),
    ("mistral",  "tests.models.mistral.test_modeling_mistral",  "MistralModelTester"),
    ("granite",  "tests.models.granite.test_modeling_granite",  "GraniteModelTester"),
    ("llama",    "tests.models.llama.test_modeling_llama",      "LlamaModelTester"),
]:
    model = build(tmod, tcls)
    g = gen(model, prompt, MAXNEW)
    print(f"\n=== {name}  prompt_len={PL}  max_new_tokens={MAXNEW}  max_length={PL+MAXNEW} ===")
    print(f"  greedy:                      total={len(g)}  new={len(g)-PL}")
    for N in (2, 3, 5):
        p = gen(model, prompt, MAXNEW, prompt_lookup_num_tokens=N)
        ov = len(p) - (PL + MAXNEW)
        k = min(len(g), len(p))
        print(f"  prompt_lookup_num_tokens={N}:  total={len(p)}  new={len(p)-PL}  "
              f"exceeds_max_length=+{ov}  overlap_eq_greedy={g[:k]==p[:k]}")

Before the fix:

=== gpt_neox  prompt_len=10  max_new_tokens=20  max_length=30 ===
  greedy:                      total=30  new=20
  prompt_lookup_num_tokens=2:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
  prompt_lookup_num_tokens=3:  total=31  new=21  exceeds_max_length=+1  overlap_eq_greedy=True
  prompt_lookup_num_tokens=5:  total=33  new=23  exceeds_max_length=+3  overlap_eq_greedy=True
=== mistral  prompt_len=10  max_new_tokens=20  max_length=30 ===
  greedy:                      total=30  new=20
  prompt_lookup_num_tokens=2:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
  prompt_lookup_num_tokens=3:  total=31  new=21  exceeds_max_length=+1  overlap_eq_greedy=True
  prompt_lookup_num_tokens=5:  total=33  new=23  exceeds_max_length=+3  overlap_eq_greedy=True
=== granite  prompt_len=10  max_new_tokens=20  max_length=30 ===
  greedy:                      total=30  new=20
  prompt_lookup_num_tokens=2:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
  prompt_lookup_num_tokens=3:  total=31  new=21  exceeds_max_length=+1  overlap_eq_greedy=True
  prompt_lookup_num_tokens=5:  total=32  new=22  exceeds_max_length=+2  overlap_eq_greedy=True
=== llama  prompt_len=10  max_new_tokens=20  max_length=30 ===
  greedy:                      total=30  new=20
  prompt_lookup_num_tokens=2:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
  prompt_lookup_num_tokens=3:  total=31  new=21  exceeds_max_length=+1  overlap_eq_greedy=True
  prompt_lookup_num_tokens=5:  total=33  new=23  exceeds_max_length=+3  overlap_eq_greedy=True

After the fix:

=== gpt_neox  prompt_len=10  max_new_tokens=20  max_length=30 ===
  greedy:                      total=30  new=20
  prompt_lookup_num_tokens=2:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
  prompt_lookup_num_tokens=3:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
  prompt_lookup_num_tokens=5:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
=== mistral  prompt_len=10  max_new_tokens=20  max_length=30 ===
  greedy:                      total=30  new=20
  prompt_lookup_num_tokens=2:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
  prompt_lookup_num_tokens=3:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
  prompt_lookup_num_tokens=5:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
=== granite  prompt_len=10  max_new_tokens=20  max_length=30 ===
  greedy:                      total=30  new=20
  prompt_lookup_num_tokens=2:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
  prompt_lookup_num_tokens=3:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
  prompt_lookup_num_tokens=5:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
=== llama  prompt_len=10  max_new_tokens=20  max_length=30 ===
  greedy:                      total=30  new=20
  prompt_lookup_num_tokens=2:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
  prompt_lookup_num_tokens=3:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True
  prompt_lookup_num_tokens=5:  total=30  new=20  exceeds_max_length=+0  overlap_eq_greedy=True

A regression test test_prompt_lookup_decoding_respects_max_length is added next to test_prompt_lookup_decoding_stops_at_eos; it builds a PromptLookupCandidateGenerator whose budget is smaller than num_output_tokens and asserts the proposed candidate count respects the remaining budget. It fails on the unpatched code and passes with this change. The prompt_lookup and assisted decoding tests in tests/generation/test_utils.py pass for llama, gpt_neox, granite and mistral.

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline and the Pull Request checks?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes according to the guidelines?
  • Did you write any new necessary tests?

Who can review?

@gante @zucchini-nlp

PromptLookupCandidateGenerator capped the candidate slice against
self.max_length, but start_idx/end_idx index into the past input_ids
while max_length bounds the output length, so the cap never bound and
prompt-lookup decoding generated past max_new_tokens / max_length.
Offset it by the current length so the candidate respects the remaining
budget, matching AssistedCandidateGenerator.
@github-actions

github-actions Bot commented Jul 1, 2026

Copy link
Copy Markdown
Contributor

CI recap

Dashboard: View test results in Grafana
Latest run: 28498833441:2
Result: failure | Jobs: 14 | Tests: 52,095 | Failures: 0 | Duration: 15h 48m

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant