diff --git a/tests/cli/base_rl_main_test.py b/tests/cli/base_rl_main_test.py new file mode 100644 index 000000000..8e398705a --- /dev/null +++ b/tests/cli/base_rl_main_test.py @@ -0,0 +1,845 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests that grpo_main dispatches correctly for both training modes. + +Also tests that KV cache / GRPOConfig computation is correct. +""" +import dataclasses +import os +import pathlib +import tempfile +from unittest import mock + +from absl.testing import absltest +import omegaconf +from tunix.cli import base_rl_main +from tunix.rl import algorithm_config as algo_config_lib +from tunix.rl import rl_cluster as rl_cluster_lib + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_REPO_ROOT = pathlib.Path(__file__).resolve().parents[2] + +@dataclasses.dataclass(kw_only=True) +class DummyConfig(algo_config_lib.AlgorithmConfig): + """Dummy RL algorithm config""" + num_generations: int = 2 + num_iterations: int = 1 + episode_timeout: int = 1 + max_response_length: int = 1000 + +class DummyPipeline(base_rl_main.BasePipeline): + def __init__(self, argv: list[str], **kwargs): + self.data_module: types.ModuleType | None = None + super().__init__(argv, **kwargs) + + def _is_agentic_mode(self, mode:str) -> bool: + return mode == "agentic_dummy" + + @property + def _default_training_mode(self): + return "dummy" + + def _create_agentic_dummy_config(self): + cfg = dict(self._config_mapping("agentic_dummy_config")) + + # episode_timeout = per_turn_timeout_secs * max_turns when not explicit + if "episode_timeout" not in cfg: + per_turn = cfg.pop("per_turn_timeout_secs", None) + max_turns = cfg.get("max_turns", 1) + if per_turn is not None: + cfg["episode_timeout"] = per_turn * max_turns + + # max_response_length mirrors rollout_config.total_generation_steps + if "max_response_length" not in cfg: + cfg["max_response_length"] = self._config_mapping("rollout_config").get( + "total_generation_steps", 8192 + ) + + # Strip helper keys that are not GRPOConfig fields + valid = {f.name for f in dataclasses.fields(DummyConfig)} + cfg.pop("max_turns", None) + + return DummyConfig(**{k: v for k, v in cfg.items() if k in valid}) + + def create_rl_cluster(self): + pass + + def compute_params(self): + pass + + def _run(self): + pass + +def _make_pipeline(extra_yaml: str) -> DummyPipeline: + """Write a minimal valid YAML and instantiate DummyPipeline against it.""" + base = """ +model_config: + model_name: "test_model" + model_id: "test/model" + model_source: "huggingface" + model_display: false + rng_seed: 0 + intermediate_ckpt_dir: "/tmp/ckpt" + +actor_model_config: + mesh: + shape: "(1,1)" + axis_names: "('fsdp','tp')" + +reference_model_config: + mesh: + shape: "(1,1)" + axis_names: "('fsdp','tp')" + +rollout_model_config: + mesh: + shape: "(1,1)" + axis_names: "('fsdp','tp')" + +tokenizer_config: + tokenizer_type: "huggingface" + tokenizer_path: "test/model" + add_bos: false + add_eos: false + +rollout_engine: "vanilla" +offload_to_cpu: false + +rollout_config: + max_prompt_length: 256 + total_generation_steps: 512 + temperature: 1.0 + top_p: null + top_k: null + +rl_training_config: + max_steps: 1 + eval_every_n_steps: 1 + mini_batch_size: 1 + train_micro_batch_size: 1 + actor_optimizer_config: + opt_type: "adamw" + learning_rate: 1.0e-6 + schedule_type: "warmup_cosine_decay_schedule" + init_value: 0.0 + end_value: 0.0 + warmup_ratio: 0.1 + b1: 0.9 + b2: 0.99 + weight_decay: 0.01 + max_grad_norm: 1.0 + metrics_logging_options: + log_dir: "/tmp/tb_test" + flush_every_n_steps: 1 + checkpointing_options: + save_interval_steps: 100 + max_to_keep: 1 + checkpoint_root_directory: "/tmp/ckpt_test" + +batch_size: 1 +num_batches: 1 +num_train_epochs: 1 +train_fraction: 1.0 +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(base + extra_yaml) + path = f.name + + # Patch HF_TOKEN so tokenizer validation passes + with mock.patch.dict(os.environ, {"HF_TOKEN": "fake"}): + pipeline = DummyPipeline(["", path]) + os.unlink(path) + return pipeline + + +def _make_pipeline_with_cli_args( + extra_yaml: str, cli_args: list[str] +) -> DummyPipeline: + """Write a minimal valid YAML and instantiate DummyPipeline with CLI args.""" + base = """ +model_config: + model_name: "test_model" + model_id: "test/model" + model_source: "huggingface" + model_display: false + rng_seed: 0 + intermediate_ckpt_dir: "/tmp/ckpt" + +actor_model_config: + mesh: + shape: "(1,1)" + axis_names: "('fsdp','tp')" + +reference_model_config: + mesh: + shape: "(1,1)" + axis_names: "('fsdp','tp')" + +rollout_model_config: + mesh: + shape: "(1,1)" + axis_names: "('fsdp','tp')" + +tokenizer_config: + tokenizer_type: "huggingface" + tokenizer_path: "test/model" + add_bos: false + add_eos: false + +rollout_engine: "vanilla" +offload_to_cpu: false + +rollout_config: + max_prompt_length: 256 + total_generation_steps: 512 + temperature: 1.0 + top_p: null + top_k: null + +rl_training_config: + max_steps: 1 + eval_every_n_steps: 1 + mini_batch_size: 1 + train_micro_batch_size: 1 + actor_optimizer_config: + opt_type: "adamw" + learning_rate: 1.0e-6 + schedule_type: "warmup_cosine_decay_schedule" + init_value: 0.0 + end_value: 0.0 + warmup_ratio: 0.1 + b1: 0.9 + b2: 0.99 + weight_decay: 0.01 + max_grad_norm: 1.0 + metrics_logging_options: + log_dir: "/tmp/tb_test" + flush_every_n_steps: 1 + checkpointing_options: + save_interval_steps: 100 + max_to_keep: 1 + checkpoint_root_directory: "/tmp/ckpt_test" + +batch_size: 1 +num_batches: 1 +num_train_epochs: 1 +train_fraction: 1.0 +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(base + extra_yaml) + path = f.name + + with mock.patch.dict(os.environ, {"HF_TOKEN": "fake"}): + pipeline = DummyPipeline(["", path, *cli_args]) + os.unlink(path) + return pipeline + + +# --------------------------------------------------------------------------- +# Mode dispatch +# --------------------------------------------------------------------------- + + +class DispatchTest(absltest.TestCase): + + def test_agentic_data_module_receives_data_config_for_raw_dataset(self): + extra = """ +training_mode: "agentic_dummy" +data_module: "tunix.cli.recipes.deepscaler_data" +apply_chat_template_to_dataset: false +data_config: + train_data_path: "gs://fake/train.json" + eval_data_path: "gs://fake/eval.parquet" +prompt_key: "prompts" +reward_functions: [] +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {} +env_class_path: null +env_kwargs: {} +kubernetes_config: null +agentic_dummy_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: "" + max_concurrency: 1 + off_policy_steps: 0 + max_turns: 1 +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + pipeline = _make_pipeline(extra) + fake_module = mock.Mock(batch_fn=mock.sentinel.batch_fn) + + with mock.patch.object( + pipeline, "_get_data_module", return_value=fake_module + ): + with mock.patch.object( + pipeline, + "_get_dataset", + return_value=mock.sentinel.dataset, + ) as get_dataset: + dataset, batch_fn = pipeline._load_raw_dataset(mock.sentinel.tokenizer) + + self.assertIs(dataset, mock.sentinel.dataset) + self.assertIs(batch_fn, mock.sentinel.batch_fn) + get_dataset.assert_called_once_with( + mock.sentinel.tokenizer, + ) + + def test_agentic_nullable_string_can_be_overridden_from_cli(self): + extra = """ +training_mode: "agentic_dummy" +data_module: "tunix.cli.recipes.deepscaler_data" +apply_chat_template_to_dataset: false +data_config: + train_data_path: "gs://fake/train.json" + eval_data_path: "gs://fake/eval.parquet" +prompt_key: "prompts" +reward_functions: [] +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {} +env_class_path: null +env_kwargs: {} +kubernetes_config: null +agentic_dummy_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: "" + max_concurrency: 1 + off_policy_steps: 0 + max_turns: 1 +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + pipeline = _make_pipeline_with_cli_args( + extra, + ["agent_class_path=examples.deepswe.swe_agent.SWEAgent"], + ) + + self.assertEqual( + pipeline.config["agent_class_path"], + "examples.deepswe.swe_agent.SWEAgent", + ) + + def test_agentic_nullable_dict_can_be_overridden_from_cli(self): + extra = """ +training_mode: "agentic_dummy" +data_module: "tunix.cli.recipes.deepscaler_data" +apply_chat_template_to_dataset: false +data_config: + train_data_path: "gs://fake/train.json" + eval_data_path: "gs://fake/eval.parquet" +prompt_key: "prompts" +reward_functions: [] +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {} +env_class_path: null +env_kwargs: {} +kubernetes_config: null +agentic_dummy_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: "" + max_concurrency: 1 + off_policy_steps: 0 + max_turns: 1 +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + pipeline = _make_pipeline_with_cli_args( + extra, + [ + "kubernetes_config.node_selector_key=cloud.google.com/gke-nodepool", + "kubernetes_config.node_selector_val=deepswe-cpu-pool", + ], + ) + + self.assertEqual( + pipeline.config["kubernetes_config"], + { + "node_selector_key": "cloud.google.com/gke-nodepool", + "node_selector_val": "deepswe-cpu-pool", + }, + ) + + def test_agentic_nullable_string_can_be_overridden_from_env(self): + extra = """ +training_mode: "agentic_dummy" +data_module: "tunix.cli.recipes.deepscaler_data" +apply_chat_template_to_dataset: false +data_config: + train_data_path: "gs://fake/train.json" + eval_data_path: "gs://fake/eval.parquet" +prompt_key: "prompts" +reward_functions: [] +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {} +env_class_path: null +env_kwargs: {} +kubernetes_config: null +agentic_dummy_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: "" + max_concurrency: 1 + off_policy_steps: 0 + max_turns: 1 +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + with mock.patch.dict( + os.environ, + {"T_AGENT_CLASS_PATH": "examples.deepswe.swe_agent.SWEAgent"}, + ): + pipeline = _make_pipeline_with_cli_args(extra, []) + + self.assertEqual( + pipeline.config["agent_class_path"], + "examples.deepswe.swe_agent.SWEAgent", + ) + + def test_standard_dummy_dispatches_to_standard(self): + extra = """ +dummy_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 +data_source: "tfds" +dataset_name: "gsm8k" +tfds_download: false +reward_functions: [] +verl_compatible: false +""" + pipeline = _make_pipeline(extra) + self.assertEqual(pipeline.config.get("training_mode", "dummy"), "dummy") + with mock.patch.object(pipeline, "_run") as mock_run: + pipeline.run_trainer() + mock_run.assert_called_once_with(mode="dummy") + + def test_agentic_dummy_dispatches_to_agentic(self): + extra = """ +training_mode: "agentic_dummy" +data_module: "tunix.cli.recipes.deepscaler_data" +apply_chat_template_to_dataset: false +data_config: + train_data_path: "gs://fake/train.json" + eval_data_path: "gs://fake/eval.parquet" +prompt_key: "prompts" +reward_functions: + - "tunix/utils/math_rewards.py" +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {} +env_class_path: null +env_kwargs: {} +kubernetes_config: null +agentic_dummy_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: "" + max_concurrency: 1 + off_policy_steps: 0 + max_turns: 1 +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + pipeline = _make_pipeline(extra) + self.assertEqual(pipeline.config["training_mode"], "agentic_dummy") + with mock.patch.object(pipeline, "_run") as mock_run: + pipeline.run_trainer() + mock_run.assert_called_once_with(mode="agentic_dummy") + +# --------------------------------------------------------------------------- +# KV cache formula +# --------------------------------------------------------------------------- + + +class RolloutConfigTest(absltest.TestCase): + + def _make_agentic_pipeline(self, max_turns): + extra = f""" +training_mode: "agentic_dummy" +data_module: "tunix.cli.recipes.deepscaler_data" +apply_chat_template_to_dataset: false +data_config: + train_data_path: "gs://fake/train.json" + eval_data_path: "gs://fake/eval.parquet" +prompt_key: "prompts" +reward_functions: [] +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {{}} +env_class_path: null +env_kwargs: {{}} +kubernetes_config: null +agentic_dummy_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: "" + max_concurrency: 1 + off_policy_steps: 0 + max_turns: {max_turns} +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + return _make_pipeline(extra) + + def test_single_turn_kv_cache(self): + p = self._make_agentic_pipeline(max_turns=1) + cfg = p.create_rollout_config() + # max_prompt=256, max_response=512, single-turn → +256 + self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256) + + def test_multi_turn_kv_cache(self): + p = self._make_agentic_pipeline(max_turns=20) + cfg = p.create_rollout_config() + self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256) + + def test_standard_dummy_kv_cache(self): + extra = """ +dummy_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 +data_source: "tfds" +dataset_name: "gsm8k" +tfds_download: false +reward_functions: [] +verl_compatible: false +""" + p = _make_pipeline(extra) + cfg = p.create_rollout_config() + self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256) + + + def test_vllm_submission_threshold_passed_through(self): + extra = """ +vllm_config: + hbm_utilization: 0.4 + server_mode: true + server_mode_submission_threshold: 3840 + server_mode_submission_timeout_s: 1.5 +""" + p = _make_pipeline_with_cli_args(extra, ["rollout_engine=vllm"]) + role_to_mesh = { + rl_cluster_lib.Role.ROLLOUT: mock.Mock( + devices=mock.Mock(shape=(1, 1)) + ) + } + cfg = p.create_rollout_config(role_to_mesh=role_to_mesh) + self.assertEqual(cfg.rollout_vllm_server_mode_submission_threshold, 3840) + self.assertEqual(cfg.rollout_vllm_server_mode_submission_timeout_s, 1.5) + +# --------------------------------------------------------------------------- +# GRPOConfig construction +# --------------------------------------------------------------------------- + + +class AgenticConfigTest(absltest.TestCase): + + def _base_extra(self, agentic_overrides="", system_prompt='""'): + return f""" +training_mode: "agentic_dummy" +data_module: "tunix.cli.recipes.deepscaler_data" +apply_chat_template_to_dataset: false +data_config: + train_data_path: "gs://fake/train.json" + eval_data_path: "gs://fake/eval.parquet" +prompt_key: "prompts" +reward_functions: [] +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {{}} +env_class_path: null +env_kwargs: {{}} +kubernetes_config: null +agentic_dummy_config: + num_generations: 2 + num_iterations: 1 + beta: 0.001 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: {system_prompt} + max_concurrency: 1 + off_policy_steps: 0 + {agentic_overrides} +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + + def test_episode_timeout_computed(self): + p = _make_pipeline( + self._base_extra("max_turns: 20\n per_turn_timeout_secs: 300") + ) + algo = p._create_agentic_dummy_config() + self.assertEqual(algo.episode_timeout, 300 * 20) + + def test_max_response_length_from_rollout(self): + p = _make_pipeline(self._base_extra("max_turns: 1")) + algo = p._create_agentic_dummy_config() + # rollout_config.total_generation_steps = 512 + self.assertEqual(algo.max_response_length, 512) + + def test_num_generations_passed_through(self): + p = _make_pipeline(self._base_extra("max_turns: 1")) + algo = p._create_agentic_dummy_config() + self.assertEqual(algo.num_generations, 2) + + def test_cli_empty_system_prompt_stays_empty_string(self): + p = _make_pipeline_with_cli_args( + self._base_extra("max_turns: 1", system_prompt='"base"'), + ['agentic_dummy_config.system_prompt=""'], + ) + self.assertEqual(p.config["agentic_dummy_config"]["system_prompt"], "") + + +class SplitMeshConfigTest(absltest.TestCase): + + def test_split_mesh_uses_explicit_role_meshes(self): + extra = """ +training_mode: "agentic_dummy" +data_module: "tunix.cli.recipes.deepscaler_data" +apply_chat_template_to_dataset: false +data_config: + train_data_path: "gs://fake/train.json" + eval_data_path: "gs://fake/eval.parquet" +prompt_key: "prompts" +reward_functions: [] +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {} +env_class_path: null +env_kwargs: {} +kubernetes_config: null +agentic_dummy_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: "" + max_concurrency: 1 + off_policy_steps: 0 + max_turns: 1 +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + pipeline = _make_pipeline(extra) + actor_model_config = pipeline.config["actor_model_config"] + if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig): + actor_model_config["mesh"] = { + "shape": "(2,1)", + "axis_names": "('fsdp','tp')", + } + pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"} + rollout_model_config = pipeline.config["rollout_model_config"] + if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig): + rollout_model_config["mesh"] = { + "shape": "(1,2)", + "axis_names": "('fsdp','tp')", + } + + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + self.process_index = 0 + self.slice_index = 0 + self.device_kind = "TPU v5e" + + fake_devices = [ + FakeDevice(0, (0, 0)), + FakeDevice(1, (1, 0)), + FakeDevice(2, (0, 1)), + FakeDevice(3, (1, 1)), + ] + + class FakeMesh: + + def __init__(self, devices, axis_names, axis_types=None): + self.devices = devices + self.axis_names = axis_names + self.axis_types = axis_types + + with mock.patch.object(base_rl_main.jax, "devices", return_value=fake_devices): + with mock.patch.object( + base_rl_main.jax.sharding, "Mesh", side_effect=FakeMesh + ): + role_to_mesh = pipeline.create_role_to_mesh() + + self.assertSequenceEqual( + [ + device.id + for device in ( + role_to_mesh[rl_cluster_lib.Role.ACTOR] + .devices.flatten() + .tolist() + ) + ], + [0, 1], + ) + self.assertSequenceEqual( + [ + device.id + for device in ( + role_to_mesh[rl_cluster_lib.Role.ROLLOUT] + .devices.flatten() + .tolist() + ) + ], + [2, 3], + ) + self.assertEqual( + role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.shape, + (2, 1), + ) + self.assertEqual( + role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.shape, + (1, 2), + ) + self.assertIs( + role_to_mesh[rl_cluster_lib.Role.REFERENCE], + role_to_mesh[rl_cluster_lib.Role.ACTOR], + ) + + def test_create_role_to_mesh_passes_configured_allocation_policy(self): + extra = """ +training_mode: "agentic_dummy" +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {} +env_class_path: null +env_kwargs: {} +kubernetes_config: null +agentic_dummy_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: "" + max_concurrency: 1 + off_policy_steps: 0 + max_turns: 1 +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + pipeline = _make_pipeline(extra) + actor_model_config = pipeline.config["actor_model_config"] + if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig): + actor_model_config["mesh"] = { + "shape": "(2,1)", + "axis_names": "('fsdp','tp')", + "allocation_policy": "PERFORMANCE", + } + pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"} + rollout_model_config = pipeline.config["rollout_model_config"] + if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig): + rollout_model_config["mesh"] = { + "shape": "(1,2)", + "axis_names": "('fsdp','tp')", + "allocation_policy": "PERFORMANCE", + } + + fake_devices = list(range(4)) + + with mock.patch.object(base_rl_main.jax, "devices", return_value=fake_devices): + with mock.patch.object( + base_rl_main.mesh_lib, + "allocate_named_mesh_device_slices", + return_value={ + "actor_model_config": [0, 1], + "rollout_model_config": [2, 3], + }, + ) as allocate_mock, mock.patch.object( + base_rl_main.mesh_lib, + "create_mesh", + side_effect=[object(), object()], + ): + pipeline.create_role_to_mesh() + + allocate_mock.assert_called_once_with( + [("actor_model_config", 2), ("rollout_model_config", 2)], + devices=fake_devices, + allocation_policy="PERFORMANCE", + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/cli/grpo_main_test.py b/tests/cli/grpo_main_test.py index 7af7a3f58..69184898a 100644 --- a/tests/cli/grpo_main_test.py +++ b/tests/cli/grpo_main_test.py @@ -25,6 +25,7 @@ from absl.testing import absltest import omegaconf from tunix.cli import grpo_main +from tunix.cli import base_rl_main from tunix.rl import rl_cluster as rl_cluster_lib # --------------------------------------------------------------------------- @@ -410,7 +411,7 @@ def test_standard_grpo_dispatches_to_standard(self): pipeline = _make_pipeline(extra) self.assertEqual(pipeline.config.get("training_mode", "grpo"), "grpo") with mock.patch.object(pipeline, "_run") as mock_run: - pipeline.run_grpo_trainer() + pipeline.run_trainer() mock_run.assert_called_once_with(mode="grpo") def test_agentic_grpo_dispatches_to_agentic(self): @@ -450,7 +451,7 @@ def test_agentic_grpo_dispatches_to_agentic(self): pipeline = _make_pipeline(extra) self.assertEqual(pipeline.config["training_mode"], "agentic_grpo") with mock.patch.object(pipeline, "_run") as mock_run: - pipeline.run_grpo_trainer() + pipeline.run_trainer() mock_run.assert_called_once_with(mode="agentic_grpo") def test_unknown_mode_raises(self): @@ -499,7 +500,7 @@ def test_unknown_mode_raises(self): with self.assertRaisesRegex( ValueError, "Unsupported training_mode 'bad_mode'" ): - pipeline.run_grpo_trainer() + pipeline.run_trainer() # --------------------------------------------------------------------------- @@ -731,9 +732,9 @@ def __init__(self, devices, axis_names, axis_types=None): self.axis_names = axis_names self.axis_types = axis_types - with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices): + with mock.patch.object(base_rl_main.jax, "devices", return_value=fake_devices): with mock.patch.object( - grpo_main.jax.sharding, "Mesh", side_effect=FakeMesh + base_rl_main.jax.sharding, "Mesh", side_effect=FakeMesh ): role_to_mesh = pipeline.create_role_to_mesh() @@ -817,16 +818,16 @@ def test_create_role_to_mesh_passes_configured_allocation_policy(self): fake_devices = list(range(4)) - with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices): + with mock.patch.object(base_rl_main.jax, "devices", return_value=fake_devices): with mock.patch.object( - grpo_main.mesh_lib, + base_rl_main.mesh_lib, "allocate_named_mesh_device_slices", return_value={ "actor_model_config": [0, 1], "rollout_model_config": [2, 3], }, ) as allocate_mock, mock.patch.object( - grpo_main.mesh_lib, + base_rl_main.mesh_lib, "create_mesh", side_effect=[object(), object()], ): diff --git a/tunix/cli/base_rl_main.py b/tunix/cli/base_rl_main.py new file mode 100644 index 000000000..e1c286d17 --- /dev/null +++ b/tunix/cli/base_rl_main.py @@ -0,0 +1,789 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from collections.abc import MutableMapping +import dataclasses +import importlib +import os +import types +from typing import Any + +from absl import flags +from absl import logging +from flax import nnx +import jax +import jax.numpy as jnp +import numpy as np +from tunix.cli import config +from tunix.cli.utils import data as data_lib +from tunix.cli.utils import model as model_lib +from tunix.examples.data import math_dataset as example_data +from tunix.models.gemma import model as gemma_lib +from tunix.perf import export as perf_export +from tunix.perf import metrics as perf_metrics +from tunix.perf.experimental import export as perf_export_v2 +from tunix.rl import rl_cluster as rl_cluster_lib +from tunix.rl.rollout import base_rollout +from tunix.utils import mesh as mesh_lib + +PATHWAYS_BNS = flags.DEFINE_string( + "pathways_bns", None, "BNS address of the Pathways server." +) + + +class BasePipeline(abc.ABC, config.HyperParameters): + def __init__(self, argv: list[str], **kwargs): + self.data_module: types.ModuleType | None = None + super().__init__(argv, **kwargs) + + @property + @abc.abstractmethod + def _default_training_mode(self): + pass + + # ------------------------------------------------------------------ + # Mesh + # ------------------------------------------------------------------ + _ROLE_TO_MODEL_KEY = { + rl_cluster_lib.Role.ACTOR: "actor_model_config", + rl_cluster_lib.Role.CRITIC: "critic_model_config", + rl_cluster_lib.Role.REFERENCE: "reference_model_config", + rl_cluster_lib.Role.REWARD: "reward_model_config", + rl_cluster_lib.Role.ROLLOUT: "rollout_model_config", + } + _SPLIT_ROLE_ALIASES = { + "actor": rl_cluster_lib.Role.ACTOR, + "critic": rl_cluster_lib.Role.CRITIC, + "reference": rl_cluster_lib.Role.REFERENCE, + "reward": rl_cluster_lib.Role.REWARD, + "rollout": rl_cluster_lib.Role.ROLLOUT, + } + + def _is_agentic_mode(self, mode: str) -> bool: + """Checks if the given mode is agentic.""" + return mode == "agentic_ppo" or "agentic_grpo" + + def _resolve_split_role(self, role_name: str) -> rl_cluster_lib.Role: + normalized = role_name.strip().lower() + if normalized not in self._SPLIT_ROLE_ALIASES: + valid_roles = sorted(self._SPLIT_ROLE_ALIASES) + raise ValueError( + f"Unknown role name {role_name!r}. Expected one of {valid_roles}." + ) + return self._SPLIT_ROLE_ALIASES[normalized] + + def _get_same_mesh_as_map( + self, + ) -> dict[rl_cluster_lib.Role, rl_cluster_lib.Role]: + same_mesh_as = {} + for role, model_key in self._ROLE_TO_MODEL_KEY.items(): + model_cfg = self.config.get(model_key, {}) or {} + target_name = model_cfg.get("same_mesh_as") + if target_name is None: + continue + target_role = self._resolve_split_role(str(target_name)) + if role == rl_cluster_lib.Role.ACTOR: + raise ValueError("Actor must own its mesh.") + same_mesh_as[role] = target_role + + return same_mesh_as + + def _is_role_active(self, role: rl_cluster_lib.Role) -> bool: + if role in ( + rl_cluster_lib.Role.ACTOR, + rl_cluster_lib.Role.REFERENCE, + rl_cluster_lib.Role.ROLLOUT, + ): + return True + model_key = self._ROLE_TO_MODEL_KEY[role] + return model_key in self.config + + def _resolve_mesh_owners( + self, + ) -> dict[rl_cluster_lib.Role, rl_cluster_lib.Role]: + same_mesh_as = self._get_same_mesh_as_map() + base_owners = {} + for role, model_key in self._ROLE_TO_MODEL_KEY.items(): + if not self._is_role_active(role) and role not in same_mesh_as: + continue + + model_config = self.config.get(model_key, {}) + has_mesh = model_config is not None and bool( + model_config.get("mesh") + ) + base_owners[role] = ( + role + if role == rl_cluster_lib.Role.ACTOR or has_mesh + else rl_cluster_lib.Role.ACTOR + ) + + def resolve_owner( + role: rl_cluster_lib.Role, + seen: set[rl_cluster_lib.Role], + ) -> rl_cluster_lib.Role: + if role in seen: + raise ValueError("same_mesh_as contains a cycle.") + if role not in same_mesh_as: + return base_owners[role] + seen.add(role) + target_role = same_mesh_as[role] + if target_role not in base_owners: + raise ValueError( + f"Role {target_role.value!r} is not active in this config." + ) + return resolve_owner(target_role, seen) + + role_to_owner = {} + for role, model_key in self._ROLE_TO_MODEL_KEY.items(): + if role not in base_owners: + continue + + model_config = self.config.get(model_key, {}) + has_mesh = isinstance(model_config, dict) and bool( + model_config.get("mesh") + ) + if role in same_mesh_as: + if has_mesh: + raise ValueError( + f"{model_key}.mesh is specified, so it must own a separate mesh " + "and cannot also use same_mesh_as." + ) + else: + role_to_owner[role] = resolve_owner(role, set()) + continue + role_to_owner[role] = resolve_owner(role, set()) + return role_to_owner + + def create_role_to_mesh(self): + """Builds the role-to-mesh mapping for execution. + + Any role with an explicit ``*.mesh`` config gets a dedicated device slice. + Roles without a mesh share the actor mesh by default, or can point at + another role via ``same_mesh_as``. + + All mesh owners participating in the same allocation pass must agree on + one ``mesh.allocation_policy`` value. That policy is then passed to the + mesh allocator so users can choose between compact packing and + performance-oriented cubical packing from config. + + Returns: + A mapping from logical role to the concrete JAX mesh it should use. + + Raises: + ValueError: If mesh ownership resolution is invalid or if mesh owners + request conflicting allocation policies. + """ + devices = list(jax.devices()) + role_to_owner = self._resolve_mesh_owners() + owner_order = [] + for role in self._ROLE_TO_MODEL_KEY: + if role not in role_to_owner: + continue + owner = role_to_owner[role] + if owner not in owner_order: + owner_order.append(owner) + + mesh_requirements = [] + allocation_policy = None + for owner in owner_order: + model_key = self._ROLE_TO_MODEL_KEY[owner] + axis_shapes, _ = self.parse_mesh_config(model_key) + owner_policy = self._parse_mesh_allocation_policy(model_key) + if allocation_policy is None: + allocation_policy = owner_policy + elif owner_policy != allocation_policy: + raise ValueError( + "All owned meshes must use the same mesh.allocation_policy, got " + f"{allocation_policy!r} and {owner_policy!r}." + ) + mesh_requirements.append((model_key, int(np.prod(axis_shapes)))) + + allocated_devices = mesh_lib.allocate_named_mesh_device_slices( + mesh_requirements, + devices=devices, + allocation_policy=allocation_policy + or mesh_lib.normalize_allocation_policy(None), + ) + + owner_to_mesh = {} + for owner in owner_order: + model_key = self._ROLE_TO_MODEL_KEY[owner] + axis_shapes, axis_names = self.parse_mesh_config(model_key) + assigned_devices = allocated_devices[model_key] + owner_to_mesh[owner] = mesh_lib.create_mesh( + axis_shapes, axis_names, devices=assigned_devices + ) + return {role: owner_to_mesh[owner] for role, owner in role_to_owner.items()} + + # ------------------------------------------------------------------ + # Rollout config + # ------------------------------------------------------------------ + def create_rollout_config( + self, + role_to_mesh: dict[rl_cluster_lib.Role, jax.sharding.Mesh] | None = None, + ) -> base_rollout.RolloutConfig: + """Build RolloutConfig from YAML. + + Standard mode: pass rollout_config fields through with kv_cache_size = + max_prompt_length + total_generation_steps + 256. + + Agentic mode: same base. Same kv_cache_size calculation. + + Engine-specific extras (sglang_jax_config, vllm_config) are also applied. + + Args: + role_to_mesh: Optional mapping from logical role to JAX mesh. + + Returns: + The constructed RolloutConfig. + """ + rollout_cfg = self._config_mapping("rollout_config") + mode = self._config_string("training_mode", self._default_training_mode) + engine = self._config_string("rollout_engine", "vanilla") + + valid_fields = { + f.name for f in dataclasses.fields(base_rollout.RolloutConfig) + } + + # Base pass-through (same as original create_rollout_config) + filtered = {k: v for k, v in rollout_cfg.items() if k in valid_fields} + if "total_generation_steps" in rollout_cfg: + filtered["max_tokens_to_generate"] = rollout_cfg["total_generation_steps"] + + max_prompt = rollout_cfg.get("max_prompt_length", 0) + max_response = rollout_cfg.get("total_generation_steps", 0) + + kv_cache_size = 0 + if self._is_agentic_mode(mode): + agentic_cfg = self._config_mapping(f"{mode}_config") + kv_cache_size = max_prompt + max_response + 256 + filtered["kv_cache_size"] = kv_cache_size + logging.info("kv_cache_size: %d", kv_cache_size) + + max_running_requests = agentic_cfg.get("max_concurrency", 16) + else: + rl_algm_cfg = self._config_mapping(f"{mode}_config") + # Standard: kv_cache_size = max_prompt + max_response + 256 + if max_prompt and max_response: + kv_cache_size = max_prompt + max_response + 256 + filtered["kv_cache_size"] = kv_cache_size + # Defaults to global batch size * num_generations to allow full + # concurrency. + max_running_requests = self.config.get("batch_size", 1) * rl_algm_cfg.get( + "num_generations", 1 + ) + + # Engine-specific extras + extra = self._rollout_engine_extra( + engine, + kv_cache_size, + max_running_requests, + role_to_mesh=role_to_mesh, + ) + filtered.update({k: v for k, v in extra.items() if k in valid_fields}) + return base_rollout.RolloutConfig(**filtered) + + def _rollout_engine_extra( + self, + engine: str, + kv_cache_size: int, + max_running_requests: int, + role_to_mesh: dict[rl_cluster_lib.Role, jax.sharding.Mesh] | None = None, + ) -> dict[str, Any]: + """Return engine-specific RolloutConfig fields for agentic mode.""" + model_id = self._config_mapping("actor_model_config").get("model_id", "") + + if engine == "sglang_jax": + sg = self._config_mapping("sglang_jax_config") + return dict( + rollout_sglang_jax_model_version=sg.get("model_version", model_id), + rollout_sglang_jax_mem_fraction_static=sg.get( + "mem_fraction_static", 0.8 + ), + rollout_sglang_jax_init_with_random_weights=sg.get( + "init_with_random_weights", True + ), + rollout_sglang_jax_disable_radix_cache=sg.get( + "disable_radix_cache", True + ), + rollout_sglang_jax_enable_deterministic_sampling=sg.get( + "enable_deterministic_sampling", False + ), + rollout_sglang_jax_chunked_prefill_size=sg.get( + "chunked_prefill_size", 2048 + ), + rollout_sglang_jax_max_running_requests=sg.get( + "max_running_requests", + max_running_requests, + ), + rollout_sglang_jax_page_size=sg.get("page_size", 128), + rollout_sglang_jax_use_sort_for_toppk_minp=sg.get( + "use_sort_for_toppk_minp", False + ), + ) + + if engine == "vllm": + vllm = self._config_mapping("vllm_config") + if role_to_mesh is None: + raise ValueError( + "role_to_mesh must be provided for vllm rollout config." + ) + rollout_shape = role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.shape + rollout_cfg = self._config_mapping("rollout_config") + max_num_seqs = rollout_cfg.get( + "rollout_vllm_max_num_seqs", + vllm.get("max_num_seqs", 768), + ) + max_batched_tokens = rollout_cfg.get( + "rollout_vllm_max_num_batched_tokens", + vllm.get( + "max_num_batched_tokens", + (max_num_seqs * kv_cache_size) // 4, + ), + ) + submission_threshold = rollout_cfg.get( + "rollout_vllm_server_mode_submission_threshold", + vllm.get("server_mode_submission_threshold", 0), + ) + submission_timeout_s = rollout_cfg.get( + "rollout_vllm_server_mode_submission_timeout_s", + vllm.get("server_mode_submission_timeout_s", 0.0), + ) + os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" + return dict( + rollout_vllm_model_version=vllm.get("model_version", model_id), + rollout_vllm_hbm_utilization=vllm.get("hbm_utilization", 0.4), + rollout_vllm_tpu_backend_type=vllm.get("tpu_backend_type", "jax"), + rollout_vllm_server_mode=vllm.get("server_mode", True), + rollout_vllm_server_mode_submission_threshold=submission_threshold, + rollout_vllm_server_mode_submission_timeout_s=submission_timeout_s, + rollout_vllm_async_scheduling=vllm.get("async_scheduling", True), + tensor_parallel_size=( + rollout_shape[1] if len(rollout_shape) > 1 else 1 + ), + data_parallel_size=rollout_shape[0], + rollout_vllm_max_num_seqs=max_num_seqs, + rollout_vllm_max_num_batched_tokens=max_batched_tokens, + rollout_vllm_kwargs=vllm.get( + "kwargs", + { + "kv_cache_metrics": True, + "disable_log_stats": False, + "enable_prefix_caching": True, + }, + ), + ) + + return {} + + # ------------------------------------------------------------------ + # Standard helpers (unchanged) + # ------------------------------------------------------------------ + + def create_cluster_config( + self, + *, + role_to_mesh: dict[rl_cluster_lib.Role, jax.sharding.Mesh], + rollout_config: base_rollout.RolloutConfig | None = None, + ): + if rollout_config is None: + rollout_config = self.create_rollout_config(role_to_mesh=role_to_mesh) + return rl_cluster_lib.ClusterConfig( + role_to_mesh=role_to_mesh, + rollout_engine=self._config_string("rollout_engine"), + offload_to_cpu=self._config_bool("offload_to_cpu"), + training_config=self.create_rl_training_config(), + rollout_config=rollout_config, + ) + + def create_rl_training_config(self): + base_key = "rl_training_config" + constructed_rl_training_config = self.obtain_training_config_dict(base_key) + + base_config = self._config_mapping(base_key) + if base_config.get("actor_optimizer_config"): + constructed_rl_training_config["actor_optimizer"] = self.create_optimizer( + base_key, "actor_optimizer_config" + ) + if base_config.get("critic_optimizer_config"): + constructed_rl_training_config["critic_optimizer"] = ( + self.create_optimizer(base_key, "critic_optimizer_config") + ) + + return rl_cluster_lib.RLTrainingConfig(**constructed_rl_training_config) + + def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig): + perf_metrics_options = cluster_config.training_config.perf_metrics_options + if not perf_metrics_options: + return None + + perf_config = perf_metrics.PerfMetricsConfig() + + if perf_metrics_options.enable_perf_v1: + custom_export_fn_path = perf_metrics_options.custom_export_fn_path + if custom_export_fn_path: + perf_config.custom_export_fn = self._get_function_from_path( + custom_export_fn_path + ) + if perf_config.custom_export_fn is None: + raise ValueError( + "Could not load custom export function from" + f" {custom_export_fn_path}" + ) + else: + perf_config.custom_export_fn = ( + perf_export.PerfMetricsExport.from_cluster_config(cluster_config) + ) + + if perf_metrics_options.enable_perf_v2: + custom_export_fn_path_v2 = perf_metrics_options.custom_export_fn_path_v2 + if custom_export_fn_path_v2: + perf_config.custom_export_fn_v2 = self._get_function_from_path( + custom_export_fn_path_v2 + ) + if perf_config.custom_export_fn_v2 is None: + raise ValueError( + "Could not load custom export function v2 from" + f" {custom_export_fn_path_v2}" + ) + else: + perf_config.custom_export_fn_v2 = ( + perf_export_v2.PerfMetricsExport.from_cluster_config( + cluster_config=cluster_config, + enable_trace_writer=perf_metrics_options.enable_trace_writer, + trace_dir=perf_metrics_options.trace_dir, + ).export_metrics + ) + return perf_config + + def create_rl_cluster(self, tokenizer): + role_to_mesh = self.create_role_to_mesh() + rollout_config = self.create_rollout_config(role_to_mesh=role_to_mesh) + reference_model_config = self._mutable_config_mapping( + "reference_model_config" + ) + actor_model_config = self._mutable_config_mapping("actor_model_config") + tokenizer_config = self._config_mapping("tokenizer_config") + # Should not use LoRA for reference model. + if reference_model_config.get("lora_config"): + logging.warning( + "LoRA config is not supported for the reference model. Disabling" + " LoRA." + ) + del reference_model_config["lora_config"] + reference_model, _ = model_lib.create_model( + dict(reference_model_config), + tokenizer_config, + role_to_mesh[rl_cluster_lib.Role.REFERENCE], + ) + if actor_model_config.get("lora_config", None): + actor_model = model_lib.apply_lora_to_model( + reference_model, + role_to_mesh[rl_cluster_lib.Role.ACTOR], + actor_model_config["lora_config"], + ) + else: + graph_def, params = nnx.split(reference_model) + actor_model = nnx.merge( + graph_def, + jax.tree.map(jnp.copy, params), + ) + + critic_model = None + critic_model_config = self._mutable_config_mapping("critic_model_config") + if critic_model_config: + critic_model, _ = model_lib.create_model( + dict(critic_model_config), + tokenizer_config, + role_to_mesh[rl_cluster_lib.Role.CRITIC], + ) + + if critic_model_config.get("lora_config", None): + critic_model = model_lib.apply_lora_to_model( + critic_model, + role_to_mesh[rl_cluster_lib.Role.CRITIC], + critic_model_config["lora_config"], + ) + + rngs = nnx.Rngs( + params=jax.random.key(critic_model_config.get("rng_seed", 0)) + ) + + # TODO (yatla2): Support all critic model types, not just Gemma + critic_model = gemma_lib.GemmaWithScoreHead(critic_model, rngs=rngs) + + cluster_config = self.create_cluster_config( + role_to_mesh=role_to_mesh, + rollout_config=rollout_config, + ) + perf_config = self.create_perf_config(cluster_config) + return rl_cluster_lib.RLCluster( + actor=actor_model, + critic=critic_model, + reference=reference_model, + tokenizer=tokenizer, + cluster_config=cluster_config, + perf_config=perf_config, + ) + + def _compute_max_steps(self, dataset, rl_training_config): + max_steps = None + if rl_training_config.get("max_steps"): + max_steps = rl_training_config.get("max_steps") + elif not hasattr(dataset, "__len__"): + raise ValueError( + "max_steps must be specified since the dataset length cannot be" + " determined." + ) + + dataset_length = len(dataset) + + batch_size = self.config.get("batch_size", 1) + num_batches = self.config.get("num_batches") + if not num_batches: + num_batches = dataset_length // batch_size + self.config["num_batches"] = num_batches + logging.info( + "Dynamically computed num_batches=%d with batch_size=%d", + num_batches, + batch_size, + ) + self.config["num_batches"] = num_batches + num_train_epochs = self.config.get("num_train_epochs") + if not num_train_epochs: + num_train_epochs = 1 + + train_fraction = self.config.get("train_fraction") + if not train_fraction: + train_fraction = 0.8 + elif train_fraction <= 0.0 and train_fraction > 1.0: + logging.warning( + "train_fraction %.2f out of expected range. Setting to 0.8", + train_fraction, + ) + train_fraction = 0.8 + + allowed_max_steps = int(num_batches * num_train_epochs * train_fraction) + if not max_steps: + max_steps = allowed_max_steps + elif max_steps > allowed_max_steps: + raise ValueError( + f"Maximum allowed value for max_steps is {allowed_max_steps}, but" + f" {max_steps} is specified." + ) + + logging.info( + "Dynamically computed max_steps=%d based on dataset length %d", + max_steps, + dataset_length, + ) + + return max_steps + + def compute_params(self, dataset): + rl_training_config = self._mutable_config_mapping("rl_training_config") + max_steps = self._compute_max_steps(dataset, rl_training_config) + + rl_training_config["max_steps"] = max_steps + self._apply_optimizer_step_limits( + rl_training_config, "actor_optimizer_config", max_steps + ) + self._apply_optimizer_step_limits( + rl_training_config, "critic_optimizer_config", max_steps + ) + + def _apply_optimizer_step_limits( + self, + rl_training_config: MutableMapping[str, Any], + optimizer_key: str, + max_steps: int, + ): + opt: MutableMapping[str, Any] | None = None + opt_value = rl_training_config.get(optimizer_key) + + if isinstance(opt_value, MutableMapping): + opt = opt_value + elif opt_value is not None: + raise ValueError(f"rl_training_config.{optimizer_key} must be a dict.") + + if opt and not opt.get("decay_steps"): + opt["decay_steps"] = max_steps + if opt and not opt.get("warmup_steps"): + warmup_ratio = self.config.get("warmup_ratio", 0.1) + warmup_steps = self.config.get("warmup_steps", warmup_ratio * max_steps) + opt["warmup_steps"] = warmup_steps + + # ------------------------------------------------------------------ + # Standard training + # ------------------------------------------------------------------ + + def _get_tokenizer(self): + model_config = self.config.get("actor_model_config") or self.config.get( + "model_config" + ) + return model_lib.create_tokenizer( + self.config["tokenizer_config"], + self.config["tokenizer_config"]["tokenizer_path"], + model_config=model_config, + ) + + def _get_data_module( + self, + ): + if self.data_module is None: + self.data_module = importlib.import_module(self.config["data_module"]) + return self.data_module + + def _get_dataset(self, tokenizer): + apply_chat_template_to_dataset = self.config.get( + "apply_chat_template_to_dataset" + ) + if apply_chat_template_to_dataset is None: + raise ValueError("apply_chat_template_to_dataset must be set.") + + if self.config.get("data_module", None): + data_module = self._config_string("data_module") + dataset = data_lib.get_dataset_from_module( + data_module, + tokenizer, + apply_chat_template_to_dataset=apply_chat_template_to_dataset, + **(self.config.get("data_config") or {}), + ) + elif self.config["data_source"] == "local": + dataset = example_data.create_dataset( + data_source=self.config["data_source"], + dataset=self.config["data_directory"], + tokenizer=tokenizer, + apply_chat_template_to_dataset=apply_chat_template_to_dataset, + ) + elif self.config["data_source"] == "tfds": + dataset = example_data.create_dataset( + data_source=self.config["data_source"], + dataset=self.config["dataset_name"], + tfds_download=self.config["tfds_download"], + split=self.config.get( + "train_split", self.config.get("split", "train") + ), + apply_chat_template_to_dataset=apply_chat_template_to_dataset, + ) + elif self.config["data_source"] == "huggingface": + dataset = example_data.create_dataset( + data_source=self.config["data_source"], + dataset=self.config["dataset_name"], + tokenizer=tokenizer, + split=self.config.get( + "train_split", self.config.get("split", "train") + ), + apply_chat_template_to_dataset=apply_chat_template_to_dataset, + ) + else: + raise ValueError(f"Unsupported data_source {self.config['data_source']}") + + return dataset + + # ------------------------------------------------------------------ + # Agentic helpers + # ------------------------------------------------------------------ + + def _create_chat_parser(self, tokenizer: Any) -> Any: + """Instantiate a chat parser based on chat_parser_config.type.""" + from tunix.rl.agentic.parser.chat_template_parser import parser as chat_parser_lib # pylint: disable=g-import-not-at-top + + parser_type = self._config_mapping("chat_parser_config").get( + "type", "default" + ) + if parser_type == "qwen": + return chat_parser_lib.QwenChatTemplateParser(tokenizer) + return chat_parser_lib.DefaultChatTemplateParser(tokenizer) + + def _load_class_from_path(self, dotted_path: str) -> type[Any]: + """Load a Python class from a dotted module path. + + Args: + dotted_path: Dotted module path to the class. + + Returns: + The loaded Python class. + """ + module_path, class_name = dotted_path.rsplit(".", 1) + return getattr(importlib.import_module(module_path), class_name) + + def _load_raw_dataset(self, tokenizer): + """Load a raw grain.MapDataset from data_module. + + The module must expose ``create_dataset(**data_config) -> grain.MapDataset`` + and optionally a ``batch_fn`` used as ``custom_batch_fn``. + + Args: + tokenizer: Tokenizer to use. + + Returns: + A tuple (dataset, batch_fn) containing the loaded dataset and batch + function. + """ + data_module = ( + self._get_data_module() + if self.config.get("data_module", None) + else None + ) + dataset = self._get_dataset(tokenizer) + batch_fn = getattr(data_module, "batch_fn", None) if data_module else None + return dataset, batch_fn + + def _setup_kubernetes(self) -> None: + k8s_cfg = self._config_mapping("kubernetes_config") + if not k8s_cfg: + return + os.environ["KUBECONFIG"] = k8s_cfg.get("kubeconfig", "~/.kube/config") + os.environ["NODE_SELECTOR_KEY"] = k8s_cfg.get( + "node_selector_key", "cloud.google.com/gke-nodepool" + ) + os.environ["NODE_SELECTOR_VAL"] = k8s_cfg.get( + "node_selector_val", "deepswe-cpu-pool" + ) + try: + from kubernetes import client as k8s_client_lib # type: ignore[import-untyped] # pylint: disable=g-import-not-at-top + from kubernetes import config as k8s_config_lib # type: ignore[import-untyped] # pylint: disable=g-import-not-at-top + + k8s_config_lib.load_kube_config() + k8s_client_lib.CoreV1Api() + except Exception as e: # pylint: disable=broad-except + logging.warning("Kubernetes config loading failed: %s", e) + + # ------------------------------------------------------------------ + # Agentic training + # ------------------------------------------------------------------ + @abc.abstractmethod + def _run(self, mode: str): + """Execute agentic training (DeepScaleR, DeepSWE, etc.).""" + pass + + def run_trainer(self): + """Dispatch to standard or agentic trainer based on training_mode.""" + mode = self.config.get("training_mode", self._default_training_mode) + self._run(mode=mode) + + +def setup_jax_pathways(pathways_bns: str): + """Sets up Jax with Pathways.""" + flags.FLAGS.pathways_ifrt = True + jax.config.update("jax_xla_backend", "pathways") + jax.config.update("jax_backend_target", pathways_bns) + + +def setup_pathways_on_cloud(): + import pathwaysutils # type: ignore[import-not-found,import-untyped] # pytype: disable=import-error # pyright: ignore[reportMissingImports] # pylint: disable=g-import-not-at-top + + pathwaysutils.initialize() diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index 3e9a50e01..6822c8c54 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -21,7 +21,7 @@ Usage:: # Standard GRPO - python -m tunix.cli.grpo_main examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml + bash ./examples/rl/grpo/gsm8k/run_gemma2_2b.sh # Agentic GRPO — DeepScaleR bash examples/deepscaler/run_deepscaler_disagg.sh @@ -29,38 +29,21 @@ # Agentic GRPO — DeepSWE python -m tunix.cli.grpo_main examples/deepswe/configs/qwen3_32b.yaml """ -from collections.abc import MutableMapping + import dataclasses -import importlib import os -import types from typing import Any from absl import app -from absl import flags from absl import logging -from flax import nnx -import jax -import jax.numpy as jnp -import numpy as np -from tunix.cli import config +from tunix.cli.base_rl_main import BasePipeline +from tunix.cli.base_rl_main import PATHWAYS_BNS +from tunix.cli.base_rl_main import setup_jax_pathways +from tunix.cli.base_rl_main import setup_pathways_on_cloud from tunix.cli.utils import data as data_lib -from tunix.cli.utils import model as model_lib -from tunix.examples.data import math_dataset as example_data -from tunix.perf import export as perf_export -from tunix.perf import metrics as perf_metrics -from tunix.perf.experimental import export as perf_export_v2 -from tunix.rl import rl_cluster as rl_cluster_lib -from tunix.rl.rollout import base_rollout -from tunix.utils import mesh as mesh_lib - -_PATHWAYS_BNS = flags.DEFINE_string( - "pathways_bns", None, "BNS address of the Pathways server." -) - -class GrpoPipeline(config.HyperParameters): +class GrpoPipeline(BasePipeline): """Runs standard GRPO or agentic GRPO depending on ``training_mode``. ``training_mode: "grpo"`` (default) — standard single-turn GRPO using @@ -86,594 +69,11 @@ class GrpoPipeline(config.HyperParameters): """ def __init__(self, argv: list[str], **kwargs): - self.data_module: types.ModuleType | None = None super().__init__(argv, **kwargs) - # ------------------------------------------------------------------ - # Mesh - # ------------------------------------------------------------------ - _ROLE_TO_MODEL_KEY = { - rl_cluster_lib.Role.ACTOR: "actor_model_config", - rl_cluster_lib.Role.CRITIC: "critic_model_config", - rl_cluster_lib.Role.REFERENCE: "reference_model_config", - rl_cluster_lib.Role.REWARD: "reward_model_config", - rl_cluster_lib.Role.ROLLOUT: "rollout_model_config", - } - _SPLIT_ROLE_ALIASES = { - "actor": rl_cluster_lib.Role.ACTOR, - "critic": rl_cluster_lib.Role.CRITIC, - "reference": rl_cluster_lib.Role.REFERENCE, - "reward": rl_cluster_lib.Role.REWARD, - "rollout": rl_cluster_lib.Role.ROLLOUT, - } - - def _resolve_split_role(self, role_name: str) -> rl_cluster_lib.Role: - normalized = role_name.strip().lower() - if normalized not in self._SPLIT_ROLE_ALIASES: - valid_roles = sorted(self._SPLIT_ROLE_ALIASES) - raise ValueError( - f"Unknown role name {role_name!r}. Expected one of {valid_roles}." - ) - return self._SPLIT_ROLE_ALIASES[normalized] - - def _get_same_mesh_as_map( - self, - ) -> dict[rl_cluster_lib.Role, rl_cluster_lib.Role]: - same_mesh_as = {} - for role, model_key in self._ROLE_TO_MODEL_KEY.items(): - model_cfg = self.config.get(model_key, {}) or {} - target_name = model_cfg.get("same_mesh_as") - if target_name is None: - continue - target_role = self._resolve_split_role(str(target_name)) - if role == rl_cluster_lib.Role.ACTOR: - raise ValueError("Actor must own its mesh.") - same_mesh_as[role] = target_role - - return same_mesh_as - - def _is_role_active(self, role: rl_cluster_lib.Role) -> bool: - if role in ( - rl_cluster_lib.Role.ACTOR, - rl_cluster_lib.Role.REFERENCE, - rl_cluster_lib.Role.ROLLOUT, - ): - return True - model_key = self._ROLE_TO_MODEL_KEY[role] - return model_key in self.config - - def _resolve_mesh_owners( - self, - ) -> dict[rl_cluster_lib.Role, rl_cluster_lib.Role]: - same_mesh_as = self._get_same_mesh_as_map() - base_owners = {} - for role, model_key in self._ROLE_TO_MODEL_KEY.items(): - if not self._is_role_active(role) and role not in same_mesh_as: - continue - has_mesh = bool(self.config.get(model_key, {}).get("mesh")) - base_owners[role] = ( - role - if role == rl_cluster_lib.Role.ACTOR or has_mesh - else rl_cluster_lib.Role.ACTOR - ) - - def resolve_owner( - role: rl_cluster_lib.Role, - seen: set[rl_cluster_lib.Role], - ) -> rl_cluster_lib.Role: - if role in seen: - raise ValueError("same_mesh_as contains a cycle.") - if role not in same_mesh_as: - return base_owners[role] - seen.add(role) - target_role = same_mesh_as[role] - if target_role not in base_owners: - raise ValueError( - f"Role {target_role.value!r} is not active in this config." - ) - return resolve_owner(target_role, seen) - - role_to_owner = {} - for role, model_key in self._ROLE_TO_MODEL_KEY.items(): - if role not in base_owners: - continue - has_mesh = bool(self.config.get(model_key, {}).get("mesh")) - if role in same_mesh_as: - if has_mesh: - raise ValueError( - f"{model_key}.mesh is specified, so it must own a separate mesh " - "and cannot also use same_mesh_as." - ) - else: - role_to_owner[role] = resolve_owner(role, set()) - continue - role_to_owner[role] = resolve_owner(role, set()) - return role_to_owner - - def create_role_to_mesh(self): - """Builds the role-to-mesh mapping for GRPO execution. - - Any role with an explicit ``*.mesh`` config gets a dedicated device slice. - Roles without a mesh share the actor mesh by default, or can point at - another role via ``same_mesh_as``. - - All mesh owners participating in the same allocation pass must agree on - one ``mesh.allocation_policy`` value. That policy is then passed to the - mesh allocator so users can choose between compact packing and - performance-oriented cubical packing from config. - - Returns: - A mapping from logical GRPO role to the concrete JAX mesh it should use. - - Raises: - ValueError: If mesh ownership resolution is invalid or if mesh owners - request conflicting allocation policies. - """ - devices = list(jax.devices()) - role_to_owner = self._resolve_mesh_owners() - owner_order = [] - for role in self._ROLE_TO_MODEL_KEY: - if role not in role_to_owner: - continue - owner = role_to_owner[role] - if owner not in owner_order: - owner_order.append(owner) - - mesh_requirements = [] - allocation_policy = None - for owner in owner_order: - model_key = self._ROLE_TO_MODEL_KEY[owner] - axis_shapes, _ = self.parse_mesh_config(model_key) - owner_policy = self._parse_mesh_allocation_policy(model_key) - if allocation_policy is None: - allocation_policy = owner_policy - elif owner_policy != allocation_policy: - raise ValueError( - "All owned meshes must use the same mesh.allocation_policy, got " - f"{allocation_policy!r} and {owner_policy!r}." - ) - mesh_requirements.append((model_key, int(np.prod(axis_shapes)))) - - allocated_devices = mesh_lib.allocate_named_mesh_device_slices( - mesh_requirements, - devices=devices, - allocation_policy=allocation_policy - or mesh_lib.normalize_allocation_policy(None), - ) - - owner_to_mesh = {} - for owner in owner_order: - model_key = self._ROLE_TO_MODEL_KEY[owner] - axis_shapes, axis_names = self.parse_mesh_config(model_key) - assigned_devices = allocated_devices[model_key] - owner_to_mesh[owner] = mesh_lib.create_mesh( - axis_shapes, axis_names, devices=assigned_devices - ) - return {role: owner_to_mesh[owner] for role, owner in role_to_owner.items()} - - # ------------------------------------------------------------------ - # Rollout config - # ------------------------------------------------------------------ - - def create_rollout_config( - self, - role_to_mesh: dict[rl_cluster_lib.Role, jax.sharding.Mesh] | None = None, - ) -> base_rollout.RolloutConfig: - """Build RolloutConfig from YAML. - - Standard mode: pass rollout_config fields through with kv_cache_size = - max_prompt_length + total_generation_steps + 256. - - Agentic mode: same base. Same kv_cache_size calculation. - - Engine-specific extras (sglang_jax_config, vllm_config) are also applied. - - Args: - role_to_mesh: Optional mapping from logical role to JAX mesh. - - Returns: - The constructed RolloutConfig. - """ - rollout_cfg = self._config_mapping("rollout_config") - mode = self._config_string("training_mode", "grpo") - engine = self._config_string("rollout_engine", "vanilla") - - valid_fields = { - f.name for f in dataclasses.fields(base_rollout.RolloutConfig) - } - - # Base pass-through (same as original create_rollout_config) - filtered = {k: v for k, v in rollout_cfg.items() if k in valid_fields} - if "total_generation_steps" in rollout_cfg: - filtered["max_tokens_to_generate"] = rollout_cfg["total_generation_steps"] - - max_prompt = rollout_cfg.get("max_prompt_length", 0) - max_response = rollout_cfg.get("total_generation_steps", 0) - - kv_cache_size = 0 - if mode == "agentic_grpo": - agentic_cfg = self._config_mapping("agentic_grpo_config") - kv_cache_size = max_prompt + max_response + 256 - filtered["kv_cache_size"] = kv_cache_size - logging.info("kv_cache_size: %d", kv_cache_size) - - max_running_requests = agentic_cfg.get("max_concurrency", 16) - else: - grpo_cfg = self._config_mapping("grpo_config") - # Standard: kv_cache_size = max_prompt + max_response + 256 - if max_prompt and max_response: - kv_cache_size = max_prompt + max_response + 256 - filtered["kv_cache_size"] = kv_cache_size - # Defaults to global batch size * num_generations to allow full - # concurrency. - max_running_requests = self.config.get("batch_size", 1) * grpo_cfg.get( - "num_generations", 1 - ) - - # Engine-specific extras - extra = self._rollout_engine_extra( - engine, - kv_cache_size, - max_running_requests, - role_to_mesh=role_to_mesh, - ) - filtered.update({k: v for k, v in extra.items() if k in valid_fields}) - return base_rollout.RolloutConfig(**filtered) - - def _rollout_engine_extra( - self, - engine: str, - kv_cache_size: int, - max_running_requests: int, - role_to_mesh: dict[rl_cluster_lib.Role, jax.sharding.Mesh] | None = None, - ) -> dict[str, Any]: - """Return engine-specific RolloutConfig fields for agentic mode.""" - model_id = self._config_mapping("actor_model_config").get("model_id", "") - - if engine == "sglang_jax": - sg = self._config_mapping("sglang_jax_config") - return dict( - rollout_sglang_jax_model_version=sg.get("model_version", model_id), - rollout_sglang_jax_mem_fraction_static=sg.get( - "mem_fraction_static", 0.8 - ), - rollout_sglang_jax_init_with_random_weights=sg.get( - "init_with_random_weights", True - ), - rollout_sglang_jax_disable_radix_cache=sg.get( - "disable_radix_cache", True - ), - rollout_sglang_jax_enable_deterministic_sampling=sg.get( - "enable_deterministic_sampling", False - ), - rollout_sglang_jax_chunked_prefill_size=sg.get( - "chunked_prefill_size", 2048 - ), - rollout_sglang_jax_max_running_requests=sg.get( - "max_running_requests", - max_running_requests, - ), - rollout_sglang_jax_page_size=sg.get("page_size", 128), - rollout_sglang_jax_use_sort_for_toppk_minp=sg.get( - "use_sort_for_toppk_minp", False - ), - ) - - if engine == "vllm": - vllm = self._config_mapping("vllm_config") - if role_to_mesh is None: - raise ValueError( - "role_to_mesh must be provided for vllm rollout config." - ) - rollout_shape = role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.shape - rollout_cfg = self._config_mapping("rollout_config") - max_num_seqs = rollout_cfg.get( - "rollout_vllm_max_num_seqs", - vllm.get("max_num_seqs", 768), - ) - max_batched_tokens = rollout_cfg.get( - "rollout_vllm_max_num_batched_tokens", - vllm.get( - "max_num_batched_tokens", - (max_num_seqs * kv_cache_size) // 4, - ), - ) - submission_threshold = rollout_cfg.get( - "rollout_vllm_server_mode_submission_threshold", - vllm.get("server_mode_submission_threshold", 0), - ) - submission_timeout_s = rollout_cfg.get( - "rollout_vllm_server_mode_submission_timeout_s", - vllm.get("server_mode_submission_timeout_s", 0.0), - ) - os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" - return dict( - rollout_vllm_model_version=vllm.get("model_version", model_id), - rollout_vllm_hbm_utilization=vllm.get("hbm_utilization", 0.4), - rollout_vllm_tpu_backend_type=vllm.get("tpu_backend_type", "jax"), - rollout_vllm_server_mode=vllm.get("server_mode", True), - rollout_vllm_server_mode_submission_threshold=submission_threshold, - rollout_vllm_server_mode_submission_timeout_s=submission_timeout_s, - rollout_vllm_async_scheduling=vllm.get("async_scheduling", True), - tensor_parallel_size=( - rollout_shape[1] if len(rollout_shape) > 1 else 1 - ), - data_parallel_size=rollout_shape[0], - rollout_vllm_max_num_seqs=max_num_seqs, - rollout_vllm_max_num_batched_tokens=max_batched_tokens, - rollout_vllm_kwargs=vllm.get( - "kwargs", - { - "kv_cache_metrics": True, - "disable_log_stats": False, - "enable_prefix_caching": True, - }, - ), - ) - - return {} - - # ------------------------------------------------------------------ - # Standard GRPO helpers (unchanged) - # ------------------------------------------------------------------ - - def create_cluster_config( - self, - *, - role_to_mesh: dict[rl_cluster_lib.Role, jax.sharding.Mesh], - rollout_config: base_rollout.RolloutConfig | None = None, - ): - if rollout_config is None: - rollout_config = self.create_rollout_config(role_to_mesh=role_to_mesh) - return rl_cluster_lib.ClusterConfig( - role_to_mesh=role_to_mesh, - rollout_engine=self._config_string("rollout_engine"), - offload_to_cpu=self._config_bool("offload_to_cpu"), - training_config=self.create_rl_training_config(), - rollout_config=rollout_config, - ) - - def create_rl_training_config(self): - base_key = "rl_training_config" - constructed_rl_training_config = self.obtain_training_config_dict(base_key) - - base_config = self._config_mapping(base_key) - if base_config.get("actor_optimizer_config"): - constructed_rl_training_config["actor_optimizer"] = self.create_optimizer( - base_key, "actor_optimizer_config" - ) - if base_config.get("critic_optimizer_config"): - constructed_rl_training_config["critic_optimizer"] = ( - self.create_optimizer(base_key, "critic_optimizer_config") - ) - - return rl_cluster_lib.RLTrainingConfig(**constructed_rl_training_config) - - def create_perf_config(self, cluster_config: rl_cluster_lib.ClusterConfig): - perf_metrics_options = cluster_config.training_config.perf_metrics_options - if not perf_metrics_options: - return None - - perf_config = perf_metrics.PerfMetricsConfig() - - if perf_metrics_options.enable_perf_v1: - custom_export_fn_path = perf_metrics_options.custom_export_fn_path - if custom_export_fn_path: - perf_config.custom_export_fn = self._get_function_from_path( - custom_export_fn_path - ) - if perf_config.custom_export_fn is None: - raise ValueError( - "Could not load custom export function from" - f" {custom_export_fn_path}" - ) - else: - perf_config.custom_export_fn = ( - perf_export.PerfMetricsExport.from_cluster_config(cluster_config) - ) - - if perf_metrics_options.enable_perf_v2: - custom_export_fn_path_v2 = perf_metrics_options.custom_export_fn_path_v2 - if custom_export_fn_path_v2: - perf_config.custom_export_fn_v2 = self._get_function_from_path( - custom_export_fn_path_v2 - ) - if perf_config.custom_export_fn_v2 is None: - raise ValueError( - "Could not load custom export function v2 from" - f" {custom_export_fn_path_v2}" - ) - else: - perf_config.custom_export_fn_v2 = ( - perf_export_v2.PerfMetricsExport.from_cluster_config( - cluster_config=cluster_config, - enable_trace_writer=perf_metrics_options.enable_trace_writer, - trace_dir=perf_metrics_options.trace_dir, - ).export_metrics - ) - return perf_config - - def create_rl_cluster(self, tokenizer): - role_to_mesh = self.create_role_to_mesh() - rollout_config = self.create_rollout_config(role_to_mesh=role_to_mesh) - reference_model_config = self._mutable_config_mapping( - "reference_model_config" - ) - actor_model_config = self._mutable_config_mapping("actor_model_config") - tokenizer_config = self._config_mapping("tokenizer_config") - # Should not use LoRA for reference model. - if reference_model_config.get("lora_config"): - logging.warning( - "LoRA config is not supported for the reference model. Disabling" - " LoRA." - ) - del reference_model_config["lora_config"] - reference_model, _ = model_lib.create_model( - dict(reference_model_config), - tokenizer_config, - role_to_mesh[rl_cluster_lib.Role.REFERENCE], - ) - if actor_model_config.get("lora_config", None): - actor_model = model_lib.apply_lora_to_model( - reference_model, - role_to_mesh[rl_cluster_lib.Role.ACTOR], - actor_model_config["lora_config"], - ) - else: - graph_def, params = nnx.split(reference_model) - actor_model = nnx.merge( - graph_def, - jax.tree.map(jnp.copy, params), - ) - - cluster_config = self.create_cluster_config( - role_to_mesh=role_to_mesh, - rollout_config=rollout_config, - ) - perf_config = self.create_perf_config(cluster_config) - return rl_cluster_lib.RLCluster( - actor=actor_model, - reference=reference_model, - tokenizer=tokenizer, - cluster_config=cluster_config, - perf_config=perf_config, - ) - - def compute_params(self, dataset): - rl_training_config = self._mutable_config_mapping("rl_training_config") - - # Return early if max_steps is already specified. - max_steps = None - if rl_training_config.get("max_steps"): - max_steps = rl_training_config.get("max_steps") - elif not hasattr(dataset, "__len__"): - raise ValueError( - "max_steps must be specified since the dataset length cannot be" - " determined." - ) - - dataset_length = len(dataset) - - batch_size = self.config.get("batch_size", 1) - num_batches = self.config.get("num_batches") - if not num_batches: - num_batches = dataset_length // batch_size - self.config["num_batches"] = num_batches - logging.info( - "Dynamically computed num_batches=%d with batch_size=%d", - num_batches, - batch_size, - ) - self.config["num_batches"] = num_batches - num_train_epochs = self.config.get("num_train_epochs") - if not num_train_epochs: - num_train_epochs = 1 - - train_fraction = self.config.get("train_fraction") - if not train_fraction: - train_fraction = 0.8 - elif train_fraction <= 0.0 and train_fraction > 1.0: - logging.warning( - "train_fraction %.2f out of expected range. Setting to 0.8", - train_fraction, - ) - train_fraction = 0.8 - - allowed_max_steps = int(num_batches * num_train_epochs * train_fraction) - if not max_steps: - max_steps = allowed_max_steps - elif max_steps > allowed_max_steps: - raise ValueError( - f"Maximum allowed value for max_steps is {allowed_max_steps}, but" - f" {max_steps} is specified." - ) - - rl_training_config["max_steps"] = max_steps - actor_opt: MutableMapping[str, Any] | None = None - actor_opt_value = rl_training_config.get("actor_optimizer_config") - if isinstance(actor_opt_value, MutableMapping): - actor_opt = actor_opt_value - elif actor_opt_value is not None: - raise ValueError( - "rl_training_config.actor_optimizer_config must be a dict." - ) - if actor_opt and not actor_opt.get("decay_steps"): - actor_opt["decay_steps"] = max_steps - if actor_opt and not actor_opt.get("warmup_steps"): - warmup_ratio = self.config.get("warmup_ratio", 0.1) - warmup_steps = self.config.get("warmup_steps", warmup_ratio * max_steps) - actor_opt["warmup_steps"] = warmup_steps - logging.info( - "Dynamically computed max_steps=%d based on dataset length %d", - max_steps, - dataset_length, - ) - - # ------------------------------------------------------------------ - # Standard GRPO training - # ------------------------------------------------------------------ - - def _get_tokenizer(self): - model_config = self.config.get("actor_model_config") or self.config.get("model_config") - return model_lib.create_tokenizer( - self.config["tokenizer_config"], - self.config["tokenizer_config"]["tokenizer_path"], - model_config=model_config, - ) - - def _get_data_module(self,): - if self.data_module is None: - self.data_module = importlib.import_module(self.config["data_module"]) - return self.data_module - - def _get_dataset(self, tokenizer): - apply_chat_template_to_dataset = self.config.get( - "apply_chat_template_to_dataset" - ) - if apply_chat_template_to_dataset is None: - raise ValueError( - "apply_chat_template_to_dataset must be set." - ) - - if self.config.get("data_module", None): - data_module = self._config_string("data_module") - dataset = data_lib.get_dataset_from_module( - data_module, - tokenizer, - apply_chat_template_to_dataset=apply_chat_template_to_dataset, - **(self.config.get("data_config") or {}), - ) - elif self.config["data_source"] == "local": - dataset = example_data.create_dataset( - data_source=self.config["data_source"], - dataset=self.config["data_directory"], - tokenizer=tokenizer, - apply_chat_template_to_dataset=apply_chat_template_to_dataset, - ) - elif self.config["data_source"] == "tfds": - dataset = example_data.create_dataset( - data_source=self.config["data_source"], - dataset=self.config["dataset_name"], - tfds_download=self.config["tfds_download"], - split=self.config.get( - "train_split", self.config.get("split", "train") - ), - apply_chat_template_to_dataset=apply_chat_template_to_dataset, - ) - elif self.config["data_source"] == "huggingface": - dataset = example_data.create_dataset( - data_source=self.config["data_source"], - dataset=self.config["dataset_name"], - tokenizer=tokenizer, - split=self.config.get( - "train_split", self.config.get("split", "train") - ), - apply_chat_template_to_dataset=apply_chat_template_to_dataset, - ) - else: - raise ValueError(f"Unsupported data_source {self.config['data_source']}") - - return dataset + @property + def _default_training_mode(self): + return "grpo" # ------------------------------------------------------------------ # Agentic GRPO helpers @@ -703,71 +103,6 @@ def _create_agentic_grpo_config(self): cfg.pop("max_turns", None) return GRPOConfig(**{k: v for k, v in cfg.items() if k in valid}) - def _create_chat_parser(self, tokenizer: Any) -> Any: - """Instantiate a chat parser based on chat_parser_config.type.""" - from tunix.rl.agentic.parser.chat_template_parser import parser as chat_parser_lib # pylint: disable=g-import-not-at-top - - parser_type = self._config_mapping("chat_parser_config").get( - "type", "default" - ) - if parser_type == "qwen": - return chat_parser_lib.QwenChatTemplateParser(tokenizer) - return chat_parser_lib.DefaultChatTemplateParser(tokenizer) - - def _load_class_from_path(self, dotted_path: str) -> type[Any]: - """Load a Python class from a dotted module path. - - Args: - dotted_path: Dotted module path to the class. - - Returns: - The loaded Python class. - """ - module_path, class_name = dotted_path.rsplit(".", 1) - return getattr(importlib.import_module(module_path), class_name) - - def _load_raw_dataset(self, tokenizer): - """Load a raw grain.MapDataset from data_module. - - The module must expose ``create_dataset(**data_config) -> grain.MapDataset`` - and optionally a ``batch_fn`` used as ``custom_batch_fn``. - - Args: - tokenizer: Tokenizer to use. - - Returns: - A tuple (dataset, batch_fn) containing the loaded dataset and batch - function. - """ - data_module = ( - self._get_data_module() - if self.config.get("data_module", None) - else None - ) - dataset = self._get_dataset(tokenizer) - batch_fn = getattr(data_module, "batch_fn", None) if data_module else None - return dataset, batch_fn - - def _setup_kubernetes(self) -> None: - k8s_cfg = self._config_mapping("kubernetes_config") - if not k8s_cfg: - return - os.environ["KUBECONFIG"] = k8s_cfg.get("kubeconfig", "~/.kube/config") - os.environ["NODE_SELECTOR_KEY"] = k8s_cfg.get( - "node_selector_key", "cloud.google.com/gke-nodepool" - ) - os.environ["NODE_SELECTOR_VAL"] = k8s_cfg.get( - "node_selector_val", "deepswe-cpu-pool" - ) - try: - from kubernetes import client as k8s_client_lib # type: ignore[import-untyped] # pylint: disable=g-import-not-at-top - from kubernetes import config as k8s_config_lib # type: ignore[import-untyped] # pylint: disable=g-import-not-at-top - - k8s_config_lib.load_kube_config() - k8s_client_lib.CoreV1Api() - except Exception as e: # pylint: disable=broad-except - logging.warning("Kubernetes config loading failed: %s", e) - # ------------------------------------------------------------------ # Agentic GRPO training # ------------------------------------------------------------------ @@ -818,6 +153,7 @@ def _run(self, mode: str = "grpo"): raise ValueError(f"Unsupported training_mode {mode!r}") from tunix.rl.agentic.agentic_grpo_learner import GRPOLearner # pylint: disable=g-import-not-at-top + algo_config = self._create_agentic_grpo_config() reward_fns = ( @@ -848,35 +184,13 @@ def _run(self, mode: str = "grpo"): logging.info("Starting agentic GRPO training...") GRPOLearner(**learner_kwargs).train(dataset) - # ------------------------------------------------------------------ - # Dispatcher - # ------------------------------------------------------------------ - - def run_grpo_trainer(self): - """Dispatch to standard or agentic GRPO based on training_mode.""" - mode = self.config.get("training_mode", "grpo") - self._run(mode=mode) - - -def _setup_jax_pathways(pathways_bns: str): - """Sets up Jax with Pathways.""" - flags.FLAGS.pathways_ifrt = True - jax.config.update("jax_xla_backend", "pathways") - jax.config.update("jax_backend_target", pathways_bns) - - -def _setup_pathways_on_cloud(): - import pathwaysutils # type: ignore[import-not-found,import-untyped] # pytype: disable=import-error # pyright: ignore[reportMissingImports] # pylint: disable=g-import-not-at-top - - pathwaysutils.initialize() - def main(argv, **kwargs): - if _PATHWAYS_BNS.value: - _setup_jax_pathways(_PATHWAYS_BNS.value) + if PATHWAYS_BNS.value: + setup_jax_pathways(PATHWAYS_BNS.value) if os.getenv("JAX_PLATFORMS") == "proxy": - _setup_pathways_on_cloud() + setup_pathways_on_cloud() pipeline = GrpoPipeline(argv, **kwargs) logging.info( @@ -884,7 +198,7 @@ def main(argv, **kwargs): "%r\n--------------------------", pipeline.config, ) - pipeline.run_grpo_trainer() + pipeline.run_trainer() if __name__ == "__main__":