diff --git a/tunix/generate/vllm_sampler.py b/tunix/generate/vllm_sampler.py index 6ac25d9fb..4b6a5ce27 100644 --- a/tunix/generate/vllm_sampler.py +++ b/tunix/generate/vllm_sampler.py @@ -370,6 +370,7 @@ def __call__( return_logits: bool = True, echo: bool = False, pad_output: bool = False, + stop_strings: Optional[List[str]] = None, ) -> base_sampler.SamplerOutput: """The entry point API for vLLM Sampler""" if not isinstance(input_strings, List): @@ -389,6 +390,7 @@ def __call__( max_tokens=max_generation_steps, ignore_eos=False, temperature=temperature, + stop=stop_strings, ) else: if self._driver is not None: @@ -414,6 +416,8 @@ def __call__( sampling_params.top_p = top_p if top_k is not None: sampling_params.top_k = top_k + if stop_strings is not None: + sampling_params.stop = stop_strings self.sampling_params = sampling_params diff --git a/tunix/rl/rollout/vllm_rollout.py b/tunix/rl/rollout/vllm_rollout.py index 80e2c767e..052567ca1 100644 --- a/tunix/rl/rollout/vllm_rollout.py +++ b/tunix/rl/rollout/vllm_rollout.py @@ -69,6 +69,7 @@ def generate( self, prompts: list[str], rollout_config: base_rollout.RolloutConfig, + stop_strings: Optional[list[str]] = None, **kwargs, ) -> base_rollout.RolloutOutput: """Generates samples from the model."""