From 4e151319d2cd8790ba507fbc8181f184368a10db Mon Sep 17 00:00:00 2001 From: Anisha Mazumder Date: Thu, 29 Jan 2026 18:42:22 -0800 Subject: [PATCH] Add support for stop strings in vLLM sampler and rollout. PiperOrigin-RevId: 862995259 --- tunix/generate/vllm_sampler.py | 4 ++++ tunix/rl/rollout/vllm_rollout.py | 1 + 2 files changed, 5 insertions(+) diff --git a/tunix/generate/vllm_sampler.py b/tunix/generate/vllm_sampler.py index 6ac25d9fb..4b6a5ce27 100644 --- a/tunix/generate/vllm_sampler.py +++ b/tunix/generate/vllm_sampler.py @@ -370,6 +370,7 @@ def __call__( return_logits: bool = True, echo: bool = False, pad_output: bool = False, + stop_strings: Optional[List[str]] = None, ) -> base_sampler.SamplerOutput: """The entry point API for vLLM Sampler""" if not isinstance(input_strings, List): @@ -389,6 +390,7 @@ def __call__( max_tokens=max_generation_steps, ignore_eos=False, temperature=temperature, + stop=stop_strings, ) else: if self._driver is not None: @@ -414,6 +416,8 @@ def __call__( sampling_params.top_p = top_p if top_k is not None: sampling_params.top_k = top_k + if stop_strings is not None: + sampling_params.stop = stop_strings self.sampling_params = sampling_params diff --git a/tunix/rl/rollout/vllm_rollout.py b/tunix/rl/rollout/vllm_rollout.py index 80e2c767e..052567ca1 100644 --- a/tunix/rl/rollout/vllm_rollout.py +++ b/tunix/rl/rollout/vllm_rollout.py @@ -69,6 +69,7 @@ def generate( self, prompts: list[str], rollout_config: base_rollout.RolloutConfig, + stop_strings: Optional[list[str]] = None, **kwargs, ) -> base_rollout.RolloutOutput: """Generates samples from the model."""