Skip to content

Stop sampler decode one step earlier#1631

Open
fallintoplace wants to merge 1 commit into
google:mainfrom
fallintoplace:fix/sampler-decode-off-by-one
Open

Stop sampler decode one step earlier#1631
fallintoplace wants to merge 1 commit into
google:mainfrom
fallintoplace:fix/sampler-decode-off-by-one

Conversation

@fallintoplace

Copy link
Copy Markdown

What changed

This stops tunix.generate.sampler.Sampler._decode_fn() at the last writable token slot instead of allowing one extra decode iteration after prefill.

It also adds a regression test that exercises max_generation_steps=1 and verifies that _decode_fn() leaves the sampler state unchanged after prefill has already produced the single generated token.

Why

init_sample_state() starts decoding_step at the last prompt token.
_prefill_fn() immediately calls _sample(), which writes the first generated token at decoding_step + 1 and advances the step.

That means by the time decode starts, prefill has already consumed one generation step. The old decode condition still allowed another iteration when decoding_step == total_sampling_steps - 1, even though the next write would target decoding_step + 1 beyond the intended writable range.

Impact

max_generation_steps=1 now performs exactly one generation step instead of running an extra decode iteration. The same boundary is respected for longer generations as well.

Validation

  • uv run --extra test python tests/generate/sampler_test.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants