Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 93 additions & 45 deletions tests/rl/agentic/trajectory/trajectory_collect_engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,32 @@ def setUp(self):
self.mock_env = mock.create_autospec(
base_environment.BaseTaskEnv, instance=True
)

self.mock_env.max_steps = 10

self.mock_model_call = mock.Mock()
self.mock_final_reward_fn = mock.Mock(
return_value=reward_types.RewardOutput(reward=0.5)
)
self.mock_tokenizer = mock.Mock()
self.mock_tokenizer.encode.return_value = [1, 2, 3]
self.mock_chat_parser = mock.Mock()

# Configure mock agent
self.trajectory = agent_types.Trajectory()
self.mock_agent.trajectory = self.trajectory
self.mock_agent.chat_completions = []
self.current_step = None

self._chat_history = []
self.mock_agent.chat_completions = self._chat_history

self.current_step = None

def _update_from_model(resp):
self.current_step = agent_types.Step(
model_response=resp, action=agent_types.Action(action=['action'])
)
self.trajectory.steps.append(self.current_step)
self.mock_agent.chat_completions.append(
{'role': 'assistant', 'content': resp}
)
self._chat_history.append({'role': 'assistant', 'content': resp})
return self.current_step

def _update_from_env(observation, reward, done, info):
Expand All @@ -64,40 +69,36 @@ def _update_from_env(observation, reward, done, info):
self.current_step.reward = reward
self.current_step.done = done
self.current_step.info = info
self.mock_agent.chat_completions.append(
{'role': 'user', 'content': observation}
)

def _get_current_state():
return self.current_step
self._chat_history.append({'role': 'user', 'content': observation})

def _reset_agent():
self.trajectory.steps.clear()
self.mock_agent.chat_completions.clear()
self._chat_history.clear() # Clear the local list
self.current_step = None

self.mock_agent.update_from_model.side_effect = _update_from_model
self.mock_agent.update_from_env.side_effect = _update_from_env
self.mock_agent.get_current_state.side_effect = _get_current_state
self.mock_agent.reset.side_effect = _reset_agent
self.mock_agent.get_current_state.side_effect = lambda: self.current_step

# Configure mock env
self.mock_env.reset.return_value = ('initial_obs', {})
# Let it run for 2 steps then done
self.mock_env.step.side_effect = [
('obs1', 1.0, False, {}),
('obs2', 2.0, True, {}),
]
self.mock_env.task = {'some': 'task'}
self.mock_env.extra_kwargs = {'group_id': 1}
self.mock_env.extra_kwargs = {}
self.trajectory.task = self.mock_env.task

# Configure mock model call
self.mock_model_call.side_effect = ['response1', 'response2']
self.mock_model_call.side_effect = ['response1', 'response2', 'response3', 'response4', 'response5']

async def _run_collect(self, engine, mode='Trajectory'):
return await engine.collect(mode=mode)

def test_collect_trajectory_mode(self):
self.mock_env.max_steps = 5
engine = trajectory_collect_engine.TrajectoryCollectEngine(
agent=self.mock_agent,
env=self.mock_env,
Expand Down Expand Up @@ -134,6 +135,7 @@ def test_collect_conversation_mode(self):
agent=self.mock_agent,
env=self.mock_env,
model_call=self.mock_model_call,
max_context_limit=1024,
)
conversation = asyncio.run(self._run_collect(engine, mode='Conversation'))

Expand Down Expand Up @@ -161,6 +163,7 @@ def test_collect_with_tokenization(self, mock_convert):
model_call=self.mock_model_call,
tokenizer=self.mock_tokenizer,
chat_parser=self.mock_chat_parser,
max_context_limit=1024,
)
token_data = asyncio.run(self._run_collect(engine, mode='Token'))
expected_tokens = {
Expand All @@ -176,8 +179,9 @@ def test_collect_with_tokenization(self, mock_convert):
'conversation_masks': [1, 1, 1, 1, 1, 1, 1, 1],
'trajectory_reward': 3.0, # 1.0 + 2.0
'policy_version': None,
'original_input': None,
'group_id': 1,
'original_input': {'some': 'task'},
'group_id': None,
'status': 'SUCCEEDED',
}
self.assertEqual(token_data, expected_tokens)

Expand All @@ -196,8 +200,7 @@ def test_collect_with_tokenization(self, mock_convert):
)
self.assertFalse(
mock_convert.call_args_list[0].kwargs['contains_generation_msg'],
'contains_generation_msg should be False for initial prompt'
' tokenization',
'contains_generation_msg should be False for initial prompt tokenization',
)

# Verify that tokenization for model responses and environment observations
Expand Down Expand Up @@ -250,29 +253,6 @@ def test_collect_with_incomplete_tokenizer_config_skips_tokenization(
asyncio.run(self._run_collect(engine))
mock_tokenize.assert_not_called()

def test_collect_timeout(self):
with mock.patch.object(trajectory_collect_engine.time, 'time') as mock_time:
mock_time.side_effect = [
100.0, # start time in _reset
100.05, # time check in _one_step (1st call)
100.11, # time check in _one_step (2nd call) -> timeout
100.12, # time access in logging.warning
]
engine = trajectory_collect_engine.TrajectoryCollectEngine(
agent=self.mock_agent,
env=self.mock_env,
model_call=self.mock_model_call,
timeout=0.1,
)
result_traj = asyncio.run(self._run_collect(engine, mode='Trajectory'))

# Should run for two steps, with the second one timing out and marked as
# done
self.assertLen(result_traj.steps, 2)
self.assertFalse(result_traj.steps[0].done)
self.assertTrue(result_traj.steps[1].done)
self.assertEqual(self.mock_env.step.call_count, 2)

async def _run_collect_multiple(self, engine_args, pairs):
results = []
async for (
Expand Down Expand Up @@ -328,6 +308,8 @@ def _reset_agent():
env1.reset.return_value = ('initial1', {})
env1.step.return_value = ('obs1', 1.0, True, {})
env1.task = {}
env1.extra_kwargs = {}
env1.max_steps = 5

agent2 = configure_mock_agent('initial2')
env2 = mock.create_autospec(base_environment.BaseTaskEnv, instance=True)
Expand All @@ -337,11 +319,12 @@ def _reset_agent():
('obs2b', 2.1, True, {}),
]
env2.task = {}
env2.extra_kwargs = {}
env2.max_steps = 5

pairs = [(agent1, env1), (agent2, env2)]
mock_model_call = mock.Mock(side_effect=['resp1', 'resp2a', 'resp2b'])
engine_args = {
'model_call': mock_model_call,
'model_call': self.mock_model_call,
'mode': 'Conversation',
}

Expand All @@ -356,6 +339,71 @@ def _reset_agent():
# Pair 2: reset_obs, resp1, obs1, resp2, obs2 -> 5 messages
self.assertLen(results[1][1], 5)

@mock.patch.object(utils, 'tokenize_and_generate_masks')
def test_status_max_context_limit_reached(self, mock_convert):
# 100 prompt + 100 step = 200 > 150. Should stop after 1 step.
mock_convert.side_effect = [
([1] * 100, [1] * 100), # prompt tokens
([1] * 100, [1] * 100), # assistant tokens 1
([1] * 100, [1] * 100), # env tokens 1
]
# Setup specific for this test
self.mock_env.max_steps = 5
self.mock_chat_parser.parse.return_value = 'mock_parsed_text'

engine = trajectory_collect_engine.TrajectoryCollectEngine(
agent=self.mock_agent,
env=self.mock_env,
model_call=self.mock_model_call,
tokenizer=self.mock_tokenizer,
chat_parser=self.mock_chat_parser,
max_context_limit=150,
)

result_traj = asyncio.run(self._run_collect(engine, mode='Trajectory'))

# Verify status is MAX_CONTEXT_LIMIT_REACHED
self.assertEqual(
result_traj.status, agent_types.TrajectoryStatus.MAX_CONTEXT_LIMIT_REACHED
)
# 100 prompt + 100 step = 200 > 150. Should stop after 1 step.
self.assertLen(result_traj.steps, 1)

def test_collect_max_steps_reached(self):
self.mock_env.max_steps = 1
self.mock_env.step.side_effect = [
('obs1', 1.0, True, {}),
]
engine = trajectory_collect_engine.TrajectoryCollectEngine(
agent=self.mock_agent,
env=self.mock_env,
model_call=self.mock_model_call,
)
result_traj = asyncio.run(self._run_collect(engine, mode='Trajectory'))

self.assertEqual(
result_traj.status, agent_types.TrajectoryStatus.SUCCEEDED
)
self.assertLen(result_traj.steps, 1)


def test_collect_timeout(self):
self.mock_env.max_steps = 10
with mock.patch.object(time, 'time') as mock_time:
mock_time.side_effect = [100.0, 100.05, 100.15, 100.15]

engine = trajectory_collect_engine.TrajectoryCollectEngine(
agent=self.mock_agent,
env=self.mock_env,
model_call=self.mock_model_call,
max_context_limit=1024,
timeout=0.1,
)
result_traj = asyncio.run(self._run_collect(engine, mode='Trajectory'))

self.assertTrue(result_traj.steps[-1].done)
self.assertEqual(result_traj.status, agent_types.TrajectoryStatus.TIMEOUT)


if __name__ == '__main__':
absltest.main()
9 changes: 3 additions & 6 deletions tests/rl/experimental/agentic_grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,19 +1246,16 @@ def _initial_observation(self) -> Any:
return "Initial prompt."

def _step_impl(self, action: Any) -> EnvStepResult:
if self.step_count <= self.max_steps:
reward = 1.0
done = False
else:
reward = 0.0
done = True
done = self.step_count >= self.max_steps
reward = 1.0 if not done else 0.0
return EnvStepResult(
observation=f"Observation after step {self.step_count}",
reward=reward,
done=done,
info={"max_steps": self.max_steps},
)


class MockAgent(ConversationAgentBase):

def __init__(self, system_prompt: str):
Expand Down
29 changes: 28 additions & 1 deletion tunix/rl/agentic/agents/agent_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from collections.abc import Hashable
import dataclasses
from typing import Any, Dict, Optional
from enum import Enum, auto
from typing import Any, Dict, List, Optional


@dataclasses.dataclass(kw_only=True)
Expand Down Expand Up @@ -57,6 +58,10 @@ class Step:
reward: Immediate reward signal from environment for this step.
done: Terminal state flag - True if episode has ended.
mc_return: Monte Carlo return from this step to episode end.
assistant_tokens: Token IDs generated by the assistant for this step.
assistant_masks: Masks for assistant tokens.
env_tokens: Token IDs generated by the environment for this step.
env_masks: Masks for environment tokens.
"""

chat_completions: list[dict[str, str]] = dataclasses.field(
Expand All @@ -70,6 +75,25 @@ class Step:
reward: float = 0.0
done: bool = False
mc_return: float = 0.0
assistant_tokens: Optional[List[int]] = None
assistant_masks: Optional[List[int]] = None
env_tokens: Optional[List[int]] = None
env_masks: Optional[List[int]] = None


class TrajectoryStatus(Enum):
"""Enum for trajectory status."""

SUCCEEDED = auto()
RUNNING = auto()

# Agent Constraints
MAX_STEPS_REACHED = auto() # corresponds to `max_steps`
MAX_CONTEXT_LIMIT_REACHED = auto() # corresponds to `max_context_limit`
TIMEOUT = auto() # corresponds to `timeout`

# System Errors
FAILED = auto()


@dataclasses.dataclass(kw_only=True)
Expand All @@ -84,11 +108,13 @@ class Trajectory:
task: Task description, initial prompt, or episode specification.
steps: Chronologically ordered sequence of interaction steps.
reward: Total episode reward (cumulative or final environment score).
status: Status of the trajectory (e.g., "success", "truncated").
"""

task: Any = None
steps: list[Step] = dataclasses.field(default_factory=list)
reward: float = 0.0
status: TrajectoryStatus = TrajectoryStatus.RUNNING

def to_dict(self) -> dict[str, Any]:
"""Convert trajectory to dictionary format for serialization.
Expand All @@ -103,6 +129,7 @@ def to_dict(self) -> dict[str, Any]:
"task": self.task,
"steps": [dataclasses.asdict(step) for step in self.steps],
"reward": float(self.reward),
"status": self.status.name,
}


Expand Down
Loading