diff --git a/tests/test_multiturn_env.py b/tests/test_multiturn_env.py index 218372141..a49213ec6 100644 --- a/tests/test_multiturn_env.py +++ b/tests/test_multiturn_env.py @@ -520,3 +520,118 @@ async def test_responses_stored_in_state(self, mock_multiturn_env): for step in state["trajectory"]: assert hasattr(step["response"], "choices") assert len(step["response"].choices) > 0 + + @pytest.mark.asyncio + async def test_stop_condition_triggered_by_env_response( + self, mock_openai_client, sample_chat_dataset + ): + """Test that rollout stops gracefully when env_response triggers a stop condition. + + This tests the edge case where get_prompt_messages() -> env_response() causes + a stop condition to become true before get_model_response() is called. + The rollout should add a trajectory step with empty completion and return. + """ + + class EnvResponseStopsEnv(MultiTurnEnv): + @stop + async def env_triggered_stop(self, state: State) -> bool: + return state.get("env_says_stop", False) + + async def env_response( + self, messages: Messages, state: State, **kwargs + ) -> Messages: + # On second turn, env_response triggers stop condition + state["env_says_stop"] = True + return [{"role": "user", "content": "This triggers stop"}] + + env = EnvResponseStopsEnv( + client=mock_openai_client, + model="test-model", + dataset=sample_chat_dataset, + max_turns=5, + parser=Parser(), + rubric=Rubric(), + ) + + # First response - normal + mock_openai_client.add_chat_response( + messages=[{"role": "user", "content": "Start"}], + response="First response", + ) + + prompt: Messages = [{"role": "user", "content": "Start"}] + state = await env.rollout( + input=RolloutInput( + prompt=prompt, + answer="test", + example_id=0, + ), + client=mock_openai_client, + model="test-model", + ) + + # Should have 2 trajectory steps: + # 1. First turn with actual model response + # 2. Second turn where env_response triggered stop (empty completion) + assert len(state["trajectory"]) == 2 + assert state["is_completed"] is True + assert state["stop_condition"] == "env_triggered_stop" + + # First step should have normal completion + assert state["trajectory"][0]["completion"] != [] + assert state["trajectory"][0]["response"] is not None + + # Second step should have empty completion (no model was called) + assert state["trajectory"][1]["completion"] == [] + assert state["trajectory"][1]["response"] is None + + @pytest.mark.asyncio + async def test_stop_condition_triggered_by_env_response_completion_mode( + self, mock_openai_client + ): + """Test early exit in completion mode uses empty string, not empty list.""" + + class EnvResponseStopsEnv(MultiTurnEnv): + def __init__(self, **kwargs): + super().__init__(message_type="completion", **kwargs) + + @stop + async def env_triggered_stop(self, state: State) -> bool: + return state.get("env_says_stop", False) + + async def env_response( + self, messages: Messages, state: State, **kwargs + ) -> Messages: + state["env_says_stop"] = True + return " [stop]" + + dataset = Dataset.from_dict({"prompt": ["Start:"], "answer": ["test"]}) + env = EnvResponseStopsEnv( + client=mock_openai_client, + model="test-model", + dataset=dataset, + max_turns=5, + ) + + mock_openai_client.add_text_response("Start:", "First response") + + state = await env.rollout( + input=RolloutInput( + prompt="Start:", + answer="test", + example_id=0, + ), + client=mock_openai_client, + model="test-model", + ) + + assert len(state["trajectory"]) == 2 + assert state["is_completed"] is True + assert state["stop_condition"] == "env_triggered_stop" + + # Second step should have empty string completion (not empty list) + assert state["trajectory"][1]["completion"] == "" + assert state["trajectory"][1]["response"] is None + + # Final completion should be a string + assert isinstance(state["completion"], str) diff --git a/verifiers/envs/multiturn_env.py b/verifiers/envs/multiturn_env.py index 93792a564..82633d4e4 100644 --- a/verifiers/envs/multiturn_env.py +++ b/verifiers/envs/multiturn_env.py @@ -67,16 +67,25 @@ async def add_model_response( self, state: State, prompt_messages: Messages, - response: ModelResponse, + response: ModelResponse | None, ): - completion_messages = await parse_response_messages(response, self.message_type) - response_is_truncated = await parse_is_truncated(response, self.message_type) - tokens = await parse_response_tokens( - response, self.message_type, self.max_seq_len - ) - is_truncated = response_is_truncated or ( - tokens is not None and bool(tokens.get("is_truncated")) - ) + if response is not None: + completion_messages = await parse_response_messages( + response, self.message_type + ) + response_is_truncated = await parse_is_truncated( + response, self.message_type + ) + tokens = await parse_response_tokens( + response, self.message_type, self.max_seq_len + ) + is_truncated = response_is_truncated or ( + tokens is not None and bool(tokens.get("is_truncated")) + ) + else: + completion_messages = "" if self.message_type == "completion" else [] + tokens = None + is_truncated = False trajectory_step = TrajectoryStep( prompt=prompt_messages, completion=completion_messages, @@ -108,6 +117,10 @@ async def rollout( while not await self.is_completed(state): try: prompt_messages = await self.get_prompt_messages(state) + if await self.is_completed(state): + await self.add_model_response(state, prompt_messages, None) + await self._render_completion(state) + return state response = await self.get_model_response(state, prompt_messages) await self.add_model_response(state, prompt_messages, response) except vf.Error as e: diff --git a/verifiers/envs/tool_env.py b/verifiers/envs/tool_env.py index 6acd4cca6..b8fa12e4e 100644 --- a/verifiers/envs/tool_env.py +++ b/verifiers/envs/tool_env.py @@ -51,7 +51,10 @@ def remove_tool(self, tool: Callable): async def no_tools_called(self, state: vf.State) -> bool: if len(state["trajectory"]) == 0: return False - last_message = state["trajectory"][-1]["completion"][-1] + last_completion = state["trajectory"][-1]["completion"] + if not last_completion: + return False + last_message = last_completion[-1] is_assistant_message = last_message["role"] == "assistant" no_tool_calls = ( "tool_calls" not in last_message or last_message["tool_calls"] is None diff --git a/verifiers/types.py b/verifiers/types.py index f9b986aea..dea69fc4e 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -63,7 +63,7 @@ class TrajectoryStepTokens(TypedDict): class TrajectoryStep(TypedDict): prompt: Messages completion: Messages - response: ModelResponse + response: ModelResponse | None tokens: TrajectoryStepTokens | None reward: float | None advantage: float | None