Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 100 additions & 42 deletions tests/rl/experimental/agentic_grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@
from tunix.generate import tokenizer_adapter
from tunix.rl import function_registry
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.sft import metrics_logger
from tunix.rl.agentic.agents.agent_types import Action, Step
from tunix.rl.agentic.agents.base_agent import ConversationAgentBase
from tunix.rl.agentic.environments.base_environment import BaseTaskEnv, EnvStepResult
from tunix.rl.experimental import agentic_grpo_learner
from tunix.rl.queue import data_queue as queue_lib
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from tunix.tests import test_common
from tunix.utils import trajectory_logger
from typing_extensions import override
from tunix.rl.agentic.agents.base_agent import ConversationAgentBase
from tunix.rl.agentic.agents.agent_types import Action, Step
from tunix.rl.agentic.environments.base_environment import BaseTaskEnv, EnvStepResult

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"
Mesh = sharding.Mesh
Expand Down Expand Up @@ -229,7 +229,8 @@ async def _orchestrator_producer(
i += 1

algo_config = agentic_grpo_learner.GRPOConfig(
num_generations=2, num_iterations=2
num_generations=2,
num_iterations=2,
)
trainer = _MockTrainer(algo_config)

Expand Down Expand Up @@ -292,7 +293,6 @@ def test_num_iterations_greater_than_1(self):
train_micro_batch_size=1, # to control calls to update_actor
),
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
max_prompt_length=256,
kv_cache_size=1024,
),
Expand All @@ -308,6 +308,7 @@ def test_num_iterations_greater_than_1(self):
num_generations=2,
num_iterations=2, # > 1
loss_algo="grpo",
max_response_length=10,
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
Expand Down Expand Up @@ -441,7 +442,6 @@ def create_learner(
checkpoint_root_directory=ckpt_dir,
),
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
max_prompt_length=32,
kv_cache_size=256,
temperature=0.5,
Expand All @@ -457,6 +457,7 @@ def create_learner(
grpo_config = agentic_grpo_learner.GRPOConfig(
num_generations=2,
num_iterations=1,
max_response_length=10,
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
Expand Down Expand Up @@ -542,7 +543,6 @@ def create_learner(
train_micro_batch_size=train_micro_batch_size,
),
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
max_prompt_length=32,
kv_cache_size=256,
temperature=0.5,
Expand All @@ -558,6 +558,7 @@ def create_learner(
grpo_config = agentic_grpo_learner.GRPOConfig(
num_generations=2,
num_iterations=1,
max_response_length=10,
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
Expand Down Expand Up @@ -639,7 +640,6 @@ def create_learner(
train_micro_batch_size=train_micro_batch_size,
),
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
max_prompt_length=32,
kv_cache_size=256,
temperature=0.5,
Expand All @@ -655,6 +655,7 @@ def create_learner(
grpo_config = agentic_grpo_learner.GRPOConfig(
num_generations=2,
num_iterations=1,
max_response_length=10,
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
Expand Down Expand Up @@ -735,7 +736,6 @@ def create_learner(
offload_to_cpu=False,
training_config=training_config,
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
max_prompt_length=32,
kv_cache_size=256,
temperature=0.5,
Expand All @@ -751,6 +751,7 @@ def create_learner(
grpo_config = agentic_grpo_learner.GRPOConfig(
num_generations=2,
num_iterations=1,
max_response_length=10,
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
Expand Down Expand Up @@ -814,7 +815,6 @@ def test_exception_handling(self):
eval_every_n_steps=10,
),
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
max_prompt_length=32,
kv_cache_size=256,
),
Expand All @@ -825,7 +825,7 @@ def test_exception_handling(self):
tokenizer=tokenizer,
cluster_config=cluster_config,
)
grpo_config = agentic_grpo_learner.GRPOConfig()
grpo_config = agentic_grpo_learner.GRPOConfig(max_response_length=10)
learner = _LearnerWithException(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
Expand Down Expand Up @@ -885,7 +885,6 @@ def test_grpo_learner(self, reward_fns, loss_algo):
gradient_accumulation_steps=None,
),
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
max_prompt_length=256,
kv_cache_size=1024,
),
Expand All @@ -902,6 +901,7 @@ def test_grpo_learner(self, reward_fns, loss_algo):
num_generations=2,
num_iterations=1,
loss_algo=loss_algo,
max_response_length=10,
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
Expand Down Expand Up @@ -1019,7 +1019,6 @@ def test_on_off_policy_training(self, offpolicy_steps):
gradient_accumulation_steps=None,
),
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
max_prompt_length=256,
kv_cache_size=1024,
),
Expand All @@ -1036,6 +1035,7 @@ def test_on_off_policy_training(self, offpolicy_steps):
num_iterations=1,
loss_algo="grpo",
off_policy_steps=offpolicy_steps,
max_response_length=10,
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
Expand Down Expand Up @@ -1146,7 +1146,6 @@ def test_trajectory_logging(self):
),
),
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
max_prompt_length=256,
kv_cache_size=1024,
),
Expand All @@ -1162,6 +1161,7 @@ def test_trajectory_logging(self):
num_generations=2,
num_iterations=1,
loss_algo="grpo",
max_response_length=10,
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
Expand All @@ -1172,10 +1172,9 @@ def test_trajectory_logging(self):
)
train_ds = _dummy_dataset(MySource(data=["1"], repeat=1), batch_size=1)

with mock.patch.object(
trajectory_logger, "log_item"
) as mock_log_item, mock.patch.object(
rl_cluster, "generate", side_effect=_mock_generate
with (
mock.patch.object(trajectory_logger, "log_item") as mock_log_item,
mock.patch.object(rl_cluster, "generate", side_effect=_mock_generate),
):
grpo_learner.train(train_ds)
if grpo_learner._trajectory_logger:
Expand All @@ -1187,9 +1186,7 @@ def test_trajectory_logging(self):
traj = mock_log_item.call_args_list[i][0][1]
self.assertIn("conversation_text", traj)
conversation = traj["conversation_text"]
assistant_msgs = [
m for m in conversation if m["role"] == "assistant"
]
assistant_msgs = [m for m in conversation if m["role"] == "assistant"]
self.assertNotEmpty(assistant_msgs)
self.assertIn(assistant_msgs[0]["content"], _MOCK_RESPONSES)
self.assertEqual(traj.get("policy_version"), 0)
Expand Down Expand Up @@ -1240,7 +1237,6 @@ def test_grpo_with_lora_model(self):
max_steps=10,
),
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
max_prompt_length=256,
kv_cache_size=1024,
),
Expand All @@ -1254,6 +1250,7 @@ def test_grpo_with_lora_model(self):
grpo_config = agentic_grpo_learner.GRPOConfig(
num_generations=2,
num_iterations=1,
max_response_length=10,
)

grpo_learner = agentic_grpo_learner.GRPOLearner(
Expand Down Expand Up @@ -1285,7 +1282,6 @@ def test_grpo_with_lora_model(self):
)

def test_customized_agent_env(self):

class MockEnv(BaseTaskEnv):

def __init__(self, entry: dict[str, str], max_steps: int, **kwargs):
Expand All @@ -1296,7 +1292,7 @@ def _initial_observation(self) -> Any:
return "Initial prompt."

def _step_impl(self, action: Any) -> EnvStepResult:
if self.step_count < self.max_steps - 1:
if self.step_count <= self.max_steps:
reward = 1.0
done = False
else:
Expand All @@ -1320,28 +1316,66 @@ def _observation_to_messages(self, observation, reward, done, info):
if max_steps is not None:
remaining_steps = max_steps - self.step - 1
if remaining_steps > 0:
observation += f"\nSteps Remaining: {remaining_steps}"
observation += f" Steps Remaining: {remaining_steps}"
else:
observation += "\nYou have reached the maximum number of steps."
observation += " You have reached the maximum number of steps."
self._messages.append({"role": "user", "content": observation})
self.cur_step = Step(observation=observation)
step = self.get_current_state()
if step:
step.observation = observation

def update_from_model(self, response, **kwargs):
self._trajectory.steps.append(self.cur_step)
cur_step = self._trajectory.steps[-1]
cur_step.model_response = response
cur_step.action = f"Model action: {response}"
step = Step(model_response=response, action=f"Model action: {response}")
self._trajectory.steps.append(step)

self._messages.append({"role": "assistant", "content": response})
return Action(action=cur_step.action)
self.step += 1
return Action(action=step.action)

unique_words = {word for line in _MOCK_RESPONSES for word in line.split()}
words = [
"<pad>",
"<s>",
"</s>",
"System:",
"User:",
"Assistant:",
"Initial",
"prompt.",
"System",
"Observation",
"after",
"step",
"Steps",
"Remaining:",
"You",
"have",
"reached",
"the",
"maximum",
"number",
"of",
"steps.",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"10",
]
words.extend(sorted(unique_words))
mapping_text_to_id = {word: i for i, word in enumerate(words)}

vocab = test_common.MockVocab()
vocab = test_common.MockVocab(mapping_text_to_id=mapping_text_to_id)
tokenizer = tokenizer_adapter.TokenizerAdapter(vocab)
model = test_common.ToyTransformer(
config=test_common.ModelConfig(vocab_size=vocab.GetPieceSize()),
rngs=nnx.Rngs(0),
)
original_variables = jax.tree.map(jnp.copy, nnx.state(model, nnx.Param))
ref_model = test_common.ToyTransformer(
config=test_common.ModelConfig(vocab_size=vocab.GetPieceSize()),
rngs=nnx.Rngs(0),
Expand All @@ -1364,7 +1398,7 @@ def update_from_model(self, response, **kwargs):
),
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
max_prompt_length=256,
max_prompt_length=32,
kv_cache_size=1024,
),
)
Expand All @@ -1380,6 +1414,8 @@ def update_from_model(self, response, **kwargs):
num_generations=2,
num_iterations=1,
loss_algo="grpo",
max_response_length=64,
max_concurrency=1, # so the output is deterministic.
)
grpo_learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
Expand All @@ -1393,8 +1429,7 @@ def update_from_model(self, response, **kwargs):
env_kwargs={"max_steps": 3},
)

agents = []
envs = []
agents, envs = [], []

original_fn = grpo_learner._create_agent_env_pair

Expand All @@ -1404,7 +1439,23 @@ def _patch_create_agent_env_pair(single_example, group_id):
envs.append(env)
return agent, env

original_process_results = grpo_learner._process_results
processed_results = []

def _patch_process_results(
trajectories,
training_input,
mode,
expected_step,
):
res = original_process_results(
trajectories, training_input, mode, expected_step
)
processed_results.append(res)
return res

grpo_learner._create_agent_env_pair = _patch_create_agent_env_pair
grpo_learner._process_results = _patch_process_results

self.assertFalse(grpo_learner.should_sync_weights)
train_ds = _dummy_dataset(MySource(repeat=10), batch_size=2)
Expand All @@ -1413,12 +1464,19 @@ def _patch_create_agent_env_pair(single_example, group_id):
with mock.patch.object(rl_cluster, "generate", side_effect=_mock_generate):
grpo_learner.train(train_ds, eval_ds)

variables = nnx.state(model, nnx.Param)
jax.tree.map_with_path(
test_common.assert_not_equal, original_variables, variables
traj = agents[0].trajectory

target_mask = []
for step in traj.steps:
# + 1 for extra token from MockChatParser
target_mask.extend([1] * (len(step.model_response.split()) + 1))
target_mask.extend([0] * (len(step.observation.split()) + 1))
target_mask.extend(
[0] * (grpo_config.max_response_length - len(target_mask))
)

# TODO(tsbao): check on generated agents and envs once deepcopy is removed.
res = processed_results[0][0]
np.testing.assert_array_equal(res.completion_mask[0], np.array(target_mask))


if __name__ == "__main__":
Expand Down
Loading
Loading