diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 77440e3294..87581854c9 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -2,6 +2,7 @@ # This file is auto-generated from the .j2 file via generate_github_workflows.py. Do not edit manually. ################################################################################ + name: PR Test on: @@ -29,89 +30,6 @@ concurrency: jobs: - e2e-test-short: - - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) - - - runs-on: self-hosted - - strategy: - fail-fast: false - matrix: - info: [{"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_ppo_critic_only_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_fanout_short.py"}, {"num_gpus": 4, "test_file": "test_delta_weight_update.py"}] - defaults: - run: - working-directory: ${{ github.workspace }} - env: - GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} - WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} - SLIME_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} - SLIME_TEST_USE_DEEPEP: ${{ matrix.info.use_deepep || '0' }} - SLIME_TEST_USE_FP8_ROLLOUT: ${{ matrix.info.use_fp8_rollout || '0' }} - SLIME_TEST_ENABLE_EVAL: ${{ matrix.info.enable_eval || '1' }} - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - - - name: Execute - shell: bash - run: | - - docker run --pull=always --rm \ - --privileged \ - --cap-add SYS_NICE \ - --security-opt seccomp=unconfined \ - --network host \ - --gpus all \ - --ipc=host \ - --shm-size=16g \ - --ulimit memlock=-1 \ - --ulimit stack=67108864 \ - --memory=0 \ - --memory-swap=0 \ - -e http_proxy \ - -e https_proxy \ - -e HTTP_PROXY \ - -e HTTPS_PROXY \ - -e GITHUB_COMMIT_NAME \ - -e WANDB_API_KEY \ - -e SLIME_TEST_ENABLE_INFINITE_RUN \ - -e SLIME_TEST_USE_DEEPEP \ - -e SLIME_TEST_USE_FP8_ROLLOUT \ - -e SLIME_TEST_ENABLE_EVAL \ - -e TEST_FILE="${{ matrix.info.test_file }}" \ - -e TEST_ARGS="${{ matrix.info.test_args || '' }}" \ - -e NUM_GPUS="${{ matrix.info.num_gpus }}" \ - -v "$GITHUB_WORKSPACE:$GITHUB_WORKSPACE" \ - -v /mnt/nvme0n1/slime_ci:/data/slime_ci \ - -v /mnt/nvme0n1/slime_ci/models:/root/models \ - -v /mnt/nvme0n1/slime_ci/datasets:/root/datasets \ - -w "$GITHUB_WORKSPACE" \ - slimerl/slime:latest \ - bash -lc ' - set -euo pipefail - pip install -e . --no-deps --break-system-packages - TEST_PATH="$TEST_FILE" - if [[ "$TEST_PATH" != tests/* ]]; then - TEST_PATH="tests/$TEST_PATH" - fi - if [[ -n "$TEST_ARGS" ]]; then - read -r -a TEST_ARGS_ARRAY < <(printf "%s\n" "$TEST_ARGS") - else - TEST_ARGS_ARRAY=() - fi - if [ "$NUM_GPUS" = "0" ]; then - python "$TEST_PATH" "${TEST_ARGS_ARRAY[@]}" - else - python tests/ci/gpu_lock_exec.py --count "$NUM_GPUS" -- python "$TEST_PATH" "${TEST_ARGS_ARRAY[@]}" - fi - ' - - e2e-test-sglang-config: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-sglang-config')) @@ -205,7 +123,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 4, "test_file": "test_full_disk_weight_update.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_glm4.7_30B_A3B_pd_mooncake.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"num_gpus": 8, "test_file": "test_qwen3.6_35B_A3B_pd_mooncake.py", "use_deepep": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_disaggregate.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_train_critic_only.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_debug_rollout_then_train.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_opd_sglang.py"}, {"num_gpus": 6, "test_file": "test_qwen3_4B_external_pd.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_fully_async_short.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_streaming_partial_rollout.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_short.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--async-save", "test_file": "test_qwen3_4B_ckpt.py"}] + info: [{"num_gpus": 4, "test_file": "test_full_disk_weight_update.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_glm4.7_30B_A3B_pd_mooncake.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"num_gpus": 8, "test_file": "test_qwen3.6_35B_A3B_pd_mooncake.py", "use_deepep": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_disaggregate.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_train_critic_only.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_debug_rollout_then_train.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_opd_sglang.py"}, {"num_gpus": 6, "test_file": "test_qwen3_4B_external_pd.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_fully_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_fanout_short.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_streaming_partial_rollout.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_short.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--async-save", "test_file": "test_qwen3_4B_ckpt.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -454,7 +372,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] + info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -588,7 +506,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_short.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_glm4.7_30B_A3B_pd_mooncake.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3.6_35B_A3B_pd_mooncake.py", "use_deepep": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--async-save", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_debug_rollout_then_train.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_opd_sglang.py"}] + info: [{"num_gpus": 4, "test_file": "test_full_disk_weight_update.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_glm4.7_30B_A3B_pd_mooncake.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"num_gpus": 8, "test_file": "test_qwen3.6_35B_A3B_pd_mooncake.py", "use_deepep": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_disaggregate.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_train_critic_only.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_debug_rollout_then_train.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_opd_sglang.py"}, {"num_gpus": 6, "test_file": "test_qwen3_4B_external_pd.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_fully_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_fanout_short.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_streaming_partial_rollout.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_short.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--async-save", "test_file": "test_qwen3_4B_ckpt.py"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 1562e6b9ea..c081daa060 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -1,14 +1,32 @@ +<% set megatron_tests = [ + {'test_file': 'test_full_disk_weight_update.py', 'num_gpus': 4}, + {'test_file': 'test_quick_start_glm4_9B.py', 'num_gpus': 8, 'enable_eval': '0'}, + {'test_file': 'test_glm4.7_30B_A3B_pd_mooncake.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_30B_A3B.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1', 'enable_eval': '0'}, + {'test_file': 'test_qwen3.6_35B_A3B_pd_mooncake.py', 'num_gpus': 8, 'use_deepep': '1'}, + {'test_file': 'test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1', 'enable_eval': '0'}, + {'test_file': 'test_qwen3_4B_ppo.py', 'num_gpus': 8, 'enable_eval': '0'}, + {'test_file': 'test_qwen3_4B_ppo_disaggregate.py', 'num_gpus': 8, 'enable_eval': '0'}, + {'test_file': 'test_qwen3_4B_ppo_train_critic_only.py', 'num_gpus': 8, 'enable_eval': '0'}, + {'test_file': 'test_moonlight_16B_A3B.py', 'num_gpus': 8, 'enable_eval': '0'}, + {'test_file': 'test_moonlight_16B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'}, + {'test_file': 'test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, + {'test_file': 'test_qwen2.5_0.5B_debug_rollout_then_train.py', 'num_gpus': 8}, + {'test_file': 'test_qwen2.5_0.5B_opd_sglang.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_4B_external_pd.py', 'num_gpus': 6}, + {'test_file': 'test_qwen2.5_0.5B_fully_async_short.py', 'num_gpus': 4}, + {'test_file': 'test_qwen2.5_0.5B_fanout_short.py', 'num_gpus': 4}, + {'test_file': 'test_qwen3_4B_streaming_partial_rollout.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3.5_0.8B_gsm8k_short.py', 'num_gpus': 4}, + {'test_file': 'test_qwen3.5_0.8B_gsm8k_async_short.py', 'num_gpus': 4}, + {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--save-optimizer gpu --load-optimizer gpu', 'num_gpus': 8}, + {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--save-optimizer gpu --load-optimizer cpu', 'num_gpus': 8}, + {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--save-optimizer cpu --load-optimizer cpu', 'num_gpus': 8}, + {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--save-optimizer cpu --load-optimizer gpu', 'num_gpus': 8}, + {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--async-save', 'num_gpus': 8}, +] %> <% set jobs = { - 'e2e-test-short': { - 'label': 'run-ci-short', - 'tests': [ - {'test_file': 'test_qwen3.5_0.8B_gsm8k_async_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen3.5_0.8B_gsm8k_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen2.5_0.5B_ppo_critic_only_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen2.5_0.5B_fanout_short.py', 'num_gpus': 4}, - {'test_file': 'test_delta_weight_update.py', 'num_gpus': 4}, - ], - }, 'e2e-test-sglang-config': { 'label': 'run-ci-sglang-config', 'tests': [ @@ -20,33 +38,7 @@ }, 'e2e-test-megatron': { 'label': 'run-ci-megatron', - 'tests': [ - {'test_file': 'test_full_disk_weight_update.py', 'num_gpus': 4}, - {'test_file': 'test_quick_start_glm4_9B.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_glm4.7_30B_A3B_pd_mooncake.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_30B_A3B.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1', 'enable_eval': '0'}, - {'test_file': 'test_qwen3.6_35B_A3B_pd_mooncake.py', 'num_gpus': 8, 'use_deepep': '1'}, - {'test_file': 'test_qwen3_30B_A3B_r3.py', 'num_gpus': 8, 'use_deepep': '1', 'use_fp8_rollout': '1', 'enable_eval': '0'}, - {'test_file': 'test_qwen3_4B_ppo.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_qwen3_4B_ppo_disaggregate.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_qwen3_4B_ppo_train_critic_only.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_moonlight_16B_A3B.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_moonlight_16B_A3B_r3.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, - {'test_file': 'test_qwen2.5_0.5B_debug_rollout_then_train.py', 'num_gpus': 8}, - {'test_file': 'test_qwen2.5_0.5B_opd_sglang.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_external_pd.py', 'num_gpus': 6}, - {'test_file': 'test_qwen2.5_0.5B_fully_async_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen3_4B_streaming_partial_rollout.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3.5_0.8B_gsm8k_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen3.5_0.8B_gsm8k_async_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--save-optimizer gpu --load-optimizer gpu', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--save-optimizer gpu --load-optimizer cpu', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--save-optimizer cpu --load-optimizer cpu', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--save-optimizer cpu --load-optimizer gpu', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--async-save', 'num_gpus': 8}, - ], + 'tests': megatron_tests, }, 'e2e-test-precision': { 'label': 'run-ci-precision', @@ -77,6 +69,7 @@ {'test_file': 'test_metric_report_dist.py', 'num_gpus': 0}, {'test_file': 'test_loss_cp_invariance.py', 'num_gpus': 0}, {'test_file': 'test_value_temperature.py', 'num_gpus': 0}, + {'test_file': 'test_cispo_loss.py', 'num_gpus': 0}, {'test_file': 'test_rm_f1.py', 'num_gpus': 0}, {'test_file': 'test_rm_gpqa.py', 'num_gpus': 0}, {'test_file': 'test_rm_math.py', 'num_gpus': 0}, @@ -109,25 +102,7 @@ 'e2e-test-image': { 'label': 'run-ci-image', 'image': 'slimerl/slime-test:latest', - 'tests': [ - {'test_file': 'test_qwen3.5_0.8B_gsm8k_async_short.py', 'num_gpus': 4}, - {'test_file': 'test_qwen3.5_0.8B_gsm8k_short.py', 'num_gpus': 4}, - {'test_file': 'test_quick_start_glm4_9B.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_glm4.7_30B_A3B_pd_mooncake.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_30B_A3B.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_qwen3.6_35B_A3B_pd_mooncake.py', 'num_gpus': 8, 'use_deepep': '1'}, - {'test_file': 'test_qwen3_4B_ppo.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_moonlight_16B_A3B.py', 'num_gpus': 8, 'enable_eval': '0'}, - {'test_file': 'test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--save-optimizer gpu --load-optimizer gpu', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--save-optimizer gpu --load-optimizer cpu', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--save-optimizer cpu --load-optimizer cpu', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--save-optimizer cpu --load-optimizer gpu', 'num_gpus': 8}, - {'test_file': 'test_qwen3_4B_ckpt.py', 'test_args': '--async-save', 'num_gpus': 8}, - {'test_file': 'test_qwen2.5_0.5B_debug_rollout_then_train.py', 'num_gpus': 8}, - {'test_file': 'test_qwen2.5_0.5B_opd_sglang.py', 'num_gpus': 8}, - ], + 'tests': megatron_tests, }, } %> name: PR Test diff --git a/README.md b/README.md index 7bf0a21b1e..9337b0065f 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [中文版](./README_zh.md) [![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://thudm.github.io/slime/) +[![CI](https://img.shields.io/github/actions/workflow/status/THUDM/slime/pr-test.yml?branch=zilin%2Fci-dont-merge&event=pull_request&label=CI&logo=github)](https://github.com/THUDM/slime/pull/2053/checks) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/THUDM/slime) **slime** is an LLM post-training framework for RL scaling, providing two core capabilities: @@ -112,9 +113,9 @@ See the [Customization Guide](docs/en/get_started/customization.md) for which in These are not just demos. They are independent systems that use slime as a reusable RL substrate for production-scale post-training, agentic RL, domain RL, and rollout-system research. -### ⛵ Miles: Production-Focused Reinforcement Learning Framework Built on slime +### ⛵ Miles: Enterprise-Grade Reinforcement Learning for Large-Scale Model Training -[**Miles**](https://github.com/radixark/miles) builds on the foundation of slime to provide a production-focused reinforcement learning framework for large-scale model post-training. It stays closely aligned with slime's upstream development while extending it with enterprise-oriented features: deeper [SGLang](https://github.com/sgl-project/sglang) integration, operational tooling, deployment support, and optimizations for new [models](https://www.radixark.com/miles/docs/models) and [hardware](https://www.radixark.com/miles/docs/platforms). Miles also adds production features such as LoRA, TITO, and low-precision training. +[Miles](https://github.com/radixark/miles) is an RL post-training framework for large-scale models, built on slime by [RadixArk](https://github.com/radixark). It stays closely aligned with slime's upstream development while extending it with enterprise-oriented features: deeper [SGLang](https://github.com/sgl-project/sglang) integration, operational tooling, deployment support, and optimizations for new [models](https://www.radixark.com/miles/docs/models) and [hardware](https://www.radixark.com/miles/docs/platforms). Miles also adds a growing set of production features, including LoRA, TITO, and low-precision training. ### 🔷 vime: vLLM-Native RL Post-Training Built on slime diff --git a/README_zh.md b/README_zh.md index 2244467271..0f62b1b8a6 100644 --- a/README_zh.md +++ b/README_zh.md @@ -3,6 +3,7 @@ [English](./README.md) [![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://thudm.github.io/slime/) +[![CI](https://img.shields.io/github/actions/workflow/status/THUDM/slime/pr-test.yml?branch=zilin%2Fci-dont-merge&event=pull_request&label=CI&logo=github)](https://github.com/THUDM/slime/pull/2053/checks) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/THUDM/slime) **slime** 是为 RL scaling 设计的 LLM post‑training 框架,提供两大核心能力: @@ -114,6 +115,10 @@ slime 被当作 RL 基础设施来开发,因为“脚本能跑起来”远远 这些项目不只是 demo。它们是把 slime 作为可复用 RL substrate 的独立系统,覆盖生产级 post-training、agentic RL、domain RL 和 rollout-system research。 +### ⛵ Miles:面向大规模模型训练的企业级强化学习框架 + +[Miles](https://github.com/radixark/miles) 是 [RadixArk](https://github.com/radixark) 基于 slime 构建的大模型 RL 后训练框架。它与 slime 上游开发保持紧密同步,同时在此基础上针对企业场景做了一系列扩展:更深度的 [SGLang](https://github.com/sgl-project/sglang) 集成、配套的运维与部署工具和服务,以及针对[新模型](https://www.radixark.com/miles/docs/models)和[新硬件](https://www.radixark.com/miles/docs/platforms)的优化。Miles 也在持续围绕真实生产环境需求迭代和进化,例如加入对 LoRA、TITO、低精度训练的支持。 + ### 🔷 vime: 基于 slime 的 vLLM-Native RL Post-Training 框架 [**vime**](https://github.com/vllm-project/vime) 是由 vLLM 项目维护的、基于 slime 的后训练框架。它保留 slime 的 Megatron 训练栈、Data Buffer 数据流与自定义 data generation 设计,主要特点是将 rollout 后端替换为 [**vLLM**](https://github.com/vllm-project/vllm)(配合 [vllm-router](https://github.com/vllm-project/router))。在现有 slime 启动脚本基础上仅调整 rollout 相关参数,即可快速适配 vime 进行训练。 diff --git a/docker/Dockerfile b/docker/Dockerfile index f30589ac68..af1f1a3d94 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,4 @@ -ARG SGLANG_IMAGE_TAG=v0.5.12.post1-cu129 +ARG SGLANG_IMAGE_TAG=v0.5.13-cu129 FROM slimerl/sglang:${SGLANG_IMAGE_TAG} AS sglang # ======================================== Arguments ============================================= diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index d5cdbc2423..191c20ad4a 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -1,40 +1,5 @@ -diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py -index 111145e..454ec96 100644 ---- a/python/sglang/srt/configs/model_config.py -+++ b/python/sglang/srt/configs/model_config.py -@@ -427,12 +427,18 @@ class ModelConfig: - self.hf_config.architectures[0] = "DeepseekV4ForCausalLMNextN" - self.hf_config.num_nextn_predict_layers = 1 - -- if is_draft_model and self.hf_config.architectures[0] in [ -- "Glm4MoeForCausalLM", -- "Glm4MoeLiteForCausalLM", -- ]: -+ if ( -+ is_draft_model -+ and self.hf_config.architectures[0] == "Glm4MoeForCausalLM" -+ ): - self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN" - -+ if ( -+ is_draft_model -+ and self.hf_config.architectures[0] == "Glm4MoeLiteForCausalLM" -+ ): -+ self.hf_config.architectures[0] = "Glm4MoeLiteForCausalLMNextN" -+ - if is_draft_model and self.hf_config.architectures[0] in [ - "GlmOcrForConditionalGeneration", - ]: -@@ -602,6 +608,7 @@ class ModelConfig: - or "DeepseekV3ForCausalLM" in self.hf_config.architectures - or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures - or "Glm4MoeLiteForCausalLM" in self.hf_config.architectures -+ or "Glm4MoeLiteForCausalLMNextN" in self.hf_config.architectures - or "GlmMoeDsaForCausalLM" in self.hf_config.architectures - or "LongcatFlashForCausalLM" in self.hf_config.architectures - or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py -index 097a841..eb56a55 100644 +index a7bf9904a20..b0cb56aaece 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -32,6 +32,7 @@ class KVArgs: @@ -46,7 +11,7 @@ index 097a841..eb56a55 100644 aux_data_lens: List[int] aux_item_lens: List[int] diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py -index 3cdf2af..2a2aacd 100644 +index e9efdcdd9ee..70265a424f5 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -21,6 +21,7 @@ Life cycle of a request in the decode server @@ -57,18 +22,18 @@ index 3cdf2af..2a2aacd 100644 import time from collections import deque from dataclasses import dataclass -@@ -43,8 +44,10 @@ from sglang.srt.disaggregation.utils import ( +@@ -49,8 +50,10 @@ from sglang.srt.disaggregation.utils import ( MetadataBuffers, ReqToMetadataIdxAllocator, TransferBackend, + apply_prefill_timing_payload, + _is_fake_transfer, get_kv_class, - is_mla_backend, + is_slime_profiling_enabled, + is_mla_backend, poll_and_all_reduce, poll_and_all_reduce_with_staging, - prepare_abort, -@@ -410,6 +413,7 @@ class DecodePreallocQueue: +@@ -425,6 +428,7 @@ class DecodePreallocQueue(DecodeHiCachePreallocMixin): kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( self.metadata_buffers.get_buf_infos() ) @@ -76,28 +41,10 @@ index 3cdf2af..2a2aacd 100644 setup_state_kv_args( kv_args, -@@ -445,6 +449,16 @@ class DecodePreallocQueue: - ) - return kv_manager - -+ def release_memory_occupation(self): -+ self.queue.clear() -+ self.retracted_queue.clear() -+ if hasattr(self.kv_manager, "deregister_buffer_to_engine"): -+ self.kv_manager.deregister_buffer_to_engine() -+ -+ def resume_memory_occupation(self): -+ if hasattr(self.kv_manager, "register_buffer_to_engine"): -+ self.kv_manager.register_buffer_to_engine() -+ - def add(self, req: Req, is_retracted: bool = False) -> None: - """Add a request to the pending queue.""" - if self._check_if_req_exceed_kv_capacity(req): -@@ -616,12 +630,37 @@ class DecodePreallocQueue: +@@ -632,12 +636,33 @@ class DecodePreallocQueue(DecodeHiCachePreallocMixin): [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group ) -+ # Bootstrap timeout: if a request has been stuck in Bootstrapping for too long, treat it as failed. + bootstrap_timeout = float( + os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") + ) @@ -109,17 +56,14 @@ index 3cdf2af..2a2aacd 100644 if poll == KVPoll.Bootstrapping: - pass -+ # Check for bootstrap timeout -+ entry_time = getattr( -+ decode_req.req.time_stats, -+ "decode_prealloc_queue_entry_time", -+ None, ++ entry_time = ( ++ decode_req.req.time_stats.decode_prealloc_queue_entry_time + ) -+ if entry_time is not None and (now - entry_time) > bootstrap_timeout: ++ if entry_time > 0 and now - entry_time > bootstrap_timeout: + error_message = ( -+ f"Decode bootstrap timed out after {now - entry_time:.1f}s " -+ f"for request rank={self.tp_rank} " -+ f"{decode_req.req.rid=} {decode_req.req.bootstrap_room=}" ++ f"Decode prealloc timeout for request rank={self.tp_rank} " ++ f"{decode_req.req.rid=} {decode_req.req.bootstrap_room=} " ++ f"after {bootstrap_timeout}s" + ) + logger.error(error_message) + prepare_abort( @@ -127,20 +71,38 @@ index 3cdf2af..2a2aacd 100644 + error_message, + status_code=HTTPStatus.GATEWAY_TIMEOUT, + ) -+ if self.scheduler.enable_metrics: ++ if self.scheduler.metrics_reporter.enable_metrics: + self.scheduler.metrics_collector.increment_bootstrap_failed_reqs() elif poll == KVPoll.WaitingForInput: decode_req.waiting_for_input = True decode_req.req.time_stats.set_bootstrap_done_time() -@@ -980,6 +1019,7 @@ class DecodePreallocQueue: +@@ -1027,6 +1052,7 @@ class DecodePreallocQueue(DecodeHiCachePreallocMixin): self.req_to_metadata_buffer_idx_allocator.alloc() ) assert decode_req.metadata_buffer_index is not None + self.metadata_buffers.clear_profiling_buf(decode_req.metadata_buffer_index) - page_indices = kv_to_page_indices(kv_indices, page_size) + page_indices = kv_to_page_indices(kv_indices, kv_transfer_page_size) decode_req.kv_receiver.send_metadata( page_indices, -@@ -1397,6 +1437,7 @@ class DecodeTransferQueue: +@@ -1403,6 +1429,17 @@ class DecodePreallocQueue(DecodeHiCachePreallocMixin): + return host_indices + return kv_loc + ++ def release_memory_occupation(self): ++ self.queue.clear() ++ self.retracted_queue.clear() ++ self.pending_reqs.clear() ++ if hasattr(self.kv_manager, "deregister_buffer_to_engine"): ++ self.kv_manager.deregister_buffer_to_engine() ++ ++ def resume_memory_occupation(self): ++ if hasattr(self.kv_manager, "register_buffer_to_engine"): ++ self.kv_manager.register_buffer_to_engine() ++ + + class DecodeTransferQueue(DecodeHiCacheTransferMixin): + """ +@@ -1455,6 +1492,7 @@ class DecodeTransferQueue(DecodeHiCacheTransferMixin): output_topk_index, output_hidden_states, output_bootstrap_room, @@ -148,27 +110,23 @@ index 3cdf2af..2a2aacd 100644 ) = self.metadata_buffers.get_buf(idx) # Validate bootstrap_room to detect context corruption -@@ -1458,6 +1499,14 @@ class DecodeTransferQueue: - output_top_logprobs_idx[: decode_req.req.top_logprobs_num].tolist() +@@ -1543,6 +1581,12 @@ class DecodeTransferQueue(DecodeHiCacheTransferMixin): + ].tolist() ) -+ # Inject prefill-side PD timing forwarded from the P instance. -+ # Layout: [bootstrap_queue, forward, transfer_queue, bootstrap, -+ # alloc_waiting, transfer_speed, transfer_mb, retry_count] + if is_slime_profiling_enabled(): + apply_prefill_timing_payload( -+ decode_req.req.time_stats, output_prefill_timing ++ decode_req.req.time_stats, ++ output_prefill_timing, + ) + decode_req.kv_receiver.clear() decode_req.kv_receiver = None decode_req.req.time_stats.set_wait_queue_entry_time() -@@ -1490,6 +1539,13 @@ class DecodeTransferQueue: - [dr.kv_receiver for dr in self.queue], self.gloo_group - ) +@@ -1600,6 +1644,11 @@ class DecodeTransferQueue(DecodeHiCacheTransferMixin): + else: + polls = self._poll_with_metadata_gate() -+ # Transfer timeout: if a request has been in the transfer queue for too long -+ # (e.g., stuck in Bootstrapping/WaitingForInput/Transferring), treat it as failed. + transfer_timeout = float( + os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600") + ) @@ -177,91 +135,90 @@ index 3cdf2af..2a2aacd 100644 transferred_reqs = [] indices_to_remove = set() for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): -@@ -1544,7 +1600,20 @@ class DecodeTransferQueue: +@@ -1674,7 +1723,17 @@ class DecodeTransferQueue(DecodeHiCacheTransferMixin): KVPoll.WaitingForInput, KVPoll.Transferring, ]: - pass -+ # Check for transfer timeout -+ entry_time = getattr( -+ decode_req.req.time_stats, -+ "decode_transfer_queue_entry_time", -+ None, -+ ) -+ if entry_time is not None and (now - entry_time) > transfer_timeout: -+ error_message = ( -+ f"Decode transfer timed out after {now - entry_time:.1f}s " -+ f"(state={poll}) for request rank={self.tp_rank} " -+ f"{decode_req.req.rid=} {decode_req.req.bootstrap_room=}" ++ entry_time = decode_req.req.time_stats.decode_transfer_queue_entry_time ++ if entry_time > 0 and now - entry_time > transfer_timeout: ++ logger.error( ++ "Decode transfer timeout for request rank=%s rid=%s room=%s " ++ "after %ss", ++ self.tp_rank, ++ decode_req.req.rid, ++ decode_req.req.bootstrap_room, ++ transfer_timeout, + ) -+ logger.error(error_message) + decode_req.kv_receiver.abort() else: raise ValueError(f"Unexpected poll case: {poll}") -@@ -1565,6 +1634,14 @@ class DecodeTransferQueue: +@@ -1698,6 +1757,15 @@ class DecodeTransferQueue(DecodeHiCacheTransferMixin): return transferred_reqs + def release_memory_occupation(self): -+ """Clean up all in-flight transfers before releasing GPU memory.""" ++ for decode_req in self.queue: ++ if decode_req.kv_receiver is not None: ++ decode_req.kv_receiver.abort() + self.queue.clear() + + def resume_memory_occupation(self): -+ """Resume after GPU memory re-allocation. Queue was already cleared on release.""" + pass + class SchedulerDisaggregationDecodeMixin: @torch.no_grad() -@@ -1758,7 +1835,15 @@ class SchedulerDisaggregationDecodeMixin: +@@ -1893,6 +1961,11 @@ class SchedulerDisaggregationDecodeMixin: resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs() self.waiting_queue.extend(resumed_reqs) if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0: -- # if there are still retracted requests, we do not allocate new requests -+ # Still have retracted requests that couldn't resume (not enough memory). -+ # Don't accept new requests (pop_preallocated) — they would consume memory -+ # that retracted requests need. -+ # But DO drain completed transfers: their KV is already committed, and -+ # moving them to waiting_queue frees the reserved-decode-token budget -+ # in _allocatable_tokens(), which may unblock resume on the next iteration. -+ # Without this, completed transfers hold memory indefinitely → deadlock. -+ alloc_reqs = self.disagg_decode_transfer_queue.pop_transferred() -+ self.waiting_queue.extend(alloc_reqs) ++ transferred_reqs = self.disagg_decode_transfer_queue.pop_transferred() ++ if self.enable_hisparse: ++ for req in transferred_reqs: ++ self.hisparse_coordinator.admit_request_direct(req) ++ self.waiting_queue.extend(transferred_reqs) + # if there are still retracted requests, we do not allocate new requests return - if not hasattr(self, "polling_count"): diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py -index 634f2ea..afe1704 100644 +index b21aee9f7c2..87f0a6fa668 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py -@@ -39,6 +39,7 @@ from sglang.srt.disaggregation.mooncake.utils import ( - from sglang.srt.disaggregation.utils import ( - DisaggregationMode, - filter_kv_indices_for_cp_rank, -+ iter_aux_transfer_specs, +@@ -39,7 +39,10 @@ from sglang.srt.disaggregation.common.utils import ( + from sglang.srt.disaggregation.mooncake.utils import ( + check_mooncake_custom_mem_pool_enabled, ) +-from sglang.srt.disaggregation.utils import DisaggregationMode ++from sglang.srt.disaggregation.utils import ( ++ DisaggregationMode, ++ iter_aux_transfer_specs, ++) from sglang.srt.distributed.parallel_state import get_mooncake_transfer_engine from sglang.srt.environ import envs -@@ -273,6 +274,17 @@ class MooncakeKVManager(CommonKVManager): + from sglang.srt.observability.mooncake_trace import ( +@@ -257,6 +260,19 @@ class MooncakeKVManager(CommonKVManager): if ptrs and lens: self.engine.batch_register(ptrs, lens) + def deregister_buffer_to_engine(self): -+ if self.kv_args.kv_data_ptrs: ++ if self.kv_args.kv_data_ptrs and self.kv_args.kv_data_lens: + self.engine.batch_deregister(self.kv_args.kv_data_ptrs) + -+ if self.kv_args.aux_data_ptrs: ++ if self.kv_args.aux_data_ptrs and self.kv_args.aux_data_lens: + self.engine.batch_deregister(self.kv_args.aux_data_ptrs) + -+ for ptrs in self.kv_args.state_data_ptrs: -+ if ptrs: ++ for ptrs, lens in zip( ++ self.kv_args.state_data_ptrs, self.kv_args.state_data_lens ++ ): ++ if ptrs and lens: + self.engine.batch_deregister(ptrs) + # ------------------------------------------------------------------ # Staging buffer methods (all delegate to staging_handler.py) # ------------------------------------------------------------------ -@@ -887,10 +899,14 @@ class MooncakeKVManager(CommonKVManager): +@@ -827,10 +843,14 @@ class MooncakeKVManager(CommonKVManager): prefill_aux_ptrs = self.kv_args.aux_data_ptrs prefill_aux_item_lens = self.kv_args.aux_item_lens @@ -280,7 +237,7 @@ index 634f2ea..afe1704 100644 transfer_blocks.append((src_addr, dst_addr, length)) return self._transfer_data(req.mooncake_session_id, transfer_blocks) -@@ -904,9 +920,14 @@ class MooncakeKVManager(CommonKVManager): +@@ -844,9 +864,14 @@ class MooncakeKVManager(CommonKVManager): prefill_aux_ptrs = self.kv_args.aux_data_ptrs prefill_aux_item_lens = self.kv_args.aux_item_lens @@ -298,30 +255,31 @@ index 634f2ea..afe1704 100644 data = AuxDataCodec.serialize_data_from_buffer(src_addr, length) self.send_aux_data_to_endpoint( -@@ -1053,16 +1074,13 @@ class MooncakeKVManager(CommonKVManager): - ) +@@ -994,15 +1019,17 @@ class MooncakeKVManager(CommonKVManager): src_indices = list(indices) dst_indices_local = list(dst_indices) -- if len(src_indices) > len(dst_indices_local): + if len(src_indices) > len(dst_indices_local): - logger.warning( - f"len(prefill_state_indices) = {len(src_indices)}, len(dst_state_indices) = {len(dst_indices_local)}" -- ) ++ logger.error( ++ f"len(prefill_state_indices) = {len(src_indices)}, " ++ f"len(dst_state_indices) = {len(dst_indices_local)}" + ) - src_indices = src_indices[: len(dst_indices_local)] -- elif len(src_indices) < len(dst_indices_local): ++ return -1 + elif len(src_indices) < len(dst_indices_local): - logger.warning( - f"len(prefill_state_indices) = {len(src_indices)}, len(dst_state_indices) = {len(dst_indices_local)}" -+ if len(src_indices) != len(dst_indices_local): + logger.error( -+ "PD extra-state index mismatch, reject transfer to avoid corrupted outputs: " -+ f"len(prefill_state_indices)={len(src_indices)}, " -+ f"len(dst_state_indices)={len(dst_indices_local)}" ++ f"len(prefill_state_indices) = {len(src_indices)}, " ++ f"len(dst_state_indices) = {len(dst_indices_local)}" ) - dst_indices_local = dst_indices_local[: len(src_indices)] + return -1 rc = ( self._send_kvcache_generic( mooncake_session_id=req.mooncake_session_id, -@@ -1319,12 +1337,6 @@ class MooncakeKVManager(CommonKVManager): +@@ -1279,12 +1306,6 @@ class MooncakeKVManager(CommonKVManager): if ret != 0: with self.session_lock: self.session_failures[req.mooncake_session_id] += 1 @@ -334,7 +292,7 @@ index 634f2ea..afe1704 100644 self.record_failure( kv_chunk.room, f"Failed to send kv chunk of {kv_chunk.room} to " -@@ -1342,12 +1354,30 @@ class MooncakeKVManager(CommonKVManager): +@@ -1302,12 +1323,27 @@ class MooncakeKVManager(CommonKVManager): if kv_chunk.is_last_chunk: if kv_chunk.state_indices: @@ -346,13 +304,10 @@ index 634f2ea..afe1704 100644 target_rank_registration_info, ) + if ret != 0: -+ remote_addr = NetworkAddress( -+ req.endpoint, req.dst_port -+ ).to_host_port_str() + self.record_failure( + kv_chunk.room, -+ f"Failed to send extra state chunk of {kv_chunk.room} to " -+ f"{remote_addr}", ++ f"Failed to send extra state of {kv_chunk.room} to " ++ f"{NetworkAddress(req.endpoint, req.dst_port).to_host_port_str()}", + ) + self.update_status(kv_chunk.room, KVPoll.Failed) + self.sync_status_to_decode_endpoint( @@ -367,27 +322,27 @@ index 634f2ea..afe1704 100644 # Only the last chunk we need to send the aux data ret = self.send_aux( diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py -index 0e2ed6a..78658f1 100644 +index ce1afdac3ad..de8fd054f70 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py -@@ -20,6 +20,8 @@ Life cycle of a request in the prefill server - from __future__ import annotations +@@ -21,6 +21,8 @@ from __future__ import annotations + import hashlib import logging +import os +import time + from array import array from collections import deque from http import HTTPStatus - from typing import TYPE_CHECKING, List, Optional -@@ -176,6 +178,7 @@ class PrefillBootstrapQueue: +@@ -180,6 +182,7 @@ class PrefillBootstrapQueue: kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( self.metadata_buffers.get_buf_infos() ) + kv_args.aux_buffer_names = self.metadata_buffers.get_aux_buffer_names() kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device - kv_args.gpu_id = self.scheduler.gpu_id + kv_args.gpu_id = self.scheduler.ps.gpu_id -@@ -290,6 +293,11 @@ class PrefillBootstrapQueue: +@@ -332,6 +335,11 @@ class PrefillBootstrapQueue: self.scheduler.attn_tp_cpu_group, ) @@ -397,18 +352,14 @@ index 0e2ed6a..78658f1 100644 + now = time.perf_counter() + for i, (req, poll) in enumerate(zip(self.queue, polls)): - if rids_to_check is not None: - # if req not in reqs_info_to_check, skip -@@ -297,6 +305,26 @@ class PrefillBootstrapQueue: - continue - - if poll == KVPoll.Bootstrapping: -+ entry_time = getattr( -+ req.time_stats, -+ "prefill_bootstrap_queue_entry_time", -+ None, -+ ) -+ if entry_time is not None and (now - entry_time) > bootstrap_timeout: + if ( + rids_to_check is not None +@@ -348,6 +356,27 @@ class PrefillBootstrapQueue: + indices_to_remove.add(i) + failed_reqs.append(req) + elif poll == KVPoll.Bootstrapping: ++ entry_time = req.time_stats.prefill_bootstrap_queue_entry_time ++ if entry_time > 0 and now - entry_time > bootstrap_timeout: + error_message = ( + f"Prefill bootstrap timed out after {now - entry_time:.1f}s " + f"for request rank={self.tp_rank} " @@ -416,17 +367,22 @@ index 0e2ed6a..78658f1 100644 + ) + logger.error(error_message) + prepare_abort( -+ req, error_message, status_code=HTTPStatus.GATEWAY_TIMEOUT ++ req, ++ error_message, ++ status_code=HTTPStatus.GATEWAY_TIMEOUT, ++ ) ++ self.scheduler.output_streamer.stream_output( ++ [req], req.return_logprob + ) -+ self.scheduler.stream_output([req], req.return_logprob) + indices_to_remove.add(i) + failed_reqs.append(req) -+ if self.scheduler.enable_metrics: ++ if self.scheduler.metrics_reporter.enable_metrics: + self.scheduler.metrics_collector.increment_bootstrap_failed_reqs() - continue - elif poll == KVPoll.Failed: - error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}" -@@ -354,6 +382,15 @@ class PrefillBootstrapQueue: ++ continue + if ( + req.time_stats.prefill_retry_count + < self.scheduler.server_args.optimistic_prefill_retries +@@ -378,6 +407,15 @@ class PrefillBootstrapQueue: else: return bootstrapped_reqs, failed_reqs @@ -442,68 +398,7 @@ index 0e2ed6a..78658f1 100644 class SchedulerDisaggregationPrefillMixin: """ -@@ -514,6 +551,34 @@ class SchedulerDisaggregationPrefillMixin: - for i, (req, next_token_id) in enumerate( - zip(batch.reqs, next_token_ids, strict=True) - ): -+ # An AbortReq may arrive while this prefill batch is already running. -+ # Honor it before exposing KV to the decode side. -+ req.check_finished() -+ if req.finished(): -+ if req.is_chunked <= 0: -+ req.time_stats.set_prefill_finished_time() -+ else: -+ req.time_stats.set_last_chunked_prefill_finish_time() -+ -+ if req.return_logprob: -+ assert extend_logprob_start_len_per_req is not None -+ assert extend_input_len_per_req is not None -+ extend_logprob_start_len = extend_logprob_start_len_per_req[i] -+ extend_input_len = extend_input_len_per_req[i] -+ logprob_pt += extend_input_len - extend_logprob_start_len -+ -+ release_kv_cache(req, self.tree_cache) -+ req.time_stats.set_completion_time() -+ if req.grammar is not None: -+ req.grammar.finished = True -+ self.stream_output([req], req.return_logprob, None) -+ release_req_to_metadata_buffer( -+ req, self.req_to_metadata_buffer_idx_allocator -+ ) -+ if hasattr(req.disagg_kv_sender, "clear"): -+ req.disagg_kv_sender.clear() -+ continue -+ - if req.is_chunked <= 0: - req.time_stats.set_prefill_finished_time() - -@@ -586,13 +651,18 @@ class SchedulerDisaggregationPrefillMixin: - self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx) - req.time_stats.set_last_chunked_prefill_finish_time() - -- can_run_cuda_graph = getattr(result, "can_run_cuda_graph", False) -- self.report_prefill_stats( -- batch=batch, -- prefill_stats=batch.prefill_stats, -- can_run_cuda_graph=can_run_cuda_graph, -- dp_cooperation_info=batch.dp_cooperation_info, -- ) -+ if ( -+ self.current_scheduler_metrics_enabled -+ and hasattr(batch, "prefill_stats") -+ and batch.prefill_stats is not None -+ ): -+ can_run_cuda_graph = getattr(result, "can_run_cuda_graph", False) -+ self.report_prefill_stats( -+ batch=batch, -+ prefill_stats=batch.prefill_stats, -+ can_run_cuda_graph=can_run_cuda_graph, -+ dp_cooperation_info=getattr(batch, "dp_cooperation_info", None), -+ ) - - def process_disagg_prefill_inflight_queue( - self: Scheduler, rids_to_check: Optional[List[str]] = None -@@ -612,6 +682,11 @@ class SchedulerDisaggregationPrefillMixin: +@@ -685,6 +723,11 @@ class SchedulerDisaggregationPrefillMixin: self.attn_tp_cpu_group, ) @@ -515,31 +410,29 @@ index 0e2ed6a..78658f1 100644 undone_reqs: List[Req] = [] # Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue for req, poll in zip(self.disagg_prefill_inflight_queue, polls): -@@ -637,7 +712,29 @@ class SchedulerDisaggregationPrefillMixin: +@@ -710,7 +753,27 @@ class SchedulerDisaggregationPrefillMixin: continue if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]: - undone_reqs.append(req) -+ entry_time = getattr( -+ req.time_stats, -+ "prefill_transfer_queue_entry_time", -+ None, -+ ) -+ if entry_time is not None and (now - entry_time) > transfer_timeout: ++ entry_time = req.time_stats.prefill_transfer_queue_entry_time ++ if entry_time > 0 and now - entry_time > transfer_timeout: + error_message = ( + f"Prefill transfer timed out after {now - entry_time:.1f}s " -+ f"(state={poll}) for request rank={self.tp_rank} " ++ f"(state={poll}) for request rank={self.ps.tp_rank} " + f"{req.rid=} {req.bootstrap_room=}" + ) + logger.error(error_message) + release_kv_cache(req, self.tree_cache) + prepare_abort( -+ req, error_message, status_code=HTTPStatus.GATEWAY_TIMEOUT ++ req, ++ error_message, ++ status_code=HTTPStatus.GATEWAY_TIMEOUT, + ) + if hasattr(req.disagg_kv_sender, "clear"): + req.disagg_kv_sender.clear() + done_reqs.append(req) -+ if self.enable_metrics: ++ if self.metrics_reporter.enable_metrics: + self.metrics_collector.increment_transfer_failed_reqs() + else: + undone_reqs.append(req) @@ -547,10 +440,18 @@ index 0e2ed6a..78658f1 100644 release_kv_cache(req, self.tree_cache) # unlock the tree req.finished_reason = FINISH_LENGTH(length=0) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py -index 951fa5b..abf1817 100644 +index e1d7d9c8db3..a44685777ea 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py -@@ -28,6 +28,17 @@ if TYPE_CHECKING: +@@ -2,6 +2,7 @@ from __future__ import annotations + + import os + import random ++import time + from collections import deque + from contextlib import nullcontext + from enum import Enum +@@ -30,6 +31,17 @@ if TYPE_CHECKING: # Constants & Enums ######################### FAKE_BOOTSTRAP_HOST = "2.2.2.2" @@ -568,13 +469,10 @@ index 951fa5b..abf1817 100644 class DisaggregationMode(Enum): -@@ -201,46 +212,35 @@ class MetadataBuffers: +@@ -255,46 +267,32 @@ class MetadataBuffers: self.bootstrap_room = torch.zeros( (size, 8), dtype=bootstrap_room_dtype, device=device ) -+ # Prefill-side PD timing (8 floats, padded to 16 for RDMA alignment). -+ # Layout: [bootstrap_queue, forward, transfer_queue, bootstrap, -+ # alloc_waiting, transfer_speed, transfer_mb, retry_count] + self.prefill_timing = torch.zeros( + (size, 16), dtype=torch.float32, device=device + ) @@ -640,7 +538,7 @@ index 951fa5b..abf1817 100644 def get_buf(self, idx: int): return ( self.output_ids[idx].clone(), -@@ -253,8 +253,12 @@ class MetadataBuffers: +@@ -307,8 +305,12 @@ class MetadataBuffers: self.output_topk_index[idx].clone(), self.output_hidden_states[idx].clone(), self.bootstrap_room[idx].clone(), @@ -653,16 +551,10 @@ index 951fa5b..abf1817 100644 def set_buf(self, req: Req): self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0] -@@ -302,6 +306,99 @@ class MetadataBuffers: +@@ -360,6 +362,93 @@ class MetadataBuffers: self.bootstrap_room[req.metadata_buffer_index, 0] = ( req.bootstrap_room if req.bootstrap_room is not None else 0 ) -+ # Pack prefill-side PD timing durations for transfer to decode instance. -+ # Note: set_buf is called at the START of the last KV chunk send, so -+ # completion_time and prefill_transfer_queue_entry_time are not yet set. -+ # We use time.perf_counter() as the "forward just completed" timestamp. -+ import time -+ + ts = req.time_stats + timing = self.prefill_timing[req.metadata_buffer_index] + self.clear_profiling_buf(req.metadata_buffer_index) @@ -754,7 +646,7 @@ index 951fa5b..abf1817 100644 ######################### diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py -index 1dde8be..729a5e6 100644 +index 88bf1947684..4ede8eb9078 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -71,6 +71,7 @@ from sglang.srt.managers.io_struct import ( @@ -765,7 +657,7 @@ index 1dde8be..729a5e6 100644 ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, RpcReqInput, -@@ -1090,6 +1091,24 @@ class Engine(EngineScoreMixin, EngineBase): +@@ -1110,6 +1111,20 @@ class Engine(EngineScoreMixin, EngineBase): self.tokenizer_manager.update_weights_from_ipc(obj, None) ) @@ -774,15 +666,11 @@ index 1dde8be..729a5e6 100644 + restore_weights_before_load: bool = False, + post_process_quantization: bool = False, + ): -+ """ -+ Optional post-processing for updated weights (e.g., Marlin conversion). -+ Should be called after weight update is finished. -+ """ ++ """Optional post-processing for updated weights, e.g. quantization packing.""" + obj = PostProcessWeightsReqInput( + restore_weights_before_load=restore_weights_before_load, + post_process_quantization=post_process_quantization, + ) -+ + return self.loop.run_until_complete( + self.tokenizer_manager.post_process_weights(obj, None) + ) @@ -791,7 +679,7 @@ index 1dde8be..729a5e6 100644 """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py -index 80081fc..cec2c57 100644 +index d7368383d89..2c881d95bd5 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -127,6 +127,7 @@ from sglang.srt.managers.io_struct import ( @@ -802,7 +690,7 @@ index 80081fc..cec2c57 100644 ProfileReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, -@@ -611,10 +612,8 @@ async def model_info(): +@@ -618,10 +619,8 @@ async def model_info(): @app.get("/weight_version") async def weight_version(): """Get the current weight version.""" @@ -815,14 +703,13 @@ index 80081fc..cec2c57 100644 @app.get("/get_server_info") -@@ -631,9 +630,19 @@ async def get_server_info(): +@@ -638,9 +637,18 @@ async def get_server_info(): async def server_info(): """Get the server information.""" # Returns internal states per DP. - internal_states: List[Dict[Any, Any]] = ( - await _global_state.tokenizer_manager.get_internal_state() - ) -+ # In large/disaggregated deployments this can occasionally block; keep endpoint responsive. + server_info_timeout = float(os.environ.get("SGLANG_SERVER_INFO_TIMEOUT", "2")) + try: + internal_states: List[Dict[Any, Any]] = await asyncio.wait_for( @@ -836,19 +723,16 @@ index 80081fc..cec2c57 100644 + ) + internal_states = [] - # server_args.model_config is not serializable but should be excluded by asdict. - return { -@@ -1222,6 +1231,23 @@ async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Re + server_args = _global_state.tokenizer_manager.server_args + +@@ -1249,6 +1257,20 @@ async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Re return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) +@app.post("/post_process_weights") +@auth_level(AuthLevel.ADMIN_OPTIONAL) +async def post_process_weights(req: PostProcessWeightsReqInput, request: Request): -+ """ -+ Optional post-processing for updated weights (e.g., Marlin conversion). -+ This should be called selectively after `update_weights_from_distributed/update_weights_from_tensor`. -+ """ ++ """Optional post-processing for updated weights, e.g. quantization packing.""" + success, message = await _global_state.tokenizer_manager.post_process_weights( + req, request + ) @@ -863,10 +747,10 @@ index 80081fc..cec2c57 100644 @auth_level(AuthLevel.ADMIN_OPTIONAL) async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py -index c78ad65..4470bf9 100644 +index 435c30a5cfd..864a0f567a6 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py -@@ -242,6 +242,7 @@ class Envs: +@@ -299,6 +299,7 @@ class Envs: SGLANG_DISAGGREGATION_WAITING_TIMEOUT = EnvInt(300) SGLANG_DISAGGREGATION_NIXL_BACKEND = EnvStr("UCX") SGLANG_DISAGGREGATION_NIXL_BACKEND_PARAMS = EnvStr("{}") @@ -874,10 +758,10 @@ index c78ad65..4470bf9 100644 SGLANG_DISAGGREGATION_ALL_CP_RANKS_TRANSFER = EnvBool(False) SGLANG_DISAGGREGATION_FORCE_QUERY_PREFILL_DP_RANK = EnvBool(False) # Extra slots in req_to_token_pool for decode workers (only effective when -diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py -index f3c2c29..b99f7fc 100644 ---- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py -+++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +diff --git a/python/sglang/srt/layers/attention/dsa/dsa_indexer.py b/python/sglang/srt/layers/attention/dsa/dsa_indexer.py +index 85fcd4b9ec7..a49161f6154 100644 +--- a/python/sglang/srt/layers/attention/dsa/dsa_indexer.py ++++ b/python/sglang/srt/layers/attention/dsa/dsa_indexer.py @@ -2,6 +2,7 @@ from __future__ import annotations import contextlib @@ -886,149 +770,154 @@ index f3c2c29..b99f7fc 100644 from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -@@ -244,14 +245,31 @@ class Indexer(MultiPlatformOp): +@@ -100,6 +101,15 @@ if TYPE_CHECKING: + DUAL_STREAM_TOKEN_THRESHOLD = 1024 if _is_cuda else 0 + + ++def _match_head_gate_q_scale( ++ weights: torch.Tensor, q_scale: torch.Tensor ++) -> torch.Tensor: ++ if weights.shape[1] < q_scale.shape[1]: ++ assert q_scale.shape[1] % weights.shape[1] == 0 ++ weights = weights.repeat_interleave(q_scale.shape[1] // weights.shape[1], dim=1) ++ return weights ++ ++ + if _is_cuda: + from sglang.srt.compilation.compilation_config import register_split_op + from sglang.srt.utils.custom_op import register_custom_op +@@ -165,6 +175,7 @@ if _is_cuda: + ) -> torch.Tensor: + out = torch.mm(x, weight.t(), out_dtype=torch.float32) + weights = out * n_heads_inv_sqrt ++ weights = _match_head_gate_q_scale(weights, q_scale) + weights = weights.unsqueeze(-1) * q_scale * softmax_scale + return weights + +@@ -368,6 +379,15 @@ class Indexer(MultiPlatformOp): self.k_norm = LayerNorm( self.head_dim, dtype=torch.bfloat16 if _use_aiter else torch.float32 ) + server_args = get_global_server_args() -+ disable_flag = server_args.disable_indexer_rope_neox_style -+ env_raw = os.environ.get("INDEXER_ROPE_NEOX_STYLE", None) -+ if env_raw is not None: -+ env_value = env_raw == "1" -+ if disable_flag and env_value: ++ env_neox_style = os.environ.get("INDEXER_ROPE_NEOX_STYLE") ++ if env_neox_style is not None: ++ if env_neox_style not in ("0", "1"): + raise ValueError( -+ "Conflict: --disable-indexer-rope-neox-style is set but " -+ "INDEXER_ROPE_NEOX_STYLE='1'. " -+ "Please remove one or make them consistent." ++ "INDEXER_ROPE_NEOX_STYLE must be either '0' or '1' when set." + ) -+ resolved_neox_style = env_value -+ elif disable_flag: -+ resolved_neox_style = False -+ else: -+ resolved_neox_style = is_neox_style ++ is_neox_style = env_neox_style == "1" + self.rotary_emb = get_rope_wrapper( rope_head_dim, rotary_dim=rope_head_dim, - max_position=max_position_embeddings, +@@ -375,7 +395,7 @@ class Indexer(MultiPlatformOp): base=rope_theta, # type: ignore rope_scaling=rope_scaling, -- is_neox_style=is_neox_style, + is_neox_style=is_neox_style, - device=get_global_server_args().device, -+ is_neox_style=resolved_neox_style, + device=server_args.device, ) self.block_size = block_size self.scale_fmt = scale_fmt -@@ -302,6 +320,11 @@ class Indexer(MultiPlatformOp): - self, x: Union[torch.Tensor, Tuple[torch.Tensor, ...]], q_scale: torch.Tensor +@@ -427,6 +447,7 @@ class Indexer(MultiPlatformOp): ): weights = self._weights_proj_bf16_in_fp32_out(x) -+ if weights.shape[1] < q_scale.shape[1]: -+ assert q_scale.shape[1] % weights.shape[1] == 0 -+ weights = weights.repeat_interleave( -+ q_scale.shape[1] // weights.shape[1], dim=1 -+ ) weights = weights * self.n_heads**-0.5 ++ weights = _match_head_gate_q_scale(weights, q_scale) weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale return weights -@@ -1183,6 +1206,9 @@ class Indexer(MultiPlatformOp): + +@@ -434,8 +455,15 @@ class Indexer(MultiPlatformOp): + def _apply_q_scale_and_softmax_scale( + self, weights: torch.Tensor, q_scale: torch.Tensor + ): ++ weights = _match_head_gate_q_scale(weights, q_scale) + return weights.unsqueeze(-1) * q_scale * self.softmax_scale + ++ def _maybe_repeat_query_heads(self, query: torch.Tensor) -> torch.Tensor: ++ if query.shape[1] < 32: ++ assert 32 % query.shape[1] == 0 ++ query = query.repeat_interleave(32 // query.shape[1], dim=1) ++ return query ++ + def _get_q_k_bf16( + self, + q_lora: torch.Tensor, +@@ -1389,6 +1417,7 @@ class Indexer(MultiPlatformOp): query, key = self._get_q_k_bf16( q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch ) -+ if query.shape[1] < 32: -+ assert 32 % query.shape[1] == 0 -+ query = query.repeat_interleave(32 // query.shape[1], dim=1) ++ query = self._maybe_repeat_query_heads(query) q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt) with torch.cuda.stream(self.alt_stream): self._store_index_k_cache( -@@ -1197,6 +1223,9 @@ class Indexer(MultiPlatformOp): +@@ -1403,6 +1432,7 @@ class Indexer(MultiPlatformOp): query, key = self._get_q_k_bf16( q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch ) -+ if query.shape[1] < 32: -+ assert 32 % query.shape[1] == 0 -+ query = query.repeat_interleave(32 // query.shape[1], dim=1) ++ query = self._maybe_repeat_query_heads(query) if enable_dual_stream: current_stream = torch.cuda.current_stream() diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py -index 97a5bfc..4cca7a0 100644 +index 59ca3f9cce6..9c2d00fcd7c 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py -@@ -771,6 +771,7 @@ class FusedMoE(torch.nn.Module): +@@ -793,6 +793,7 @@ class FusedMoE(torch.nn.Module): + "CompressedTensorsWNA16MoE", "CompressedTensorsWNA16TritonMoE", ] ++ and "zero" not in weight_name ) -+ and "zero" not in weight_name else loaded_weight ) - -@@ -990,6 +991,7 @@ class FusedMoE(torch.nn.Module): +@@ -1012,6 +1013,7 @@ class FusedMoE(torch.nn.Module): + "CompressedTensorsWNA16MoE", "CompressedTensorsWNA16TritonMoE", ] ++ and "zero" not in weight_name ) -+ and "zero" not in weight_name else loaded_weight ) - -diff --git a/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py b/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py -index 61af553..661735e 100644 ---- a/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py -+++ b/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py -@@ -863,8 +863,10 @@ def _varlen_deep_gemm_silu_mul_quant( - dtype=torch.float8_e4m3fn, - ) - -- if envs.SGLANG_OPT_USE_JIT_EP_ACTIVATION.get(): -- assert N % 4 == 0 and G % 4 == 0 -+ use_jit_ep_activation = ( -+ envs.SGLANG_OPT_USE_JIT_EP_ACTIVATION.get() and N % 4 == 0 and G % 4 == 0 -+ ) -+ if use_jit_ep_activation: - packed_ue8m0 = deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 - down_input_scale = torch.empty( - (E, G // 4, N) if packed_ue8m0 else (E, N, G), diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py -index 81056a1..b8cfe41 100644 +index 28a9d567a5e..e60a0bcfde0 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py -@@ -502,7 +502,7 @@ class CompressedTensorsConfig(QuantizationConfig): - ) - is_static = not weight_quant.dynamic - -- return is_channel_group and input_quant_none and is_symmetric and is_static -+ return is_channel_group and input_quant_none and is_static +@@ -927,6 +927,10 @@ class CompressedTensorsLinearMethod(LinearMethodBase): + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scheme.process_weights_after_loading(layer) - def _is_mxint4a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: - input_quant_none = input_quant is None -@@ -978,6 +978,9 @@ class CompressedTensorsFusedMoEMethod(FusedMoEMethodBase): ++ def restore_weights_before_loading(self, layer: torch.nn.Module) -> None: ++ if hasattr(layer.scheme, "restore_weights_before_loading"): ++ layer.scheme.restore_weights_before_loading(layer) ++ + def create_weights( + self, + layer: torch.nn.Module, +@@ -981,6 +985,10 @@ class CompressedTensorsFusedMoEMethod(FusedMoEMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.scheme.process_weights_after_loading(layer) + def restore_weights_before_loading(self, layer: torch.nn.Module) -> None: -+ layer.scheme.restore_weights_before_loading(layer) ++ if hasattr(layer.scheme, "restore_weights_before_loading"): ++ layer.scheme.restore_weights_before_loading(layer) + def create_weights( self, layer: torch.nn.Module, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py -index 0ac1878..bbc94f7 100644 +index 58562bb23db..c3dc1ceb0d2 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py -@@ -17,7 +17,10 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsMoEScheme, - ) - from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack --from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales -+from sglang.srt.layers.quantization.marlin_utils import ( -+ marlin_moe_permute_scales, +@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + from sglang.srt.layers.quantization.marlin_utils import ( + marlin_make_workspace, + marlin_moe_permute_scales, + moe_awq_to_marlin_zero_points, -+) + ) from sglang.srt.layers.quantization.utils import replace_parameter from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, set_weight_attrs - -@@ -64,7 +67,7 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): +@@ -69,7 +70,7 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): self.strategy = config.strategy self.group_size = config.group_size self.actorder = config.actorder @@ -1037,7 +926,7 @@ index 0ac1878..bbc94f7 100644 if not ( self.quant_config.quant_format == CompressionFormat.pack_quantized.value -@@ -124,7 +127,7 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): +@@ -129,7 +130,7 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): # In the case where we have actorder/g_idx, # we do not partition the w2 scales @@ -1046,11 +935,10 @@ index 0ac1878..bbc94f7 100644 if load_full_w2: w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size -@@ -172,6 +175,32 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): +@@ -177,6 +178,31 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): layer.register_parameter("w13_weight_shape", w13_weight_shape) set_weight_attrs(w13_weight_shape, extra_weight_attrs) -+ # add zero param + if not self.sym: + w13_qzeros = torch.nn.Parameter( + torch.empty( @@ -1079,22 +967,20 @@ index 0ac1878..bbc94f7 100644 w13_g_idx = torch.nn.Parameter( torch.empty( num_experts, -@@ -231,6 +260,10 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): +@@ -235,6 +261,9 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): + # Also record the shapes of the scales. layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape) layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape) - + if not self.sym: -+ layer._original_shapes["w13_weight_zero_point"] = w13_qzeros.shape ++ layer._original_shapes["w13_weight_zero_point"] = tuple(w13_qzeros.shape) + layer._original_shapes["w2_weight_zero_point"] = tuple(w2_qzeros.shape) -+ + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Skip if the layer is already converted to Marlin format to prevent double-packing. -@@ -334,6 +367,24 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): +@@ -339,6 +368,23 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): ) replace_tensor("w2_weight_scale", marlin_w2_scales) -+ # Repack zero + if not self.sym: + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_weight_zero_point, @@ -1112,10 +998,19 @@ index 0ac1878..bbc94f7 100644 + ) + replace_tensor("w2_weight_zero_point", marlin_w2_zp) + + layer.workspace = marlin_make_workspace(layer.w13_weight_packed.device, 4) layer.is_marlin_converted = True - def restore_weights_before_loading(self, layer: torch.nn.Module): -@@ -416,6 +467,8 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): +@@ -376,6 +422,8 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): + w13_g_idx=getattr(layer, "w13_weight_g_idx", None), + w2_g_idx=getattr(layer, "w2_weight_g_idx", None), + is_k_full=self.is_k_full, ++ w13_qzeros=layer.w13_weight_zero_point if not self.sym else None, ++ w2_qzeros=layer.w2_weight_zero_point if not self.sym else None, + ) + + def apply_weights( +@@ -422,6 +470,8 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): g_idx2=layer.w2_weight_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, @@ -1125,10 +1020,10 @@ index 0ac1878..bbc94f7 100644 is_k_full=self.is_k_full, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py -index 293335f..51e5669 100644 +index 987ec512122..e098565729b 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py -@@ -1407,6 +1407,8 @@ class PauseContinueBroadcast: +@@ -1442,6 +1442,8 @@ class PauseContinueBroadcast: class UpdateWeightFromDiskReqInput(BaseReq): # The model path with the new weights model_path: str @@ -1137,7 +1032,7 @@ index 293335f..51e5669 100644 # The format to load the weights load_format: Optional[str] = None # Whether to abort all requests before updating weights -@@ -1437,6 +1439,41 @@ class UpdateWeightFromDiskReqOutput(BaseReq): +@@ -1472,6 +1474,40 @@ class UpdateWeightFromDiskReqOutput(BaseReq): num_paused_requests: Optional[int] = 0 @@ -1154,15 +1049,14 @@ index 293335f..51e5669 100644 + +@dataclass +class DeltaParam: -+ """Per-param slice into the shared (positions, values) bucket. ``pos_*`` index -+ into the uint8 byte blob; ``val_*`` index into the param-dtype value tensor.""" ++ """Per-param slice into the shared (positions, values) bucket.""" + + name: str + dtype: str + shape: List[int] + pos_start: int + pos_end: int -+ pos_width: int # 2 or 4 ++ pos_width: int + val_start: int + val_end: int + @@ -1179,7 +1073,7 @@ index 293335f..51e5669 100644 @dataclass class UpdateWeightsFromDistributedReqInput(BaseReq): names: List[str] -@@ -1452,6 +1489,8 @@ class UpdateWeightsFromDistributedReqInput(BaseReq): +@@ -1487,6 +1523,8 @@ class UpdateWeightsFromDistributedReqInput(BaseReq): weight_version: Optional[str] = None # Optional format specification for loading load_format: Optional[str] = None @@ -1188,7 +1082,7 @@ index 293335f..51e5669 100644 # Whether to call torch.cuda.empty_cache() during flush torch_empty_cache: bool = False -@@ -1638,6 +1677,18 @@ class ResumeMemoryOccupationReqOutput(BaseReq): +@@ -1673,6 +1711,18 @@ class ResumeMemoryOccupationReqOutput(BaseReq): pass @@ -1206,37 +1100,66 @@ index 293335f..51e5669 100644 + @dataclass class CheckWeightsReqInput(BaseReq): - action: str + action: str = "checksum" +@@ -2058,7 +2108,7 @@ class GetLoadsReqInput(BaseReq): + """Request for /v1/loads endpoint.""" + + VALID_SECTIONS = frozenset( +- {"core", "memory", "spec", "lora", "disagg", "queues", "all"} ++ {"core", "memory", "spec", "lora", "disagg", "queues", "inflight", "all"} + ) + + include: List[str] = field(default_factory=lambda: ["all"]) +@@ -2128,6 +2178,9 @@ class GetLoadsReqOutput(BaseReq): + lora: Optional[LoRAMetrics] = None + disaggregation: Optional[DisaggregationMetrics] = None + queues: Optional[QueueMetrics] = None ++ # Per-request breakdown of every queue, only populated when "inflight" or ++ # "all" is requested. ++ inflight: Optional[List[Dict[str, Any]]] = None + + + @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py -index feecc54..6fce256 100755 +index 42ea8431091..c369b070b57 100755 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py -@@ -874,6 +874,11 @@ class Req(ReqDllmMixin): +@@ -943,6 +943,7 @@ class Req(ReqDllmMixin): self.metrics_collector = metrics_collector if time_stats is not None: self.time_stats = SchedulerReqTimeStats.new_from_obj(time_stats) -+ # Force the scheduler-side disagg_mode: new_from_obj copies -+ # disagg_mode from the source object (APIServerReqTimeStats from -+ # the tokenizer manager has disagg_mode=NULL), which breaks the -+ # PD timing getters in SchedulerReqTimeStats. + self.time_stats.disagg_mode = disagg_mode else: self.time_stats = SchedulerReqTimeStats(disagg_mode=disagg_mode) self.time_stats.set_metrics_collector(metrics_collector) -@@ -2193,7 +2198,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -2383,11 +2384,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + + retracted_reqs = [] + first_iter = True ++ num_minimum_reqs = ( ++ 0 if server_args.disaggregation_mode == "decode" else 1 ++ ) while first_iter or ( not self.check_decode_mem(selected_indices=sorted_indices) ): - if len(sorted_indices) == 1: -+ # We should allow all requests to be retracted in decode disaggregation mode -+ # because there call be prealloc prefill requests. -+ num_minimum_reqs = 0 if server_args.disaggregation_mode == "decode" else 1 -+ if len(sorted_indices) == num_minimum_reqs: - # Always keep at least one request +- # Always keep at least one request ++ if len(sorted_indices) <= num_minimum_reqs: ++ # Unified mode keeps one request; decode disaggregation may retract all. break + first_iter = False +@@ -2398,7 +2402,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.release_req(idx, len(sorted_indices), server_args) + + reqs_to_abort: List[Req] = [] +- if len(sorted_indices) <= 1 and not self.check_decode_mem( ++ if len(sorted_indices) <= num_minimum_reqs and not self.check_decode_mem( + selected_indices=sorted_indices + ): + # Even the last remaining request cannot fit in memory. diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py -index 143054c..40e05e2 100644 +index 8e32640fc6a..98966842506 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -124,6 +124,7 @@ from sglang.srt.managers.io_struct import ( @@ -1247,63 +1170,171 @@ index 143054c..40e05e2 100644 ProfileReq, ReleaseMemoryOccupationReqInput, RemoveExternalCorpusReqInput, -@@ -1452,6 +1453,7 @@ class Scheduler( +@@ -1318,6 +1319,10 @@ class Scheduler( + UpdateWeightsFromIPCReqInput, + self.weight_updater.update_weights_from_ipc, ), - (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), - (UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc), -+ (PostProcessWeightsReqInput, self.post_process_weights), - (GetWeightsByNameReqInput, self.get_weights_by_name), - (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), - (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), -@@ -3649,9 +3651,16 @@ class Scheduler( - recv_req.abort_all or req.rid.startswith(recv_req.rid) - ): - # Abort method 3: set `to_finish` -- # The request will still run one decode forward pass. -+ # Decode requests may still run one forward pass. PD prefill -+ # consumes this before sending KV to the decode side. ++ ( ++ PostProcessWeightsReqInput, ++ self.weight_updater.post_process_weights, ++ ), + ( + GetWeightsByNameReqInput, + self.weight_updater.get_weights_by_name, +@@ -1577,6 +1582,10 @@ class Scheduler( + flush_cache=self.flush_cache, + is_fully_idle=self.is_fully_idle, + metrics_collector=self.metrics_collector, ++ disaggregation_mode=self.disaggregation_mode, ++ get_disagg_decode_transfer_queue=lambda: self.disagg_decode_transfer_queue, ++ get_disagg_decode_prealloc_queue=lambda: self.disagg_decode_prealloc_queue, ++ get_disagg_prefill_bootstrap_queue=lambda: self.disagg_prefill_bootstrap_queue, + ) + + def init_lora_drainer(self) -> None: +@@ -3721,6 +3730,12 @@ class Scheduler( + # The request will still run one decode forward pass. # Then we reuse all existing code to clean up the KV cache allocation. logger.debug(f"Abort running request. {req.rid=}") -+ if ( -+ self.disaggregation_mode == DisaggregationMode.PREFILL -+ and hasattr(req, "disagg_kv_sender") -+ and hasattr(req.disagg_kv_sender, "abort") ++ if self.disaggregation_mode == DisaggregationMode.PREFILL and hasattr( ++ req, "disagg_kv_sender" + ): -+ req.disagg_kv_sender.abort() ++ sender = getattr(req, "disagg_kv_sender", None) ++ if sender is not None and hasattr(sender, "abort"): ++ sender.abort() req.to_finish = FINISH_ABORT() def _pause_engine(self) -> Tuple[List[Req], int]: -diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py -index ae6f732..496e7bc 100644 ---- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py -+++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py -@@ -1261,7 +1261,7 @@ class SchedulerOutputProcessorMixin: - dp_ranks = [self.dp_rank] * len(rids) if rids else None - - # Send to detokenizer -- if reqs or is_idle_batch: -+ if rids or is_idle_batch: - self.send_to_detokenizer.send_output( - BatchTokenIDOutput( - rids=rids, -diff --git a/python/sglang/srt/managers/scheduler_profiler_mixin.py b/python/sglang/srt/managers/scheduler_profiler_mixin.py -index c02ed79..61733c4 100644 ---- a/python/sglang/srt/managers/scheduler_profiler_mixin.py -+++ b/python/sglang/srt/managers/scheduler_profiler_mixin.py -@@ -349,7 +349,7 @@ class SchedulerProfilerMixin: +diff --git a/python/sglang/srt/managers/scheduler_components/load_inquirer.py b/python/sglang/srt/managers/scheduler_components/load_inquirer.py +index 3f10d7edaff..712322a95af 100644 +--- a/python/sglang/srt/managers/scheduler_components/load_inquirer.py ++++ b/python/sglang/srt/managers/scheduler_components/load_inquirer.py +@@ -202,6 +202,88 @@ class SchedulerLoadInquirer: + retracted=self.get_stats().num_retracted_reqs, + ) + ++ inflight = None ++ if include_all or "inflight" in include: ++ now_perf = time.perf_counter() ++ inflight_queues = [("running", self.get_running_batch().reqs, None)] ++ if self.disaggregation_mode == DisaggregationMode.PREFILL: ++ inflight_queues += [ ++ ("waiting", self.get_waiting_queue(), "wait_queue_entry_time"), ++ ( ++ "bootstrap", ++ self.get_disagg_prefill_bootstrap_queue().queue, ++ "prefill_bootstrap_queue_entry_time", ++ ), ++ ( ++ "prefill_inflight", ++ self.get_disagg_prefill_inflight_queue(), ++ "prefill_transfer_queue_entry_time", ++ ), ++ ] ++ elif self.disaggregation_mode == DisaggregationMode.DECODE: ++ inflight_queues += [ ++ ("waiting", self.get_waiting_queue(), "wait_queue_entry_time"), ++ ( ++ "prealloc", ++ self.get_disagg_decode_prealloc_queue().queue, ++ "decode_prealloc_queue_entry_time", ++ ), ++ ( ++ "transfer", ++ self.get_disagg_decode_transfer_queue().queue, ++ "decode_transfer_queue_entry_time", ++ ), ++ ( ++ "retracted", ++ self.get_disagg_decode_prealloc_queue().retracted_queue, ++ "decode_prealloc_queue_entry_time", ++ ), ++ ] ++ else: ++ inflight_queues.append( ++ ("waiting", self.get_waiting_queue(), "wait_queue_entry_time") ++ ) ++ ++ def describe_req(entry, stage, entry_time_field): ++ req = getattr(entry, "req", entry) ++ info = { ++ "rid": getattr(req, "rid", None), ++ "bootstrap_room": getattr(req, "bootstrap_room", None), ++ "seqlen": getattr(entry, "seqlen", None), ++ "stage": stage, ++ } ++ if entry_time_field is not None: ++ time_stats = getattr(req, "time_stats", None) ++ entry_time = ( ++ getattr(time_stats, entry_time_field, 0.0) ++ if time_stats ++ else 0.0 ++ ) ++ info["age_s"] = ( ++ round(now_perf - entry_time, 3) if entry_time else None ++ ) ++ if entry is not req: ++ info["waiting_for_input"] = getattr( ++ entry, "waiting_for_input", None ++ ) ++ info["timeout_cancel_issued"] = getattr( ++ entry, "timeout_cancel_issued", None ++ ) ++ return info ++ ++ inflight = [] ++ for name, queue, entry_time_field in inflight_queues: ++ inflight.append( ++ { ++ "name": name, ++ "num_reqs": len(queue), ++ "reqs": [ ++ describe_req(entry, name, entry_time_field) ++ for entry in queue ++ ], ++ } ++ ) ++ + return GetLoadsReqOutput( + dp_rank=self.ps.dp_rank, + timestamp=time.time(), +@@ -221,4 +303,5 @@ class SchedulerLoadInquirer: + lora=lora, + disaggregation=disaggregation, + queues=queues, ++ inflight=inflight, + ) +diff --git a/python/sglang/srt/managers/scheduler_components/output_streamer.py b/python/sglang/srt/managers/scheduler_components/output_streamer.py +index cac80715856..2574fcfb55c 100644 +--- a/python/sglang/srt/managers/scheduler_components/output_streamer.py ++++ b/python/sglang/srt/managers/scheduler_components/output_streamer.py +@@ -481,7 +481,7 @@ class _GenerationStreamAccumulator: + def to_payload( + self, *, load, dp_rank: int, is_idle_batch: bool, has_reqs: bool + ) -> Optional[BatchTokenIDOutput]: +- if not (has_reqs or is_idle_batch): ++ if not (self.rids or is_idle_batch): + return None + dp_ranks = [dp_rank] * len(self.rids) if self.rids else None + return BatchTokenIDOutput( +diff --git a/python/sglang/srt/managers/scheduler_components/profiler_manager.py b/python/sglang/srt/managers/scheduler_components/profiler_manager.py +index 31df519f9e8..cdcf41cd8bc 100644 +--- a/python/sglang/srt/managers/scheduler_components/profiler_manager.py ++++ b/python/sglang/srt/managers/scheduler_components/profiler_manager.py +@@ -377,7 +377,7 @@ class SchedulerProfilerManager: if self.profiler_prefill_ct > self.profiler_target_prefill_ct: if self.profile_in_progress: - self.stop_profile(stage=ForwardMode.EXTEND) + self._stop_profile(stage=ForwardMode.EXTEND) - elif batch.forward_mode.is_decode(): + elif batch.forward_mode.is_decode() or batch.forward_mode.is_prebuilt(): if self.profiler_decode_ct == 0: if self.profile_in_progress: # force trace flush -diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py -index f2daf64..534d1b2 100644 ---- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py -+++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py -@@ -12,6 +12,7 @@ from sglang.srt.constants import ( +diff --git a/python/sglang/srt/managers/scheduler_components/weight_updater.py b/python/sglang/srt/managers/scheduler_components/weight_updater.py +index 77bf823b081..9ab3abe5618 100644 +--- a/python/sglang/srt/managers/scheduler_components/weight_updater.py ++++ b/python/sglang/srt/managers/scheduler_components/weight_updater.py +@@ -16,6 +16,7 @@ from sglang.srt.constants import ( GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS, ) @@ -1311,7 +1342,7 @@ index f2daf64..534d1b2 100644 from sglang.srt.managers.io_struct import ( CheckWeightsReqInput, CheckWeightsReqOutput, -@@ -21,6 +22,8 @@ from sglang.srt.managers.io_struct import ( +@@ -25,6 +26,8 @@ from sglang.srt.managers.io_struct import ( GetWeightsByNameReqOutput, InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqOutput, @@ -1320,64 +1351,89 @@ index f2daf64..534d1b2 100644 ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, -@@ -120,6 +123,11 @@ class SchedulerUpdateWeightsMixin: - torch.distributed.barrier(group=self.tp_cpu_group) - return UpdateWeightsFromIPCReqOutput(success, message) +@@ -78,6 +81,10 @@ class SchedulerWeightUpdaterManager: + flush_cache: Callable[..., bool] + is_fully_idle: Callable[..., bool] + metrics_collector: Optional[Any] = None ++ disaggregation_mode: DisaggregationMode = DisaggregationMode.NULL ++ get_disagg_decode_transfer_queue: Optional[Callable[..., Any]] = None ++ get_disagg_decode_prealloc_queue: Optional[Callable[..., Any]] = None ++ get_disagg_prefill_bootstrap_queue: Optional[Callable[..., Any]] = None + offload_tags: set = field(default_factory=set) + stashed_model_static_state: Any = None + +@@ -175,6 +182,19 @@ class SchedulerWeightUpdaterManager: + parameter = self.tp_worker.get_weights_by_name(recv_req) + return GetWeightsByNameReqOutput(parameter) + def post_process_weights(self, recv_req: PostProcessWeightsReqInput): -+ """Optional post-processing for updated weights (e.g., Marlin conversion).""" + success, message = self.tp_worker.post_process_weights(recv_req) ++ if ( ++ success ++ and self.draft_worker is not None ++ and hasattr(self.draft_worker, "post_process_weights") ++ ): ++ success, message = self.draft_worker.post_process_weights(recv_req) ++ if not success: ++ logger.error(message) ++ torch.distributed.barrier(group=self.tp_cpu_group) + return PostProcessWeightsReqOutput(success, message) + - def get_weights_by_name(self: Scheduler, recv_req: GetWeightsByNameReqInput): - parameter = self.tp_worker.get_weights_by_name(recv_req) - return GetWeightsByNameReqOutput(parameter) -@@ -143,6 +151,15 @@ class SchedulerUpdateWeightsMixin: + def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput): + assert ( + self.is_fully_idle() +@@ -191,6 +211,18 @@ class SchedulerWeightUpdaterManager: + if GPU_MEMORY_TYPE_KV_CACHE in tags: self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) self.flush_cache() ++ if ( ++ self.disaggregation_mode == DisaggregationMode.DECODE ++ and self.get_disagg_decode_transfer_queue is not None ++ and self.get_disagg_decode_prealloc_queue is not None ++ ): ++ self.get_disagg_decode_transfer_queue().release_memory_occupation() ++ self.get_disagg_decode_prealloc_queue().release_memory_occupation() ++ elif ( ++ self.disaggregation_mode == DisaggregationMode.PREFILL ++ and self.get_disagg_prefill_bootstrap_queue is not None ++ ): ++ self.get_disagg_prefill_bootstrap_queue().release_memory_occupation() -+ if self.disaggregation_mode == DisaggregationMode.DECODE: -+ if hasattr(self, "disagg_decode_transfer_queue"): -+ self.disagg_decode_transfer_queue.release_memory_occupation() -+ if hasattr(self, "disagg_decode_prealloc_queue"): -+ self.disagg_decode_prealloc_queue.release_memory_occupation() -+ elif self.disaggregation_mode == DisaggregationMode.PREFILL: -+ if hasattr(self, "disagg_prefill_bootstrap_queue"): -+ self.disagg_prefill_bootstrap_queue.release_memory_occupation() -+ if GPU_MEMORY_TYPE_WEIGHTS in tags: self.stashed_model_static_state = _export_static_state( - self.tp_worker.model_runner.model -@@ -183,6 +200,15 @@ class SchedulerUpdateWeightsMixin: +@@ -229,6 +261,18 @@ class SchedulerWeightUpdaterManager: + if GPU_MEMORY_TYPE_KV_CACHE in tags: self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) ++ if ( ++ self.disaggregation_mode == DisaggregationMode.DECODE ++ and self.get_disagg_decode_transfer_queue is not None ++ and self.get_disagg_decode_prealloc_queue is not None ++ ): ++ self.get_disagg_decode_prealloc_queue().resume_memory_occupation() ++ self.get_disagg_decode_transfer_queue().resume_memory_occupation() ++ elif ( ++ self.disaggregation_mode == DisaggregationMode.PREFILL ++ and self.get_disagg_prefill_bootstrap_queue is not None ++ ): ++ self.get_disagg_prefill_bootstrap_queue().resume_memory_occupation() -+ if self.disaggregation_mode == DisaggregationMode.DECODE: -+ if hasattr(self, "disagg_decode_transfer_queue"): -+ self.disagg_decode_transfer_queue.resume_memory_occupation() -+ if hasattr(self, "disagg_decode_prealloc_queue"): -+ self.disagg_decode_prealloc_queue.resume_memory_occupation() -+ elif self.disaggregation_mode == DisaggregationMode.PREFILL: -+ if hasattr(self, "disagg_prefill_bootstrap_queue"): -+ self.disagg_prefill_bootstrap_queue.resume_memory_occupation() -+ return ResumeMemoryOccupationReqOutput() - def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): diff --git a/python/sglang/srt/managers/tokenizer_control_mixin.py b/python/sglang/srt/managers/tokenizer_control_mixin.py -index 05382e0..7eea220 100644 +index c9939a1fc93..ee25e5e70e0 100644 --- a/python/sglang/srt/managers/tokenizer_control_mixin.py +++ b/python/sglang/srt/managers/tokenizer_control_mixin.py -@@ -76,6 +76,8 @@ from sglang.srt.managers.io_struct import ( - UpdateWeightsFromDistributedReqOutput, - UpdateWeightsFromIPCReqInput, - UpdateWeightsFromIPCReqOutput, +@@ -48,6 +48,8 @@ from sglang.srt.managers.io_struct import ( + LoadLoRAAdapterReqOutput, + LoRAUpdateOutput, + OpenSessionReqInput, + PostProcessWeightsReqInput, + PostProcessWeightsReqOutput, - UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, - ) -@@ -102,6 +104,7 @@ _COMMUNICATOR_SPECS = [ + ProfileReq, + ProfileReqOutput, + ProfileReqType, +@@ -96,6 +98,7 @@ _COMMUNICATOR_SPECS = [ ("send_weights_to_remote_instance", SendWeightsToRemoteInstanceReqOutput), ("update_weights_from_tensor", UpdateWeightsFromTensorReqOutput), ("update_weights_from_ipc", UpdateWeightsFromIPCReqOutput), @@ -1385,29 +1441,28 @@ index 05382e0..7eea220 100644 ("get_weights_by_name", GetWeightsByNameReqOutput), ("release_memory_occupation", ReleaseMemoryOccupationReqOutput), ("resume_memory_occupation", ResumeMemoryOccupationReqOutput), -@@ -531,6 +534,17 @@ class TokenizerControlMixin: - - return success, message +@@ -754,6 +757,16 @@ class TokenizerControlMixin: + self.auto_create_handle_loop() + await self.resume_memory_occupation_communicator(obj) + async def post_process_weights( + self: TokenizerManager, + obj: PostProcessWeightsReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: -+ """Trigger post-processing hooks for weights after loading (e.g., Marlin conversion).""" + self.auto_create_handle_loop() + async with self.model_update_lock.writer_lock: + results = await self.post_process_weights_communicator(obj) -+ return FanOutCommunicator.merge_results(results) ++ return FanOutCommunicator.merge_results(results) + - async def _unload_lora_adapter_locked( + async def check_weights( self: TokenizerManager, - obj: UnloadLoRAAdapterReqInput, + obj: CheckWeightsReqInput, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index 6375a1d..3fc175e 100644 +index 357e3c4675a..1f6dc90e471 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py -@@ -1501,7 +1501,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): +@@ -1641,7 +1641,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): async with self.is_pause_cond: self.is_pause = True if obj.mode != "abort": @@ -1416,7 +1471,7 @@ index 6375a1d..3fc175e 100644 else: # we are using the model_update_lock to check if there is still on-going requests. while True: -@@ -1515,7 +1515,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): +@@ -1655,7 +1655,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): async def continue_generation(self, obj: ContinueGenerationReqInput): async with self.is_pause_cond: self.is_pause = False @@ -1425,7 +1480,7 @@ index 6375a1d..3fc175e 100644 self.is_pause_cond.notify_all() async def update_weights_from_disk( -@@ -1564,7 +1564,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): +@@ -1704,7 +1704,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): self.model_update_result = asyncio.Future() if self.server_args.dp_size == 1: result = await self.model_update_result @@ -1434,7 +1489,7 @@ index 6375a1d..3fc175e 100644 self._update_model_path_info(obj.model_path, obj.load_format) return result.success, result.message, result.num_paused_requests else: # self.server_args.dp_size > 1 -@@ -1572,7 +1572,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): +@@ -1712,7 +1712,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): result = await self.model_update_result all_success = all([r.success for r in result]) @@ -1443,7 +1498,7 @@ index 6375a1d..3fc175e 100644 self._update_model_path_info(obj.model_path, obj.load_format) all_message = [r.message for r in result] all_message = " | ".join(all_message) -@@ -2177,25 +2177,23 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): +@@ -2343,25 +2343,23 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): priority = getattr(state.obj, "priority", None) if priority is not None: labels["priority"] = str(priority) @@ -1477,7 +1532,7 @@ index 6375a1d..3fc175e 100644 if state.finished: # Get detailed cache breakdown if available diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py -index 60e105d..6a37a82 100644 +index bd9184408ed..71bbe8f400f 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -29,6 +29,7 @@ from sglang.srt.managers.io_struct import ( @@ -1488,51 +1543,45 @@ index 60e105d..6a37a82 100644 SendWeightsToRemoteInstanceReqInput, UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, -@@ -97,6 +98,7 @@ class BaseTpWorker(ABC): - success, message = self.model_runner.update_weights_from_disk( +@@ -98,6 +99,7 @@ class BaseTpWorker(ABC): recv_req.model_path, recv_req.load_format, -+ files=recv_req.files, recapture_cuda_graph=recv_req.recapture_cuda_graph, ++ files=recv_req.files, ) return success, message -@@ -152,6 +154,7 @@ class BaseTpWorker(ABC): + +@@ -152,6 +154,14 @@ class BaseTpWorker(ABC): recv_req.shapes, recv_req.group_name, recv_req.load_format, + recv_req.delta, ++ ) ++ return success, message ++ ++ def post_process_weights(self, recv_req: PostProcessWeightsReqInput): ++ success, message = self.model_runner.post_process_weights( ++ restore_weights_before_load=recv_req.restore_weights_before_load, ++ post_process_quantization=recv_req.post_process_quantization, ) return success, message -@@ -171,6 +174,11 @@ class BaseTpWorker(ABC): - success, message = self.model_runner.update_weights_from_ipc(recv_req) - return success, message - -+ def post_process_weights(self, recv_req: PostProcessWeightsReqInput): -+ """Perform optional post-processing on the updated model weights (e.g., Marlin conversion).""" -+ success, message = self.model_runner.post_process_weights(recv_req) -+ return success, message -+ - def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): - parameter = self.model_runner.get_weights_by_name( - recv_req.name, recv_req.truncate_size diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py -index 3416d2a..df3866c 100644 +index 353a02ee0be..7e3e3f58cb9 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py -@@ -842,9 +842,8 @@ class HiRadixCache(RadixCache): +@@ -1009,9 +1009,7 @@ class HiRadixCache(RadixCache): self._update_leaf_status(node) self._update_host_leaf_status(node) if node.parent is None: - assert ( - node is self.root_node - ), f"This request holds the node from another tree" -+ # Node belongs to a stale (flushed) tree — stop traversal gracefully. + break node = node.parent return DecLockRefResult(delta=delta) -@@ -924,6 +923,7 @@ class HiRadixCache(RadixCache): +@@ -1091,6 +1089,7 @@ class HiRadixCache(RadixCache): self._update_host_leaf_status(node) # update leaf status for the parent because the node is evicted self._update_leaf_status(node.parent) @@ -1540,27 +1589,11 @@ index 3416d2a..df3866c 100644 return num_evicted def _evict_regular(self, node: TreeNode): -@@ -1447,6 +1447,7 @@ class HiRadixCache(RadixCache): - self._update_host_leaf_status(node) - # update parent status as a new leaf is added into device - self._update_leaf_status(node.parent) -+ self._update_host_leaf_status(node.parent) - else: - self._inc_hit_count(node, chunked) - total_prefix_length += prefix_len -@@ -1462,6 +1463,7 @@ class HiRadixCache(RadixCache): - self._update_host_leaf_status(new_node) - # update parent status as a new leaf is added into device - self._update_leaf_status(new_node.parent) -+ self._update_host_leaf_status(new_node.parent) - else: - self._inc_hit_count(new_node, chunked) - total_prefix_length += prefix_len diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py -index 9bf36c3..962b9f7 100644 +index 8efe9aae94e..79e9885c92f 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py -@@ -2038,9 +2038,12 @@ class NSATokenToKVPool(MLATokenToKVPool): +@@ -2244,9 +2244,12 @@ class DSATokenToKVPool(MLATokenToKVPool): else: assert self.page_size == 64 with ( @@ -1576,106 +1609,33 @@ index 9bf36c3..962b9f7 100644 ): self.index_k_with_scale_buffer = [ torch.zeros( -@@ -2062,6 +2065,11 @@ class NSATokenToKVPool(MLATokenToKVPool): - ) - for _ in range(layer_num) - ] -+ self.index_k_with_scale_buffer_ptrs = torch.tensor( -+ [x.data_ptr() for x in self.index_k_with_scale_buffer], -+ dtype=torch.uint64, -+ device=self.device, -+ ) - self._finalize_allocation_log(size) - - def _clear_buffers(self): -@@ -2198,6 +2206,50 @@ class NSATokenToKVPool(MLATokenToKVPool): - ] - return data_ptrs, data_lens, item_lens - -+ def get_cpu_copy(self, indices): -+ # First, save the kv_buffer (inherited from MLATokenToKVPool) -+ kv_cache_cpu = super().get_cpu_copy(indices) -+ -+ # Additionally, save the index_k_with_scale_buffer (page-indexed) -+ page_indices = indices[:: self.page_size] // self.page_size -+ torch.cuda.synchronize() -+ index_k_cpu = [] -+ chunk_size = self.cpu_offloading_chunk_size -+ # Convert chunk_size from token-level to page-level -+ page_chunk_size = max(1, chunk_size // self.page_size) -+ for layer_id in range(self.layer_num): -+ index_k_cpu.append([]) -+ for i in range(0, len(page_indices), page_chunk_size): -+ chunk_page_indices = page_indices[i : i + page_chunk_size] -+ idx_cpu = self.index_k_with_scale_buffer[layer_id][ -+ chunk_page_indices -+ ].to("cpu", non_blocking=True) -+ index_k_cpu[-1].append(idx_cpu) -+ torch.cuda.synchronize() -+ -+ return {"kv": kv_cache_cpu, "index_k": index_k_cpu} -+ -+ def load_cpu_copy(self, kv_cache_cpu_dict, indices): -+ # Restore the kv_buffer (inherited from MLATokenToKVPool) -+ super().load_cpu_copy(kv_cache_cpu_dict["kv"], indices) -+ -+ # Restore the index_k_with_scale_buffer (page-indexed) -+ page_indices = indices[:: self.page_size] // self.page_size -+ index_k_cpu = kv_cache_cpu_dict["index_k"] -+ torch.cuda.synchronize() -+ chunk_size = self.cpu_offloading_chunk_size -+ page_chunk_size = max(1, chunk_size // self.page_size) -+ for layer_id in range(self.layer_num): -+ for i in range(0, len(page_indices), page_chunk_size): -+ chunk_page_indices = page_indices[i : i + page_chunk_size] -+ idx_cpu = index_k_cpu[layer_id][i // page_chunk_size] -+ assert idx_cpu.shape[0] == len(chunk_page_indices) -+ idx_chunk = idx_cpu.to( -+ self.index_k_with_scale_buffer[0].device, non_blocking=True -+ ) -+ self.index_k_with_scale_buffer[layer_id][chunk_page_indices] = idx_chunk -+ torch.cuda.synchronize() -+ - def get_kv_size_bytes(self): - kv_size_bytes = super().get_kv_size_bytes() - for index_k_cache in self.index_k_with_scale_buffer: diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py -index 5f8a256..5eff30e 100644 +index bd6adb6e398..5ea935f76e9 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py -@@ -489,7 +489,17 @@ class RadixCache(KVCacheEventMixin, BasePrefixCache): - if self.disable: +@@ -467,6 +467,9 @@ class RadixCache(KVCacheEventMixin, BasePrefixCache): return -- token_ids = req.fill_ids -+ # Limit to kv_committed_len to avoid including tokens (e.g., the just-generated -+ # token in disagg prefill) that don't have computed KV yet. If fill_ids is longer -+ # than kv_committed_len, the extra tokens would produce stale values (0 from -+ # req_to_token_pool initialization), leading to spurious tree nodes and memory -+ # leak when page-aligned token counts happen to cross a page boundary. -+ kv_committed_len = req.kv_committed_len -+ token_ids = ( -+ req.fill_ids[:kv_committed_len] -+ if kv_committed_len < len(req.fill_ids) -+ else req.fill_ids -+ ) + token_ids = req.get_fill_ids() ++ kv_committed_len = getattr(req, "kv_committed_len", None) ++ if kv_committed_len is not None and len(token_ids) > kv_committed_len: ++ token_ids = token_ids[:kv_committed_len] kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(token_ids) ] -@@ -616,9 +626,8 @@ class RadixCache(KVCacheEventMixin, BasePrefixCache): +@@ -593,9 +596,7 @@ class RadixCache(KVCacheEventMixin, BasePrefixCache): node.lock_ref -= 1 self._update_leaf_status(node) if node.parent is None: - assert ( - node is self.root_node - ), f"This request holds the node from another tree" -+ # Node belongs to a stale (flushed) tree — stop traversal gracefully. + break node = node.parent return DecLockRefResult(delta=delta) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index eff3a56..5b1eec9 100644 +index 3b30eb0e1f7..d715bc6893d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -20,7 +20,9 @@ import datetime @@ -1688,30 +1648,28 @@ index eff3a56..5b1eec9 100644 import os import socket import threading -@@ -29,7 +31,7 @@ import uuid +@@ -28,7 +30,7 @@ import time from collections import defaultdict - from dataclasses import dataclass + from dataclasses import dataclass, replace from pathlib import Path --from typing import Callable, List, Optional, Tuple, Union -+from typing import Callable, Dict, List, Optional, Tuple, Union +-from typing import Any, Callable, List, Optional, Tuple, Union ++from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist -@@ -129,6 +131,7 @@ from sglang.srt.layers.sampler import create_sampler - from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model +@@ -137,6 +139,7 @@ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model + from sglang.srt.layers.utils.cp_utils import is_mla_prefill_cp_enabled from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_registry import LoRARef +from sglang.srt.managers.io_struct import DeltaEncoding, DeltaParam, DeltaSpec from sglang.srt.managers.schedule_batch import sanity_check_mm_pad_shift_value from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool -@@ -514,7 +517,12 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -548,7 +551,10 @@ class ModelRunner(ModelRunnerKVCacheMixin): self.forward_stream = torch.get_device_module(self.device).Stream() # CPU offload - set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank)) -+ # For draft worker (e.g., MTP), do not set offloader to avoid overriding -+ # the main model's offloader. Draft worker uses NoopOffloader instead. + if not is_draft_worker: + set_offloader( + create_offloader_from_server_args(server_args, dp_rank=dp_rank) @@ -1719,7 +1677,7 @@ index eff3a56..5b1eec9 100644 self._weight_checker = WeightChecker(model_runner=self) -@@ -752,7 +760,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -796,7 +802,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): self.maybe_init_ngram_embedding() # Init routed experts capturer @@ -1729,25 +1687,23 @@ index eff3a56..5b1eec9 100644 self.init_indexer_capturer() -@@ -1753,8 +1762,16 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -1657,8 +1664,14 @@ class ModelRunner(ModelRunnerKVCacheMixin): load_format: str, weight_name_filter: Optional[Callable[[str], bool]] = None, recapture_cuda_graph: bool = False, + files: Optional[List[str]] = None, ) -> tuple[bool, str]: - """Update engine weights in-place from the disk.""" -+ """Update weights in-place from disk. For ``load_format="delta"``, read + -+ apply each basename in ``files`` under ``model_path``; otherwise reload the -+ HF checkpoint at ``model_path``.""" ++ """Update engine weights in-place from disk.""" + if load_format == "delta": + if not files: -+ return False, "load_format='delta' requires non-empty ``files``" ++ return False, "load_format='delta' requires non-empty `files`" + return self._apply_delta([os.path.join(model_path, f) for f in files]) + logger.info( f"Update engine weights online from disk begin. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id, empty_cache=False):.2f} GB" -@@ -1984,6 +2001,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -1888,6 +1901,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): shapes, group_name, load_format: Optional[str] = None, @@ -1755,7 +1711,7 @@ index eff3a56..5b1eec9 100644 ): """ Update specific parameter in the model weights online -@@ -2004,6 +2022,18 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -1908,6 +1922,18 @@ class ModelRunner(ModelRunnerKVCacheMixin): return self._update_bucketed_weights_from_distributed( names, dtypes, shapes, group_name ) @@ -1774,7 +1730,7 @@ index eff3a56..5b1eec9 100644 try: weights = [] handles = [] -@@ -2067,6 +2097,156 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -1971,6 +1997,151 @@ class ModelRunner(ModelRunnerKVCacheMixin): logger.error(error_msg) return False, error_msg @@ -1785,10 +1741,9 @@ index eff3a56..5b1eec9 100644 + values: torch.Tensor, + p: DeltaParam, + ) -> torch.Tensor: -+ """Decode one param's (positions, values) into a full-shape NaN-masked tensor. -+ NaN at unchanged positions triggers the patched-copy on apply.""" ++ """Decode one param's sparse delta into a NaN-masked full tensor.""" + numel = math.prod(p.shape) -+ param_dtype = p.dtype if isinstance(p.dtype, torch.dtype) else getattr(torch, p.dtype) ++ param_dtype = _resolve_torch_dtype(p.dtype) + flat = torch.full((numel,), float("nan"), dtype=param_dtype, device=self.device) + val_slice = values[p.val_start : p.val_end] + if val_slice.numel() == 0: @@ -1796,9 +1751,9 @@ index eff3a56..5b1eec9 100644 + + pos_bytes = positions[p.pos_start : p.pos_end] + if encoding is DeltaEncoding.INDICES: -+ width = 4 # int32 absolute indices ++ width = 4 + elif encoding in (DeltaEncoding.DELTAS, DeltaEncoding.DELTAS_ZSTD): -+ width = p.pos_width # uint16 or uint32 gap-deltas ++ width = p.pos_width + else: + raise ValueError(f"unsupported delta encoding: {encoding!r}") + @@ -1806,19 +1761,14 @@ index eff3a56..5b1eec9 100644 + b = pos_bytes.view(n_elems, width).to(torch.int64) + if width == 2: + unpacked = b[:, 0] | (b[:, 1] << 8) -+ else: # 4 ++ else: + unpacked = b[:, 0] | (b[:, 1] << 8) | (b[:, 2] << 16) | (b[:, 3] << 24) + + if encoding is DeltaEncoding.INDICES: + idx = unpacked + else: -+ # Sender encodes ``delta[k] = idx[k] - idx[k-1] - 1`` with idx[-1] := -1; -+ # receiver inverts with ``idx = cumsum(delta + 1) - 1``. + idx = (unpacked + 1).cumsum(dim=0) - 1 -+ # Sender may concat values across params of mixed dtypes (bf16 weights -+ # + fp32 norms in one bucket); torch.cat promotes to the widest dtype, -+ # so re-cast each slice back to the param's own dtype. The promoted -+ # round-trip is exact (bf16 ⊂ fp32), no precision loss. ++ + flat.index_copy_(0, idx, val_slice.to(param_dtype)) + return flat.view(tuple(p.shape)) + @@ -1830,14 +1780,12 @@ index eff3a56..5b1eec9 100644 + values: torch.Tensor, + expected_checksum: int, + ) -> None: -+ """Verify checksum, decode each param, apply via the patched-copy context. -+ ``load_weights`` is called per ``update_weight_delta_chunk_bytes`` budget.""" + actual_checksum = _delta_checksum(positions, values) + if actual_checksum != expected_checksum: + raise RuntimeError( -+ f"delta checksum mismatch: expected={expected_checksum} got={actual_checksum}; " -+ "indicates corruption between sender encode and receiver apply" ++ f"delta checksum mismatch: expected={expected_checksum} got={actual_checksum}" + ) ++ + chunk_byte_cap = self.server_args.update_weight_delta_chunk_bytes + with _delta_apply_context(self.model): + chunk: List[Tuple[str, torch.Tensor]] = [] @@ -1855,12 +1803,10 @@ index eff3a56..5b1eec9 100644 + self.model.load_weights(chunk) + + def _decode_and_apply_blob(self, blob: bytes) -> None: -+ """Decode + apply one decompressed safetensors blob from the delta sender.""" + from safetensors.torch import load as st_load + -+ # st_load only returns tensors, so parse the header for metadata. + hdr_len = int.from_bytes(blob[:8], "little") -+ meta = json.loads(blob[8:8 + hdr_len]).get("__metadata__", {}) ++ meta = json.loads(blob[8 : 8 + hdr_len]).get("__metadata__", {}) + encoding = DeltaEncoding(meta["encoding"]) + params = [DeltaParam(**p) for p in json.loads(meta["params"])] + expected_checksum = int(meta["checksum"]) @@ -1868,7 +1814,9 @@ index eff3a56..5b1eec9 100644 + tensors = st_load(blob) + positions = tensors["__positions__"].to(self.device, non_blocking=True) + values = tensors["__values__"].to(self.device, non_blocking=True) -+ self._apply_delta_payload(encoding, params, positions, values, expected_checksum) ++ self._apply_delta_payload( ++ encoding, params, positions, values, expected_checksum ++ ) + + def _apply_delta_from_distributed( + self, @@ -1878,24 +1826,30 @@ index eff3a56..5b1eec9 100644 + group_name: str, + delta: DeltaSpec, + ) -> tuple[bool, str]: -+ """NCCL receive: broadcast (positions, values) from sender, then apply.""" + try: + recv: Dict[str, torch.Tensor] = {} + handles = [] + for name, dtype, shape in zip(names, dtypes, shapes): -+ target_dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) ++ target_dtype = _resolve_torch_dtype(dtype) + t = torch.empty(shape, dtype=target_dtype, device=self.device) + handles.append( + torch.distributed.broadcast( -+ t, src=0, group=self._model_update_group[group_name], async_op=True ++ t, ++ src=0, ++ group=self._model_update_group[group_name], ++ async_op=True, + ) + ) + recv[name] = t -+ for h in handles: -+ h.wait() ++ for handle in handles: ++ handle.wait() + + self._apply_delta_payload( -+ delta.encoding, delta.params, recv["__positions__"], recv["__values__"], delta.checksum ++ delta.encoding, ++ delta.params, ++ recv["__positions__"], ++ recv["__values__"], ++ delta.checksum, + ) + return True, "ok" + except Exception as e: @@ -1904,7 +1858,6 @@ index eff3a56..5b1eec9 100644 + return False, error_msg + + def _apply_delta(self, paths: List[str]) -> tuple[bool, str]: -+ """Read + decompress delta safetensors files in parallel, decode + apply each.""" + import concurrent.futures + + n_files = len(paths) @@ -1915,8 +1868,6 @@ index eff3a56..5b1eec9 100644 + return _maybe_zstd_decompress(fh.read()) + + try: -+ # Cap peak memory at workers × file_size by applying each batch before -+ # prefetching the next. + for i in range(0, n_files, workers): + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as pool: + batch = list(pool.map(_read_and_decompress, paths[i : i + workers])) @@ -1931,16 +1882,15 @@ index eff3a56..5b1eec9 100644 def update_weights_from_tensor( self, named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]], -@@ -3346,11 +3526,18 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -3468,11 +3639,17 @@ class ModelRunner(ModelRunnerKVCacheMixin): output.expert_distribution_metrics = recorder_outputs.get("metrics") no_copy_to_cpu = not self.server_args.disable_overlap_schedule -+ # In speculative decoding, num_tokens_per_bs > 1, so we need to pass -+ # the actual number of tokens per dp rank in cuda graph, not batch size. + cuda_graph_num_tokens = None -+ if getattr(self.graph_runner, "bs", None): -+ cuda_graph_num_tokens = self.graph_runner.bs * getattr( -+ self.graph_runner, "num_tokens_per_bs", 1 ++ decode_graph_runner = getattr(self, "decode_cuda_graph_runner", None) ++ if getattr(decode_graph_runner, "bs", None): ++ cuda_graph_num_tokens = decode_graph_runner.bs * getattr( ++ decode_graph_runner, "num_tokens_per_bs", 1 + ) if (experts_capturer := get_global_experts_capturer()) is not None: output.routed_experts_output = experts_capturer.on_forward_end( @@ -1951,7 +1901,7 @@ index eff3a56..5b1eec9 100644 no_copy_to_cpu=no_copy_to_cpu, ) -@@ -3358,7 +3545,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -3480,7 +3657,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): output.indexer_topk_output = indexer_capturer.on_forward_end( forward_batch=forward_batch, can_run_graph=output.can_run_graph, @@ -1960,50 +1910,51 @@ index eff3a56..5b1eec9 100644 no_copy_to_cpu=no_copy_to_cpu, ) -@@ -3641,6 +3828,161 @@ class ModelRunner(ModelRunnerKVCacheMixin): - ) - return output +@@ -3718,6 +3895,39 @@ class ModelRunner(ModelRunnerKVCacheMixin): + logger.error(f"IPC weight update failed: {e}") + return False, str(e) -+ def post_process_weights(self, recv_req): -+ """ -+ Execute post-processing logic for model weights, such as Marlin quantization format conversion. -+ """ ++ def post_process_weights( ++ self, ++ restore_weights_before_load: bool = False, ++ post_process_quantization: bool = False, ++ ): ++ """Run optional post-loading hooks, such as quantization repacking.""" + from sglang.srt.model_loader.loader import device_loading_context + -+ target_device = torch.device("cuda", torch.cuda.current_device()) ++ if self.device == "cuda": ++ target_device = torch.device("cuda", torch.cuda.current_device()) ++ else: ++ target_device = torch.device(self.device) + -+ if recv_req.restore_weights_before_load: ++ if restore_weights_before_load: + for _, module in self.model.named_modules(): + quant_method = getattr(module, "quant_method", None) -+ -+ # Check if the module supports restoring weights + if quant_method is not None and hasattr( + quant_method, "restore_weights_before_loading" + ): -+ + with device_loading_context(module, target_device): + quant_method.restore_weights_before_loading(module) + -+ if recv_req.post_process_quantization: -+ # Iterate through all modules to apply specific post-loading processing ++ if post_process_quantization: + for _, module in self.model.named_modules(): + quant_method = getattr(module, "quant_method", None) -+ -+ # Check if the module supports quantization post-processing + if quant_method is not None and hasattr( + quant_method, "process_weights_after_loading" + ): -+ -+ # Apply the post-processing (e.g., repacking weights for Marlin kernel) + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + return True, "Success" + -+ + def prealloc_symmetric_memory_pool(self): + # PyTorch mempools never de-fragment memory in OOM scenarios, so we need to pre-allocate a large chunk of memory to limit fragmentation. + if ( +@@ -3767,6 +3977,123 @@ class ModelRunner(ModelRunnerKVCacheMixin): + return output + + +def _param_storage_index(model): -+ """Build ``find_parent(dst)``: looks up the param/buffer owning ``dst``'s storage, -+ or None. Used by ``_delta_apply_context`` to scope its patched copy_/fill_.""" + import bisect + + starts: List[int] = [] @@ -2025,6 +1976,7 @@ index eff3a56..5b1eec9 100644 + starts.append(ptr) + ends.append(ptr + sz) + owners.append(t) ++ + order = sorted(range(len(starts)), key=lambda i: starts[i]) + starts = [starts[i] for i in order] + ends = [ends[i] for i in order] @@ -2045,10 +1997,6 @@ index eff3a56..5b1eec9 100644 + +@contextlib.contextmanager +def _delta_apply_context(model): -+ """Patch ``copy_`` / ``fill_`` so writes into ``model``'s param storage skip -+ positions whose source is NaN. Non-param writes go through unmodified. -+ ``post_load_weights`` runs in the original env so derived tensors (fp8 scales, -+ MoE biases, w_kc/w_vc) overwrite as usual.""" + is_param_target = _param_storage_index(model) + original_copy_ = torch.Tensor.copy_ + original_fill_ = torch.Tensor.fill_ @@ -2067,8 +2015,6 @@ index eff3a56..5b1eec9 100644 + + def patched_fill_(self, value): + if is_param_target(self) is not None: -+ # NaN scalar means "don't change the param" (per-element analog of -+ # patched_copy_). Non-NaN scalars write through. + try: + if math.isnan(value): + return self @@ -2079,6 +2025,7 @@ index eff3a56..5b1eec9 100644 + + original_post_load = getattr(model, "post_load_weights", None) + if original_post_load is not None: ++ + def wrapped_post_load(*args, **kwargs): + current_copy = torch.Tensor.copy_ + current_fill = torch.Tensor.fill_ @@ -2104,182 +2051,43 @@ index eff3a56..5b1eec9 100644 + + +def _delta_checksum(positions: torch.Tensor, values: torch.Tensor) -> int: -+ """Wire-corruption check, must match the sender's computation.""" + p = int(torch.hash_tensor(positions).item()) if positions.numel() else 0 + v = int(torch.hash_tensor(values).item()) if values.numel() else 0 + return p ^ (v << 1) + + +def _maybe_zstd_decompress(blob: bytes) -> bytes: -+ """Decompress if zstd-framed (sender uses zstd when encoding=deltas_zstd).""" -+ # Zstandard frame magic: 0xFD2FB528 little-endian (RFC 8478 §3.1.1). + if blob.startswith(b"\x28\xb5\x2f\xfd"): + import zstandard + + return zstandard.ZstdDecompressor().decompress(blob) + return blob + - - def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): - params_dict = dict(model.named_parameters()) -diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py -index 84c57c1..3373f25 100644 ---- a/python/sglang/srt/model_loader/weight_utils.py -+++ b/python/sglang/srt/model_loader/weight_utils.py -@@ -685,7 +685,13 @@ def maybe_add_mtp_safetensors( - getattr(hf_config, "num_nextn_predict_layers", 0), - ) - if not ( -- arch in ["Glm4MoeForCausalLM", "Glm4MoeForCausalLMNextN"] -+ arch -+ in [ -+ "Glm4MoeForCausalLM", -+ "Glm4MoeForCausalLMNextN", -+ "Glm4MoeLiteForCausalLM", -+ "Glm4MoeLiteForCausalLMNextN", -+ ] - and num_nextn_layers > 0 - ): - return hf_weights_files -diff --git a/python/sglang/srt/models/glm4_moe_lite.py b/python/sglang/srt/models/glm4_moe_lite.py -index 80a0351..1f7cdd1 100644 ---- a/python/sglang/srt/models/glm4_moe_lite.py -+++ b/python/sglang/srt/models/glm4_moe_lite.py -@@ -273,10 +273,16 @@ class Glm4MoeLiteSparseMoeBlock(DeepseekV2MoE): - - self.shared_experts_is_int8 = False - self.shared_experts_is_fp8 = False -- # self.shared_experts_weight_block_size = None -+ self.shared_experts_weight_block_size = None -+ self._shared_expert_tp1 = False - if config.n_shared_experts is not None and self.num_fused_shared_experts == 0: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - # disable tp for shared experts when enable deepep moe, or with fp4 allgather -+ shared_expert_use_tp1 = ( -+ get_moe_a2a_backend().is_deepep() -+ or get_moe_a2a_backend().is_mooncake() -+ or should_use_flashinfer_cutlass_moe_fp4_allgather() -+ ) - self.shared_experts = Glm4MoeLiteMLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, -@@ -284,14 +290,9 @@ class Glm4MoeLiteSparseMoeBlock(DeepseekV2MoE): - quant_config=quant_config, - reduce_results=False, - prefix=add_prefix("shared_experts", prefix), -- **( -- dict(tp_rank=0, tp_size=1) -- if get_moe_a2a_backend().is_deepep() -- or get_moe_a2a_backend().is_mooncake() -- or should_use_flashinfer_cutlass_moe_fp4_allgather() -- else {} -- ), -+ **(dict(tp_rank=0, tp_size=1) if shared_expert_use_tp1 else {}), - ) -+ self._shared_expert_tp1 = shared_expert_use_tp1 - is_packed_weight = hasattr( - self.shared_experts.gate_up_proj.quant_method, "quant_config" - ) -diff --git a/python/sglang/srt/models/glm4_moe_nextn.py b/python/sglang/srt/models/glm4_moe_nextn.py -index dfbd458..93b3e1a 100644 ---- a/python/sglang/srt/models/glm4_moe_nextn.py -+++ b/python/sglang/srt/models/glm4_moe_nextn.py -@@ -36,12 +36,15 @@ from sglang.srt.layers.vocab_parallel_embedding import ( - from sglang.srt.model_executor.forward_batch_info import ForwardBatch - from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM - from sglang.srt.server_args import get_global_server_args --from sglang.srt.utils import add_prefix -+from sglang.srt.utils import BumpAllocator, add_prefix - - logger = logging.getLogger(__name__) - - - class Glm4MoeModelNextN(nn.Module): -+ decoder_layer_cls = Glm4MoeDecoderLayer -+ decoder_needs_zero_allocator = False -+ - def __init__( - self, - config: PretrainedConfig, -@@ -69,7 +72,7 @@ class Glm4MoeModelNextN(nn.Module): - - self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) - -- self.decoder = Glm4MoeDecoderLayer( -+ self.decoder = self.decoder_layer_cls( - config, - 0, - quant_config=quant_config, -@@ -105,9 +108,19 @@ class Glm4MoeModelNextN(nn.Module): - - residual = None - with get_global_expert_distribution_recorder().disable_this_region(): -- hidden_states, residual = self.decoder( -- positions, hidden_states, forward_batch, residual -- ) -+ if self.decoder_needs_zero_allocator: -+ zero_allocator = BumpAllocator( -+ buffer_size=2 * (2 if forward_batch.can_run_tbo else 1), -+ dtype=torch.float32, -+ device=hidden_states.device, -+ ) -+ hidden_states, residual, _ = self.decoder( -+ positions, hidden_states, forward_batch, residual, zero_allocator -+ ) -+ else: -+ hidden_states, residual = self.decoder( -+ positions, hidden_states, forward_batch, residual -+ ) - - if not forward_batch.forward_mode.is_idle(): - if residual is not None: -@@ -119,6 +132,8 @@ class Glm4MoeModelNextN(nn.Module): - - - class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM): -+ model_cls = Glm4MoeModelNextN -+ - def __init__( - self, - config: PretrainedConfig, -@@ -133,7 +148,7 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM): - or quant_config is not None - ) - quant_config = quant_config if self.needs_quant_draft else None -- self.model = Glm4MoeModelNextN( -+ self.model = self.model_cls( - config, quant_config, prefix=add_prefix("model", prefix) - ) - self.lm_head = ParallelLMHead( -@@ -177,4 +192,22 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM): - super().load_weights(weights, is_nextn=True) - - --EntryClass = [Glm4MoeForCausalLMNextN] -+from sglang.srt.models.glm4_moe_lite import ( -+ Glm4MoeLiteDecoderLayer, -+ Glm4MoeLiteForCausalLM, -+) -+ -+ -+class Glm4MoeLiteModelNextN(Glm4MoeModelNextN): -+ decoder_layer_cls = Glm4MoeLiteDecoderLayer -+ decoder_needs_zero_allocator = True -+ -+ -+class Glm4MoeLiteForCausalLMNextN(Glm4MoeForCausalLMNextN, Glm4MoeLiteForCausalLM): -+ model_cls = Glm4MoeLiteModelNextN + -+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): -+ Glm4MoeLiteForCausalLM.load_weights(self, weights, is_nextn=True) ++def _resolve_torch_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: ++ if isinstance(dtype, torch.dtype): ++ return dtype ++ dtype_name = dtype.split(".", 1)[1] if dtype.startswith("torch.") else dtype ++ return getattr(torch, dtype_name) + + -+EntryClass = [Glm4MoeForCausalLMNextN, Glm4MoeLiteForCausalLMNextN] + def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): + params_dict = dict(model.named_parameters()) + for name, tensor in named_tensors: diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py -index 2f00749..1f99193 100644 +index 2f0074924db..8d62df83c74 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py -@@ -52,11 +52,31 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): +@@ -45,6 +45,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): + + self.pp_group = get_pp_group() + self.config = config ++ self.config.encoder_only = getattr(config, "encoder_only", False) ++ self.config.language_only = getattr(config, "language_only", False) + self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder + vision_utils.update_vit_attn_dummy_heads_config(self.config) + self.tp_size = get_tensor_model_parallel_world_size() +@@ -52,11 +54,30 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): self.num_fused_shared_experts = 0 self.determine_num_fused_shared_experts() @@ -2310,13 +2118,12 @@ index 2f00749..1f99193 100644 + # ranks other than the last rank will have a placeholder layer + self.lm_head = PPMissingLayer() + else: -+ # encoder_only mode: no language model, so no lm_head needed + self.lm_head = None + self.visual = Glm4vVisionModel( config.vision_config, quant_config=quant_config, -@@ -64,24 +84,14 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): +@@ -64,24 +85,14 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): use_data_parallel=self.use_data_parallel, ) @@ -2338,20 +2145,19 @@ index 2f00749..1f99193 100644 self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) - self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling -+ _rope_cfg = ( ++ rope_config = ( + getattr(self.config, "rope_scaling", None) + or getattr(self.config, "rope_parameters", None) + or {} + ) -+ self.is_mrope_enabled = "mrope_section" in _rope_cfg ++ self.is_mrope_enabled = "mrope_section" in rope_config # For EAGLE3 support self.capture_aux_hidden_states = False -@@ -219,6 +229,11 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): +@@ -219,6 +230,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue -+ # Skip loading visual/language model weights + if ( + self.config.encoder_only or self.config.language_only + ) and name not in params_dict: @@ -2368,11 +2174,10 @@ index 2f00749..1f99193 100644 # Mark as expert weight regardless of whether we can process it is_expert_weight = True -@@ -265,6 +282,11 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): +@@ -265,6 +282,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue -+ # Skip loading mm/language parameters + if ( + self.config.encoder_only or self.config.language_only + ) and name not in params_dict: @@ -2381,13 +2186,16 @@ index 2f00749..1f99193 100644 continue diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py -index 1b6c185..58ad9c7 100644 +index 3ffe4dde7fd..9869f11623e 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py -@@ -1024,14 +1024,19 @@ class Qwen3LLMModel(Qwen3Model): - hidden_states + residual if residual is not None else hidden_states - ) - +@@ -1034,9 +1034,14 @@ class Qwen3LLMModel(Qwen3Model): + # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack + # The order matters because addition with different tensors is not associative in practice. + # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition. +- deepstack_embeds = self.get_deepstack_embeds( +- layer_idx - 1, input_deepstack_embeds +- ) + deepstack_embeds = None + if input_deepstack_embeds is not None: + prev_layer_idx = layer_idx - 1 @@ -2396,20 +2204,11 @@ index 1b6c185..58ad9c7 100644 + deepstack_embeds = input_deepstack_embeds[ + :, sep : sep + self.hidden_size + ] -+ - # SGLang applies residual at the START of the next layer, not at the END like HuggingFace. - # See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549 - # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack - # The order matters because addition with different tensors is not associative in practice. -- # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition. -- deepstack_embeds = self.get_deepstack_embeds( -- layer_idx - 1, input_deepstack_embeds -- ) hidden_states, residual = layer( positions, hidden_states, diff --git a/python/sglang/srt/multimodal/processors/glm4v.py b/python/sglang/srt/multimodal/processors/glm4v.py -index a44f14b..6d6c65e 100644 +index db684259d2f..17d2cb6958a 100644 --- a/python/sglang/srt/multimodal/processors/glm4v.py +++ b/python/sglang/srt/multimodal/processors/glm4v.py @@ -1,7 +1,13 @@ @@ -2477,25 +2276,25 @@ index a44f14b..6d6c65e 100644 image_grid_thw = None video_grid_thw = None diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py -index fb9fd85..b1cbc9b 100644 +index b8774ebade5..fa01537b201 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py -@@ -505,7 +505,7 @@ class QwenVLImageProcessor(SGLangBaseProcessor): +@@ -678,7 +678,7 @@ class QwenVLImageProcessor(SGLangBaseProcessor): **kwargs, ): entry_time = time.perf_counter() -- base_output = self.load_mm_data( -+ base_output = self.legacy_load_mm_data( +- base_output = await self.load_mm_data( ++ base_output = await self.legacy_load_mm_data( prompt=input_text, image_data=image_data, video_data=request_obj.video_data, diff --git a/python/sglang/srt/observability/req_time_stats.py b/python/sglang/srt/observability/req_time_stats.py -index 326aace..b62804c 100644 +index 2de10730c94..d3ce2c62d21 100644 --- a/python/sglang/srt/observability/req_time_stats.py +++ b/python/sglang/srt/observability/req_time_stats.py -@@ -21,7 +21,10 @@ import uuid - from dataclasses import dataclass, field - from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +@@ -23,7 +23,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + + from typing_extensions import Self -from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.disaggregation.utils import ( @@ -2505,7 +2304,7 @@ index 326aace..b62804c 100644 from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.observability.metrics_collector import ( SchedulerMetricsCollector, -@@ -575,6 +578,14 @@ class SchedulerReqTimeStats(ReqTimeStatsBase): +@@ -577,6 +580,14 @@ class SchedulerReqTimeStats(ReqTimeStatsBase): # Number of prefill retries for this request prefill_retry_count: int = 0 @@ -2520,7 +2319,7 @@ index 326aace..b62804c 100644 def __getstate__(self) -> object: # send to detokenizer/tokenizer -@@ -582,10 +593,38 @@ class SchedulerReqTimeStats(ReqTimeStatsBase): +@@ -584,10 +595,35 @@ class SchedulerReqTimeStats(ReqTimeStatsBase): return {} state = { @@ -2543,39 +2342,23 @@ index 326aace..b62804c 100644 + self.fwd_prefill_bootstrap_queue_duration + ), + "fwd_prefill_forward_duration": self.fwd_prefill_forward_duration, -+ "fwd_prefill_transfer_queue_duration": self.fwd_prefill_transfer_queue_duration, ++ "fwd_prefill_transfer_queue_duration": ( ++ self.fwd_prefill_transfer_queue_duration ++ ), + "fwd_prefill_bootstrap_duration": self.fwd_prefill_bootstrap_duration, + "fwd_prefill_alloc_wait_duration": self.fwd_prefill_alloc_wait_duration, + "fwd_transfer_speed_gb_s": self.fwd_transfer_speed_gb_s, + "fwd_transfer_total_mb": self.fwd_transfer_total_mb, + "fwd_prefill_retry_count": self.fwd_prefill_retry_count, "diff_realtime_monotonic": global_diff_realtime_monotonic, -+ # SLIME PATCH: preserve enable_metrics across IPC hops. -+ # Otherwise detokenizer-side enable_metrics defaults to False (from -+ # the dataclass default, not from the scheduler value), and when -+ # the same object is forwarded to tokenizer_manager the early -+ # `if not self.enable_metrics: return {}` strips PD timing fields. + "enable_metrics": self.enable_metrics, } return state -@@ -974,6 +1013,172 @@ class SchedulerReqTimeStats(ReqTimeStatsBase): +@@ -987,6 +1023,159 @@ class SchedulerReqTimeStats(ReqTimeStatsBase): def get_queueing_time(self) -> float: return self.forward_entry_time - self.wait_queue_entry_time -+ def get_prefill_waiting_latency(self) -> Optional[float]: -+ if self.prefill_run_batch_start_time > 0.0: -+ return self.prefill_run_batch_start_time - self.forward_entry_time -+ return None -+ -+ def get_prefill_launch_latency(self) -> Optional[float]: -+ if ( -+ self.prefill_run_batch_start_time > 0.0 -+ and self.prefill_run_batch_end_time > 0.0 -+ ): -+ return self.prefill_run_batch_end_time - self.prefill_run_batch_start_time -+ return None -+ + def get_pd_prefill_bootstrap_queue_duration(self) -> Optional[float]: + if not is_slime_profiling_enabled(): + return None @@ -2732,7 +2515,7 @@ index 326aace..b62804c 100644 def convert_to_duration(self) -> str: if self.disagg_mode == DisaggregationMode.NULL: queue_duration = self.duration_between( -@@ -1107,6 +1312,26 @@ class SchedulerReqTimeStats(ReqTimeStatsBase): +@@ -1120,6 +1309,36 @@ class SchedulerReqTimeStats(ReqTimeStatsBase): "queue_time": self.get_queueing_time(), } ) @@ -2742,14 +2525,24 @@ index 326aace..b62804c 100644 + self.get_pd_prefill_bootstrap_queue_duration() + ), + "pd_prefill_forward_duration": self.get_pd_prefill_forward_duration(), -+ "pd_prefill_transfer_queue_duration": self.get_pd_prefill_transfer_queue_duration(), -+ "pd_prefill_bootstrap_duration": self.get_pd_prefill_bootstrap_duration(), -+ "pd_prefill_alloc_wait_duration": self.get_pd_prefill_alloc_wait_duration(), ++ "pd_prefill_transfer_queue_duration": ( ++ self.get_pd_prefill_transfer_queue_duration() ++ ), ++ "pd_prefill_bootstrap_duration": ( ++ self.get_pd_prefill_bootstrap_duration() ++ ), ++ "pd_prefill_alloc_wait_duration": ( ++ self.get_pd_prefill_alloc_wait_duration() ++ ), + "pd_decode_prealloc_duration": self.get_pd_decode_prealloc_duration(), + "pd_decode_transfer_duration": self.get_pd_decode_transfer_duration(), + "pd_decode_forward_duration": self.get_pd_decode_forward_duration(), -+ "pd_decode_bootstrap_duration": self.get_pd_decode_bootstrap_duration(), -+ "pd_decode_alloc_wait_duration": self.get_pd_decode_alloc_wait_duration(), ++ "pd_decode_bootstrap_duration": ( ++ self.get_pd_decode_bootstrap_duration() ++ ), ++ "pd_decode_alloc_wait_duration": ( ++ self.get_pd_decode_alloc_wait_duration() ++ ), + "pd_transfer_speed_gb_s": self.get_pd_transfer_speed_gb_s(), + "pd_transfer_total_mb": self.get_pd_transfer_total_mb(), + "pd_prefill_retry_count": self.get_pd_prefill_retry_count(), @@ -2760,18 +2553,10 @@ index 326aace..b62804c 100644 def format_duration(self, duration: float) -> str: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py -index 173bdbb..339791f 100644 +index 6c77ff64f92..706161ab928 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py -@@ -772,6 +772,7 @@ class ServerArgs: - # Context parallelism used in the long sequence prefill phase of DeepSeek v3.2 - enable_nsa_prefill_context_parallel: bool = False - nsa_prefill_cp_mode: str = "round-robin-split" -+ disable_indexer_rope_neox_style: bool = False - enable_fused_qk_norm_rope: bool = False - enable_precise_embedding_interpolation: bool = False - enable_fused_moe_sum_all_reduce: bool = False -@@ -817,6 +818,8 @@ class ServerArgs: +@@ -854,6 +854,8 @@ class ServerArgs: weight_loader_prefetch_checkpoints: bool = False weight_loader_prefetch_num_threads: int = 4 weight_loader_drop_cache_after_load: bool = False @@ -2780,49 +2565,36 @@ index 173bdbb..339791f 100644 remote_instance_weight_loader_seed_instance_ip: Optional[str] = None remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None -@@ -6782,6 +6785,12 @@ class ServerArgs: - help="Token splitting mode for the prefill phase of DeepSeek v3.2 under context parallelism. Optional values: 'round-robin-split'(default), 'in-seq-split' " - "'round-robin-split' distributes tokens across ranks based on token_idx %% cp_size. It supports multi-batch prefill, fused MoE, and FP8 KV cache.", - ) -+ parser.add_argument( -+ "--disable-indexer-rope-neox-style", -+ action="store_true", -+ help="Disable NSA indexer RoPE neox style (equivalent to INDEXER_ROPE_NEOX_STYLE=0). " -+ "If the environment variable INDEXER_ROPE_NEOX_STYLE is also set and conflicts, an error is raised.", -+ ) - parser.add_argument( - "--enable-prefill-context-parallel", +@@ -7119,6 +7121,18 @@ class ServerArgs: action="store_true", -@@ -6962,6 +6971,18 @@ class ServerArgs: - action="store_true", - help="Disable mmap while loading weight using safetensors.", + help="Call posix_fadvise(DONTNEED) on each safetensors shard after loading it.", ) + parser.add_argument( + "--update-weight-delta-chunk-bytes", + type=int, + default=ServerArgs.update_weight_delta_chunk_bytes, -+ help="Byte cap per load_weights call when applying a delta update.", ++ help="Maximum bytes per delta weight chunk when applying delta updates.", + ) + parser.add_argument( + "--update-weight-delta-read-workers", + type=int, + default=ServerArgs.update_weight_delta_read_workers, -+ help="Max parallel I/O threads for reading delta files from disk.", ++ help="Number of worker threads used to read delta weight files.", + ) parser.add_argument( - "--weight-loader-prefetch-checkpoints", - action="store_true", + "--remote-instance-weight-loader-seed-instance-ip", + type=str, diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py -index 1dd4bb0..4be3a8c 100644 +index 96c7286af76..9e3e2bd7142 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py -@@ -442,8 +442,12 @@ class EAGLEDraftCudaGraphRunner: +@@ -458,8 +458,12 @@ class EAGLEDraftCudaGraphRunner: "EagleDraftCudaGraphRunner.replay: topk_index vs vocab_size=" f"{self.model_runner.model_config.vocab_size}", ) - buffers.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) - buffers.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) -+ buffers.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p.clamp(0, 1)) ++ buffers.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p.clamp(0.0, 1.0)) + buffers.topk_index[:raw_bs].copy_( + forward_batch.spec_info.topk_index.clamp( + 0, self.model_runner.model_config.vocab_size - 1 @@ -2832,52 +2604,71 @@ index 1dd4bb0..4be3a8c 100644 buffers.hidden_states is not None and forward_batch.spec_info.hidden_states is not None diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py -index e845fde..ead4d72 100644 +index 6bf5d6182af..70de75f20be 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py -@@ -1232,6 +1232,7 @@ class EAGLEWorkerV2(BaseSpecWorker): - success, message = self._draft_worker.draft_runner.update_weights_from_disk( +@@ -530,6 +530,21 @@ class EagleDraftWorker(BaseDraftWorker): + + # Forward multiple steps + scores = None ++ # Reuse NSA/DSA topk_indices from the first draft forward step for ++ # subsequent steps, analogous to skip_topk in deepseek_v2.py layers. ++ # Only safe with topk == 1: select_top_k_tokens reorders candidate rows ++ # each step, which would desync the cached indices from their rows. ++ index_share_for_mtp_iteration = ( ++ getattr( ++ self.draft_runner.model_config.hf_config, ++ "index_share_for_mtp_iteration", ++ False, ++ ) ++ and self.topk == 1 ++ ) ++ if index_share_for_mtp_iteration: ++ forward_batch.reuse_mtp_topk_indices = True ++ forward_batch.topk_indices = None + for i in range(self.speculative_num_steps): + input_ids, hidden_states, scores, tree_info = select_top_k_tokens( + i, topk_p, topk_index, hidden_states, scores, self.topk +@@ -597,6 +612,10 @@ class EagleDraftWorker(BaseDraftWorker): + hidden_states = logits_output.hidden_states + forward_batch.positions.add_(1) + ++ if index_share_for_mtp_iteration: ++ forward_batch.topk_indices = None ++ forward_batch.reuse_mtp_topk_indices = False ++ + # Organize the results + if ( + self.topk == 1 +@@ -1480,6 +1499,7 @@ class EAGLEWorkerV2(BaseSpecWorker): recv_req.model_path, recv_req.load_format, -+ files=recv_req.files, recapture_cuda_graph=recv_req.recapture_cuda_graph, ++ files=recv_req.files, ) if not success: + return success, message diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py -index 4650619..3722bd2 100644 +index 04b3841a23d..9aaf6b30673 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py -@@ -819,6 +819,7 @@ class MultiLayerEagleWorkerV2(BaseSpecWorker): - ].update_weights_from_disk( +@@ -856,6 +856,7 @@ class MultiLayerEagleWorkerV2(BaseSpecWorker): recv_req.model_path, recv_req.load_format, -+ files=recv_req.files, recapture_cuda_graph=recv_req.recapture_cuda_graph, ++ files=recv_req.files, ) if not success: + return success, message diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py -index 48e1086..2fd60b3 100644 +index 4556d06b16f..9c28114f85d 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py -@@ -2311,6 +2311,7 @@ class SafeUnpickler(pickle.Unpickler): - "sglang.srt.layers.", +@@ -2399,6 +2399,7 @@ class SafeUnpickler(pickle.Unpickler): "sglang.srt.utils.", - "torch_npu.", + "sglang.srt.disaggregation.", + "sglang.srt.managers.", + "slime.", + "torch_npu.", } - DENY_CLASSES = { -diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py -index 5556697..46c1771 100644 ---- a/python/sglang/srt/utils/weight_checker.py -+++ b/python/sglang/srt/utils/weight_checker.py -@@ -168,6 +168,9 @@ def _check_tensors( - actual_should_compare, - actual, - ) in zip(expect_tensors, actual_tensors, strict=True): -+ if ".cos_sin_cache" in expect_name: -+ # skip cos/sin cache which is deterministic from shape and dtype and may have different shapes due to different implementations. -+ continue - assert expect_name == actual_name, f"{expect_name=} {actual_name=}" - assert ( - expect_should_compare == actual_should_compare diff --git a/docker/version.txt b/docker/version.txt index 01b202dfda..6adf74a6e0 100644 --- a/docker/version.txt +++ b/docker/version.txt @@ -1 +1 @@ -nightly-dev-20260608b +nightly-dev-20260614a diff --git a/docs/en/advanced/delta-weight-sync.md b/docs/en/advanced/delta-weight-sync.md index a1e670f81d..9dd4fa2c9f 100644 --- a/docs/en/advanced/delta-weight-sync.md +++ b/docs/en/advanced/delta-weight-sync.md @@ -4,6 +4,7 @@ - [Quick Start](#quick-start) - [Mode vs Transport](#mode-vs-transport) - [How It Works](#how-it-works) +- [Publish-Only Disk Delta](#publish-only-disk-delta) - [Encoding Choice](#encoding-choice) - [Why Not Colocated](#why-not-colocated) @@ -92,6 +93,35 @@ For both transports, the receiver ends up calling the same `_apply_delta_payload Selective overwrite has no arithmetic — the receiver writes the trainer's exact bytes at changed positions — so it's lossless by construction and there's no notion of drift to fight with periodic base re-syncs. +## Publish-Only Disk Delta + +The disk path above pushes each version to known engines: rank 0 calls every engine's `update_weights_from_disk(load_format="delta")` and the sync ends when all engines acknowledge. That requires stable engine handles. When the serving side is an elastic fleet that consumes published versions on its own schedule — e.g. behind an [opaque HTTP rollout endpoint](external-rollout-engines.md#opaque-http-rollout-endpoint) — invert the direction with publish-only mode: + +```bash +--update-weight-mode delta +--update-weight-transport disk +--update-weight-delta-publish-only +--custom-delta-publish-path my_pkg.publish.publish_delta +--update-weight-delta-keep-files +``` + +Instead of firing per-engine RPCs, rank 0 invokes your publish hook once per sync, after every delta file has been written and the optional `--custom-delta-pre-push-path` hook has committed: + +```python +def publish_delta(args, version_dir: str, files: list[str], weight_version: str, rollout_engines) -> list | None: + ... # e.g. upload version_dir to object storage, then announce weight_version +``` + +Returned Ray ObjectRefs are awaited before the version counts as settled. Behavior differences from the direct disk path: + +- **One complete version per sync.** Direct disk transport publishes at each pass boundary so receivers can overlap apply with later encoding; publish-only defers everything to finalize, so external consumers never observe a partially published version. +- **Publish wait is configurable.** By default, `--update-weight-delta-publish-wait=next-sync` leaves the dispatched publish in flight across the next training step and settles it at the start of the next sync (or on disconnect). A failed publish therefore surfaces one sync late, on rank 0. Set `--update-weight-delta-publish-wait=sync` when the publish hook should block `update_weights`, for example because it polls an external rollout fleet until enough replicas report the new version ready. +- **Engines are left alone.** Generation is not paused, caches are not flushed, and no update RPCs are issued; consumers decide when to pick up a version. If the rollout endpoint supports request-level weight constraints, attach them from a `--custom-rollout-request-hook-path` hook so requests routed to lagging replicas fail/retry before doing unusable rollout compute. +- **No cleanup.** slime cannot know when consumers finish reading a version, so `--update-weight-delta-keep-files` is required and version-directory lifecycle belongs to you (e.g. the publish hook can prune old versions once uploaded). +- **No-op versions still publish.** If a sync produces no changed bytes, the hook is still called with an empty file list so consumers' version counters can advance. + +`--update-weight-delta-root` optionally names a root directory for publish-side metadata; it defaults to the parent of `--update-weight-disk-dir` and is passed through to hooks via `args`. + ## Encoding Choice `--update-weight-encoding` picks how positions are packed. All three share the same on-wire layout (`__positions__` uint8 blob + `__values__` tensor + per-param manifest); decoder dispatches on the metadata. diff --git a/docs/en/advanced/external-rollout-engines.md b/docs/en/advanced/external-rollout-engines.md index 498afa0c5f..a48b3a0eb9 100644 --- a/docs/en/advanced/external-rollout-engines.md +++ b/docs/en/advanced/external-rollout-engines.md @@ -2,13 +2,15 @@ An external rollout engine is an SGLang engine that is not launched by the slime training job. Another system deploys and owns the engine lifecycle; slime connects to those engines during training, registers a router, and syncs updated actor weights when needed. -This page is a roadmap. Use it to decide when to use `--rollout-external-engine-addrs`, when to stay with `--sglang-config`, and which weight-update path to pick for external deployments. +This page is a roadmap. Use it to decide when to use `--rollout-external-engine-addrs`, when to use `--rollout-http-endpoint-url`, when to stay with `--sglang-config`, and which weight-update path to pick for external deployments. ## Where To Start | Goal | Recommended entry point | | :--- | :--- | | Engines are already launched externally and slime should only connect for rollout | `--rollout-external-engine-addrs` | +| Rollout serving is an elastic fleet behind a single HTTP URL, with no stable per-engine handles | `--rollout-http-endpoint-url` | +| The serving side pulls published weight versions instead of receiving direct update RPCs | `--update-weight-delta-publish-only`, see [Publish-Only Disk Delta](delta-weight-sync.md#publish-only-disk-delta) | | slime should still launch engines, but you need PD disaggregation, multi-model serving, heterogeneous server groups, or per-group overrides | [SGLang Config](sglang-config.md) | | Trainer and external engines can form an NCCL group | Default `--update-weight-mode full --update-weight-transport nccl` | | Trainer and external engines cannot form an NCCL group, but can see the same filesystem path | `--update-weight-mode full --update-weight-transport disk` | @@ -38,6 +40,27 @@ slime queries each engine's `/server_info` or `/get_server_info` endpoint and in This path fits deployments where serving is owned outside the training job: a separate inference cluster, a separate Ray cluster, manually warmed SGLang engines, or a rollout service managed by another orchestrator. +## Opaque HTTP Rollout Endpoint + +`--rollout-external-engine-addrs` still assumes SGLang engines with stable addresses: slime queries `/server_info` per engine, registers each one with a router, and pushes weight updates to known engine handles. Some deployments cannot offer that contract — for example a serverless or autoscaled inference fleet behind one URL, where workers come and go and no worker-management API is exposed. For those, point slime at the endpoint directly: + +```bash +python train.py \ + --rollout-http-endpoint-url https://rollout.example.com \ + ... +``` + +In this mode slime launches no engines and no router, and assumes nothing about the endpoint beyond the generation route: rollout requests POST to `{url}/generate`, and `get_model_url(args, ...)` in custom rollout functions resolves to the endpoint as well. No rollout GPUs are allocated in the placement group, `/server_info` is never queried, and slime fault tolerance does not manage the fleet — recovery is the endpoint operator's job. `--rollout-http-endpoint-url` and `--rollout-external-engine-addrs` are mutually exclusive. + +Two companion flags adapt the default SGLang rollout to an endpoint that lacks router APIs: + +- `--rollout-http-endpoint-abort-strategy {cancel-only,router-workers}`: how `abort` behaves between rollouts. `cancel-only` (the default when an endpoint URL is set) cancels slime's local pending generation tasks without calling the router's worker-list or per-worker abort APIs. `router-workers` keeps the existing router-based abort and remains the default otherwise. Note that `cancel-only` does not collect partial samples, so it does not compose with `--partial-rollout`. +- `--custom-rollout-request-hook-path`: optional hook called before each default SGLang `/generate` request. Signature: `def hook(args, sample, request) -> None | dict`. The `request` dict contains `url`, `payload`, `headers`, `max_retries`, `retry_sleep`, `rollout_id`, and `evaluation`; mutate it in place or return a dict of updates. + +Use the request hook for rollout-endpoint admission control. For example, a hook may attach `"weight_version": {"exact_version": }` or `"weight_version": {"min_required_version": }` and increase `max_retries`/`retry_sleep`. Those request fields avoid wasted rollout compute when an opaque router sends the request to a replica that has not loaded a usable version yet. They do not define SLIME's off-policy or staleness semantics; the trainer schedule and loss/correction path still decide which versions are valid. + +For weight sync, an elastic fleet usually cannot receive per-engine `update_weights_from_disk` RPCs either. Combine the endpoint with publish-only delta sync, where the trainer publishes each complete weight version through a custom hook and the serving side consumes it on its own schedule — see [Publish-Only Disk Delta](delta-weight-sync.md#publish-only-disk-delta). If request-level minimum-version retry is enough, leave publish-only in its default pipelined mode. If the publish hook polls rollout-fleet status and you want the next rollout dispatch to wait for that readiness threshold, set `--update-weight-delta-publish-wait=sync`. + ## Relationship With `--sglang-config` `--rollout-external-engine-addrs` and `--sglang-config` are mutually exclusive because they own different boundaries: @@ -108,8 +131,9 @@ For encoding choices, wire layout, receiver-side selective overwrite, and tuning - External engines can use an independent SGLang environment; they do not need the slime or Megatron training environment. - Disk transport supports different GPU models or vendors between training and rollout, as long as SGLang supports the target hardware and model format. - Disk transport requires trainer and SGLang engines to see the same `--update-weight-disk-dir` path; a path visible only to the trainer is not enough. -- External engines are not recovered by slime fault tolerance; their lifecycle belongs to the external deployment system. -- `--sglang-config` and `--rollout-external-engine-addrs` are mutually exclusive. +- External engines are not recovered by slime fault tolerance; their lifecycle belongs to the external deployment system. The same applies to fleets behind `--rollout-http-endpoint-url`. +- `--sglang-config` and `--rollout-external-engine-addrs` are mutually exclusive, as are `--rollout-external-engine-addrs` and `--rollout-http-endpoint-url`. +- An opaque HTTP endpoint only needs to serve the generation route; worker-management APIs are never called. If the fleet cannot accept direct weight-update RPCs, use publish-only delta sync. - Delta mode does not support `--colocate`, because colocated sync uses CUDA IPC handles and delta encoding does not reduce the actual transfer. ## Related Work diff --git a/docs/en/developer_guide/ci.md b/docs/en/developer_guide/ci.md index ef06c8f4ea..8a44cff00d 100644 --- a/docs/en/developer_guide/ci.md +++ b/docs/en/developer_guide/ci.md @@ -49,12 +49,11 @@ The changed-test job itself runs through the self-hosted Docker path. When `NUM_ |---|---|---|---| | Automatic | `cpu-unittest` | CPU | Always-on unit and contract tests for argument validation, schedules, rewards, samples, rollout validation, checkpoint utilities, and plugin contracts. | | Automatic | `agent-adapter-test` | CPU | Always-on agent adapter tests with optional provider SDK dependencies. | -| `run-ci-short` | `e2e-test-short` | GPU | Lightweight smoke tests with small Qwen models. Fast GPU feedback loop. | | `run-ci-sglang-config` | `e2e-test-sglang-config` | GPU | SGLang config tests for advanced rollout engine deployment and mixed/offload scenarios. | | `run-ci-megatron` | `e2e-test-megatron` | GPU | Core Megatron training tests covering dense, MoE, PPO, MTP, OPD, async rollout, PD/Mooncake, and debug replay paths. | | `run-ci-precision` | `e2e-test-precision` | GPU | Numerical precision validation and parallel consistency checks. | | `run-ci-ckpt` | `e2e-test-ckpt` | GPU | Checkpoint save/load correctness, including CPU/GPU optimizer states and async save. | -| `run-ci-image` | `e2e-test-image` | GPU | Broad image validation suite on `slimerl/slime-test:latest`. | +| `run-ci-image` | `e2e-test-image` | GPU | Runs the `run-ci-megatron` matrix on `slimerl/slime-test:latest`. | | `run-ci-changed` | `e2e-test-changed` | Mixed | Runs only changed tests, using each file's `NUM_GPUS` value. | `workflow_dispatch` can be used from the Actions page for manual validation. It runs the registered jobs according to the workflow conditions. @@ -86,12 +85,11 @@ python -m pytest tests/test_megatron_argument_validation.py tests/plugin_contrac GPU e2e tests validate the integrated training/rollout behavior that CPU tests cannot cover: -- `run-ci-short`: small-model smoke coverage for quick GPU feedback. - `run-ci-sglang-config`: advanced SGLang deployment paths, including config-based engine layouts. - `run-ci-megatron`: main Megatron backend coverage for dense/MoE recipes, async rollout, OPD, PPO-style paths, PD/Mooncake, and debug rollout-then-train replay. - `run-ci-precision`: numerical consistency across parallel settings. - `run-ci-ckpt`: checkpoint save/load combinations and async save. -- `run-ci-image`: broad validation of the release/test image. +- `run-ci-image`: the same matrix as `run-ci-megatron`, but on the release/test image. Use targeted labels for routine PRs. Use `run-ci-image` sparingly because it consumes significantly more GPU time. diff --git a/docs/en/get_started/customization.md b/docs/en/get_started/customization.md index 77f5cd5e34..389cda02b6 100644 --- a/docs/en/get_started/customization.md +++ b/docs/en/get_started/customization.md @@ -28,6 +28,7 @@ Below is a summary of all available customization interfaces and their purposes. | [`--custom-megatron-init-path`](#17-megatron-hooks) | Custom initialization after Megatron setup. | | [`--custom-megatron-before-log-prob-hook-path`](#17-megatron-hooks) | Custom logic before log probability computation. | | [`--custom-megatron-before-train-step-hook-path`](#17-megatron-hooks) | Custom logic before each training step. | +| [`--custom-rollout-request-hook-path`](#19-rollout-request-hook---custom-rollout-request-hook-path) | Customize each default SGLang `/generate` request before dispatch. | ## Agentic workflows through customization interfaces @@ -457,6 +458,25 @@ Stabilize MoE RL training by recording and replaying expert routing decisions to | `--use-routing-replay` | Forward-backward routing consistency in training. ([arXiv:2507.18071](https://arxiv.org/abs/2507.18071)) | | `--use-rollout-routing-replay` | R3: Replay routing from rollout during training. Supported by slime's default `sglang_router` path. ([arXiv:2510.11370](https://arxiv.org/abs/2510.11370)) | +--- + +### 19. Rollout Request Hook (`--custom-rollout-request-hook-path`) + +**Signature**: +```python +def hook(args, sample, request) -> None | dict +``` + +**Purpose**: Customize each default SGLang rollout `/generate` request before it +is sent. `request` contains `url`, `payload`, `headers`, `max_retries`, +`retry_sleep`, `rollout_id`, and `evaluation`. Mutate it in place or return a +dict of updates. + +This hook is useful for external rollout providers that need request-level +admission control, for example adding `payload["weight_version"]` so a request +routed to a lagging replica fails and retries before doing unusable rollout +compute. + ## Testing Custom Function Paths slime also provides CPU-only contract tests for customization interfaces. These tests resolve components through import-path strings, so they can validate both built-in hooks and user-defined implementations passed through the same CLI arguments used by training. @@ -470,7 +490,7 @@ The tests live under `tests/plugin_contracts/` and are grouped by hook shape: - `tests/plugin_contracts/test_plugin_path_loading_contracts.py` Covers `--eval-function-path`, `--custom-rm-path`, `--dynamic-sampling-filter-path`, `--buffer-filter-path`, `--data-source-path`, `--rollout-sample-filter-path`, and `--rollout-all-samples-process-path` - `tests/plugin_contracts/test_plugin_runtime_hook_contracts.py` - Covers `--custom-rollout-log-function-path`, `--custom-eval-rollout-log-function-path`, `--custom-reward-post-process-path`, `--custom-convert-samples-to-train-data-path`, and `--rollout-data-postprocess-path` + Covers `--custom-rollout-log-function-path`, `--custom-eval-rollout-log-function-path`, `--custom-reward-post-process-path`, `--custom-convert-samples-to-train-data-path`, `--rollout-data-postprocess-path`, and `--custom-rollout-request-hook-path` Run all customization contract tests locally: diff --git a/docs/en/get_started/quick_start.md b/docs/en/get_started/quick_start.md index 7f3a312e66..70132bbdb1 100644 --- a/docs/en/get_started/quick_start.md +++ b/docs/en/get_started/quick_start.md @@ -306,7 +306,7 @@ SGLANG_ARGS=( ### Colocated Actor and Rollout -Under the default configuration, training (Actor) and inference (Rollout) resources are specified separately. Ray allocates `actor_num_nodes * actor_num_gpus_per_node` GPUs to the training part and `rollout_num_gpus` GPUs to inference, that is, training and inference are separated. +Under the default configuration, training (Actor) and inference (Rollout) resources are specified separately. Ray allocates `actor_num_nodes * actor_num_gpus_per_node` GPUs to the training part and `rollout_num_gpus` GPUs to inference, that is, training and inference are separated. When `--rollout-num-gpus` is explicitly set to `0`, slime still parses SGLang arguments and launches the router, but does not launch local SGLang servers. **Standard (Disaggregated) Configuration**: ```bash @@ -320,7 +320,7 @@ ray job submit ... \ In the above configuration, Actor uses 4 cards, and Rollout also uses 4 cards, running in parallel. **Training-Inference Integration (Colocated) Configuration**: -To deploy training and inference on the same group of GPUs, please add the `--colocate` parameter. After enabling, `--rollout-num-gpus` will be ignored to make the number of cards for training and inference equal. +To deploy training and inference on the same group of GPUs, please add the `--colocate` parameter. By default, this makes the number of cards for training and inference equal. You can explicitly set a different positive `--rollout-num-gpus`, for example to use more rollout GPUs than actor GPUs; the extra GPUs are used as rollout-only resources. If `--rollout-num-gpus 0` is set explicitly, slime launches only the router and no local SGLang servers. ```bash ray job submit ... \ diff --git a/docs/en/get_started/usage.md b/docs/en/get_started/usage.md index 37b43280c6..5deeda7980 100644 --- a/docs/en/get_started/usage.md +++ b/docs/en/get_started/usage.md @@ -18,7 +18,7 @@ There are four main parameters for cluster resource allocation: - `--actor-num-nodes`: The number of nodes required for RL actor training. - `--actor-num-gpus-per-node`: The number of GPUs per node for RL actor training. - - `--rollout-num-gpus`: The total number of GPUs required for rollout (inference). + - `--rollout-num-gpus`: The total number of GPUs required for rollout (inference). Set it to `0` to still parse SGLang arguments and launch the router without launching local SGLang servers. - `--rollout-num-gpus-per-engine`: The number of GPUs per inference engine. This parameter is similar to SGLang's `tp_size`. When performing multi-node serving, this value should be the total number of GPUs. For example, if serving one model with 2 nodes and 16 GPUs, this value should be 16. The reason for not using a parameter like `--sglang-tp-size` is that we might consider supporting SGLang's `dp_size` parameter in the future, which means an engine could contain multiple SGLang servers (currently, only `--sglang-dp-size` under the `--sglang-enable-dp-attention` condition is supported). @@ -26,7 +26,7 @@ With the default configuration, we use these parameters to allocate `actor_num_n For co-located training and inference, you also need to configure: - - `--colocate`: Enables co-located training and inference. When enabled, it ignores `--rollout-num-gpus` and makes the number of GPUs for training and inference equal. + - `--colocate`: Enables co-located training and inference. By default, this makes the number of GPUs for training and inference equal. You can explicitly set a different positive `--rollout-num-gpus`, for example to use more rollout GPUs than actor GPUs; the extra GPUs are used as rollout-only resources. If `--rollout-num-gpus 0` is set explicitly, slime launches only the router and no local SGLang servers. Additionally, slime supports Prefill and Decode disaggregation (PD Disaggregation). You can set the number of servers used for Prefill by setting the `--prefill-num-servers` argument. @@ -184,6 +184,7 @@ Additionally, we provide a `metadata_key`, which defaults to `"metadata"`. When - `--advantage-estimator`: Specifies the RL algorithm for the training process. Currently supported algorithms include: - `grpo` ([https://arxiv.org/abs/2402.03300](https://arxiv.org/abs/2402.03300)) - `gspo` ([https://arxiv.org/abs/2507.18071](https://arxiv.org/abs/2507.18071)) + - `cispo` ([https://arxiv.org/abs/2506.13585](https://arxiv.org/abs/2506.13585)) - `reinforce_plus_plus` and `reinforce_plus_plus_baseline` ([https://arxiv.org/abs/2501.03262](https://arxiv.org/abs/2501.03262)) - `ppo` ([https://arxiv.org/abs/1707.06347](https://arxiv.org/abs/1707.06347)) diff --git a/docs/zh/advanced/delta-weight-sync.md b/docs/zh/advanced/delta-weight-sync.md index f009dc954a..ad17ac23f9 100644 --- a/docs/zh/advanced/delta-weight-sync.md +++ b/docs/zh/advanced/delta-weight-sync.md @@ -4,6 +4,7 @@ - [快速开始](#快速开始) - [同步模式与传输方式](#同步模式与传输方式) - [工作原理](#工作原理) +- [Publish-Only 磁盘 Delta](#publish-only-磁盘-delta) - [编码选择](#编码选择) - [为何不支持 colocated](#为何不支持-colocated) @@ -88,6 +89,35 @@ Delta NCCL 和 delta 磁盘共用同一条发送管线、同一种 wire 布局 选择性覆写没有任何算术运算 —— 接收端在变化位置直接写入训练端的精确字节 —— 因此天然无损,也不存在数值漂移问题,无需周期性 base 同步。 +## Publish-Only 磁盘 Delta + +上面的磁盘路径把每个版本推送给已知 engine:rank 0 调用每个 engine 的 `update_weights_from_disk(load_format="delta")`,所有 engine 确认后同步才结束。这要求 engine 句柄稳定。当 serving 侧是一个按自己节奏消费已发布版本的弹性集群——例如位于 [opaque HTTP rollout endpoint](external-rollout-engines.md#opaque-http-rollout-endpoint) 之后——可以用 publish-only 模式反转方向: + +```bash +--update-weight-mode delta +--update-weight-transport disk +--update-weight-delta-publish-only +--custom-delta-publish-path my_pkg.publish.publish_delta +--update-weight-delta-keep-files +``` + +rank 0 不再发出 per-engine RPC,而是在每次同步中调用一次你的 publish hook——此时所有 delta 文件已经写完,可选的 `--custom-delta-pre-push-path` hook 也已提交: + +```python +def publish_delta(args, version_dir: str, files: list[str], weight_version: str, rollout_engines) -> list | None: + ... # 例如把 version_dir 上传到对象存储,然后公告 weight_version +``` + +返回的 Ray ObjectRef 会在该版本视为完成之前被等待。与直接磁盘路径的行为差异: + +- **每次同步发布一个完整版本。** 直接磁盘传输在每个 pass 边界发布,让接收端的 apply 与后续编码重叠;publish-only 把所有发布推迟到 finalize,外部消费者永远不会看到只发布了一半的版本。 +- **发布等待可配置。** 默认 `--update-weight-delta-publish-wait=next-sync` 会让已派发的 publish 在下一个训练 step 期间保持 in flight,并在下一次同步开始时(或 disconnect 时)结算。因此 publish 失败会晚一个同步周期才在 rank 0 上暴露。如果 publish hook 会轮询外部 rollout 集群、并且希望下一次 rollout dispatch 等到足够副本就绪后再开始,可以设置 `--update-weight-delta-publish-wait=sync`。 +- **不打扰 engine。** 不暂停生成、不清空 cache、不发出任何 update RPC;消费者自己决定何时拉取新版本。如果 rollout endpoint 支持请求级权重约束,可以在 `--custom-rollout-request-hook-path` hook 中附加这些约束,让路由到落后副本的请求尽早失败并重试,避免生成不可用样本。 +- **不做清理。** slime 无法知道消费者何时读完一个版本,所以必须加 `--update-weight-delta-keep-files`,版本目录的生命周期由你负责(例如 publish hook 可以在上传完成后清理旧版本)。 +- **空 delta 也会发布。** 如果某次同步没有任何字节变化,hook 仍会以空文件列表被调用,让消费者的版本计数得以推进。 + +`--update-weight-delta-root` 可选地指定发布侧元数据的根目录;缺省为 `--update-weight-disk-dir` 的父目录,并通过 `args` 透传给 hook。 + ## 编码选择 `--update-weight-encoding` 决定位置如何打包。三种编码共用同一种 wire 布局(`__positions__` uint8 块 + `__values__` 张量 + per-param manifest),解码端根据 metadata 分派。 diff --git a/docs/zh/advanced/external-rollout-engines.md b/docs/zh/advanced/external-rollout-engines.md index 9aae0ef5ec..46dd999225 100644 --- a/docs/zh/advanced/external-rollout-engines.md +++ b/docs/zh/advanced/external-rollout-engines.md @@ -2,13 +2,15 @@ External rollout engine 指的是:SGLang engine 不由 slime 训练任务启动,而是由外部系统预先部署和管理;slime 只在训练时连接这些 engine,注册 router,并在需要时同步训练后的 actor 权重。 -这篇文档是一个导航页。它帮助你判断什么时候该用 `--rollout-external-engine-addrs`,什么时候该继续使用 `--sglang-config`,以及 external 场景下该选择 full checkpoint update from disk 还是 delta update。 +这篇文档是一个导航页。它帮助你判断什么时候该用 `--rollout-external-engine-addrs`,什么时候该用 `--rollout-http-endpoint-url`,什么时候该继续使用 `--sglang-config`,以及 external 场景下该选择 full checkpoint update from disk 还是 delta update。 ## 从哪里开始 | 目标 | 推荐入口 | | :--- | :--- | | engine 已经由外部系统启动,只想让 slime 连接并做 rollout | `--rollout-external-engine-addrs` | +| rollout serving 是单一 HTTP URL 背后的弹性集群,没有稳定的 per-engine 句柄 | `--rollout-http-endpoint-url` | +| serving 侧主动拉取发布的权重版本,而不是接收直接的 update RPC | `--update-weight-delta-publish-only`,见 [Publish-Only 磁盘 Delta](delta-weight-sync.md#publish-only-磁盘-delta) | | engine 仍由 slime 启动,但需要 PD 分离、多模型、异构 server group 或 per-group overrides | [SGLang Config](sglang-config.md) | | 训练器和 external engine 可以建立 NCCL group | 默认的 `--update-weight-mode full --update-weight-transport nccl` | | 训练器和 external engine 不能建立 NCCL group,但能共享同一路径的文件系统 | `--update-weight-mode full --update-weight-transport disk` | @@ -38,6 +40,27 @@ slime 会请求每个 engine 的 `/server_info` 或 `/get_server_info`,推断 这条路径适合 serving 生命周期由训练任务外部管理的部署:例如独立的推理集群、跨 Ray 集群部署、手工预热的 SGLang engine,或由其他编排系统管理的 rollout service。 +## Opaque HTTP Rollout Endpoint + +`--rollout-external-engine-addrs` 仍然假设 SGLang engine 有稳定地址:slime 会逐个查询 `/server_info`,把每个 engine 注册到 router,并向已知 engine 句柄推送权重更新。有些部署无法提供这种契约——例如单一 URL 背后的 serverless 或自动扩缩容推理集群,worker 随时增减,也不暴露任何 worker 管理 API。这种情况下让 slime 直接指向 endpoint: + +```bash +python train.py \ + --rollout-http-endpoint-url https://rollout.example.com \ + ... +``` + +在这个模式下,slime 不启动任何 engine 和 router,对 endpoint 的假设只有生成路由:rollout 请求 POST 到 `{url}/generate`,自定义 rollout function 里的 `get_model_url(args, ...)` 也解析到该 endpoint。placement group 中不会分配 rollout GPU,`/server_info` 永远不会被查询,slime 的 fault tolerance 也不管理这个集群——故障恢复由 endpoint 运营方负责。`--rollout-http-endpoint-url` 与 `--rollout-external-engine-addrs` 互斥。 + +两个配套参数让默认 SGLang rollout 适配没有 router API 的 endpoint: + +- `--rollout-http-endpoint-abort-strategy {cancel-only,router-workers}`:控制两次 rollout 之间 `abort` 的行为。`cancel-only`(设置 endpoint URL 时的默认值)只取消 slime 本地待完成的生成任务,不调用 router 的 worker 列表或 per-worker abort API。`router-workers` 保留原有基于 router 的 abort,在其他情况下仍是默认值。注意 `cancel-only` 不收集 partial sample,因此与 `--partial-rollout` 不兼容。 +- `--custom-rollout-request-hook-path`:可选 hook,在默认 SGLang `/generate` 请求发出前调用。签名为 `def hook(args, sample, request) -> None | dict`。`request` dict 包含 `url`、`payload`、`headers`、`max_retries`、`retry_sleep`、`rollout_id` 和 `evaluation`;可以原地修改,也可以返回一个 dict 覆盖字段。 + +请求级权重约束应通过这个 hook 添加。例如 hook 可以加入 `"weight_version": {"exact_version": }` 或 `"weight_version": {"min_required_version": }`,并调整 `max_retries`/`retry_sleep`。这些字段用于 opaque router 把请求路由到落后副本时尽早失败并重试,避免浪费 rollout compute;它们不定义 SLIME 的 off-policy 或 staleness 语义,真正的有效版本仍由训练调度和 loss/correction 路径决定。 + +至于权重同步,弹性集群通常也无法接收 per-engine 的 `update_weights_from_disk` RPC。可以把 endpoint 与 publish-only delta 同步组合使用:训练端通过自定义 hook 发布每个完整的权重版本,serving 侧按自己的节奏消费——见 [Publish-Only 磁盘 Delta](delta-weight-sync.md#publish-only-磁盘-delta)。如果请求级最低版本重试已经足够,保留 publish-only 的默认流水线模式即可;如果 publish hook 会轮询 rollout 集群状态、并且你希望下一次 rollout dispatch 等待该就绪阈值,可以设置 `--update-weight-delta-publish-wait=sync`。 + ## 与 `--sglang-config` 的关系 `--rollout-external-engine-addrs` 和 `--sglang-config` 互斥,因为它们负责不同的边界: @@ -108,8 +131,9 @@ delta update 面向大模型、跨集群或跨数据中心训推解耦。它不 - external engine 可以使用独立 SGLang 环境;不需要安装 slime 或 Megatron 训练环境。 - disk transport 支持训练和 rollout 使用不同型号或不同厂家的 GPU,前提是 SGLang 支持对应硬件和模型格式。 - disk transport 要求训练端和 SGLang engine 看到同一个 `--update-weight-disk-dir` 路径;路径只在训练端可见是不够的。 -- external engine 当前不支持 slime 的 fault tolerance 恢复流程;engine 生命周期由外部系统负责。 -- `--sglang-config` 与 `--rollout-external-engine-addrs` 互斥。 +- external engine 当前不支持 slime 的 fault tolerance 恢复流程;engine 生命周期由外部系统负责。`--rollout-http-endpoint-url` 背后的集群同理。 +- `--sglang-config` 与 `--rollout-external-engine-addrs` 互斥;`--rollout-external-engine-addrs` 与 `--rollout-http-endpoint-url` 也互斥。 +- opaque HTTP endpoint 只需要提供生成路由;slime 不会调用任何 worker 管理 API。如果集群无法接收直接的权重更新 RPC,请使用 publish-only delta 同步。 - delta mode 不支持 `--colocate`,因为 colocated 权重同步通过 CUDA IPC 传句柄,delta 编码不会节省实际传输量。 ## 参考工作 diff --git a/docs/zh/developer_guide/ci.md b/docs/zh/developer_guide/ci.md index 44db10c0b4..d36f69c173 100644 --- a/docs/zh/developer_guide/ci.md +++ b/docs/zh/developer_guide/ci.md @@ -49,12 +49,11 @@ changed-test job 本身走 self-hosted Docker 路径。当 `NUM_GPUS = 0` 时, |---|---|---|---| | 自动运行 | `cpu-unittest` | CPU | 默认运行的 unit/contract tests,覆盖 argument validation、schedule、reward、sample、rollout validation、checkpoint utilities 和 plugin contracts。 | | 自动运行 | `agent-adapter-test` | CPU | 默认运行的 agent adapter tests,包含额外 provider SDK 依赖。 | -| `run-ci-short` | `e2e-test-short` | GPU | 小模型轻量级 smoke tests,用于快速 GPU 反馈。 | | `run-ci-sglang-config` | `e2e-test-sglang-config` | GPU | SGLang config 测试,覆盖高级 rollout engine deployment 和 mixed/offload 场景。 | | `run-ci-megatron` | `e2e-test-megatron` | GPU | 核心 Megatron 训练测试,覆盖 dense、MoE、PPO、MTP、OPD、async rollout、PD/Mooncake 和 debug replay 路径。 | | `run-ci-precision` | `e2e-test-precision` | GPU | 数值精度和并行一致性检查。 | | `run-ci-ckpt` | `e2e-test-ckpt` | GPU | Checkpoint save/load 正确性,包括 CPU/GPU optimizer state 和 async save。 | -| `run-ci-image` | `e2e-test-image` | GPU | 在 `slimerl/slime-test:latest` 上运行更完整的镜像验证套件。 | +| `run-ci-image` | `e2e-test-image` | GPU | 在 `slimerl/slime-test:latest` 上运行与 `run-ci-megatron` 相同的 matrix。 | | `run-ci-changed` | `e2e-test-changed` | Mixed | 只运行 changed tests,并使用每个文件中的 `NUM_GPUS`。 | 也可以在 Actions 页面通过 `workflow_dispatch` 手动验证;它会按照 workflow 条件运行注册的 jobs。 @@ -86,12 +85,11 @@ python -m pytest tests/test_megatron_argument_validation.py tests/plugin_contrac GPU e2e tests 验证 CPU tests 无法覆盖的集成训练/rollout 行为: -- `run-ci-short`:小模型 smoke coverage,用于快速 GPU 反馈。 - `run-ci-sglang-config`:高级 SGLang deployment path,包括 config-based engine layouts。 - `run-ci-megatron`:主要 Megatron backend coverage,包括 dense/MoE recipe、async rollout、OPD、PPO-style path、PD/Mooncake 和 debug rollout-then-train replay。 - `run-ci-precision`:不同并行设置下的数值一致性。 - `run-ci-ckpt`:checkpoint save/load 组合和 async save。 -- `run-ci-image`:release/test image 的较完整验证。 +- `run-ci-image`:与 `run-ci-megatron` 相同的 matrix,但运行在 release/test image 上。 日常 PR 优先使用 targeted labels。`run-ci-image` 消耗 GPU 时间较多,应谨慎使用。 diff --git a/docs/zh/get_started/customization.md b/docs/zh/get_started/customization.md index 5b95f05463..eb92095c08 100644 --- a/docs/zh/get_started/customization.md +++ b/docs/zh/get_started/customization.md @@ -28,6 +28,7 @@ slime 通过函数路径参数提供了广泛的自定义能力。这些参数 | [`--custom-megatron-init-path`](#17-megatron-hook) | Megatron 设置后的自定义初始化。 | | [`--custom-megatron-before-log-prob-hook-path`](#17-megatron-hook) | log probability 计算前的自定义逻辑。 | | [`--custom-megatron-before-train-step-hook-path`](#17-megatron-hook) | 每个训练步骤前的自定义逻辑。 | +| [`--custom-rollout-request-hook-path`](#19-rollout-request-hook---custom-rollout-request-hook-path) | 在默认 SGLang `/generate` 请求发出前自定义请求。 | ## 通过 customization 接口实现 agentic workflow @@ -459,6 +460,22 @@ def custom_hook(args, rollout_id, step_id, model, optimizer, opt_param_scheduler | `--use-routing-replay` | 训练中前向-反向路由一致性。([arXiv:2507.18071](https://arxiv.org/abs/2507.18071)) | | `--use-rollout-routing-replay` | R3:在训练时重放 rollout 阶段的路由。slime 默认的 `sglang_router` 路径支持该功能。([arXiv:2510.11370](https://arxiv.org/abs/2510.11370)) | +--- + +### 19. Rollout Request Hook (`--custom-rollout-request-hook-path`) + +**函数签名**: +```python +def hook(args, sample, request) -> None | dict +``` + +**用途**: 在默认 SGLang rollout `/generate` 请求发出前自定义该请求。`request` +包含 `url`、`payload`、`headers`、`max_retries`、`retry_sleep`、`rollout_id` +和 `evaluation`。可以原地修改,也可以返回一个 dict 覆盖字段。 + +这个 hook 适合外部 rollout provider 的请求级 admission control,例如加入 +`payload["weight_version"]`,让路由到落后副本的请求在生成不可用样本前失败并重试。 + ## 自定义函数路径的测试 slime 现在也提供了一组 CPU 契约测试,用于校验这些 customization 接口。测试会通过字符串形式的导入路径来动态加载组件,因此既能回归仓库内置 hook,也能验证用户通过和训练时完全相同的 CLI 参数传入的自定义实现。 @@ -472,7 +489,7 @@ slime 现在也提供了一组 CPU 契约测试,用于校验这些 customizati - `tests/plugin_contracts/test_plugin_path_loading_contracts.py` 覆盖 `--eval-function-path`、`--custom-rm-path`、`--dynamic-sampling-filter-path`、`--buffer-filter-path`、`--data-source-path`、`--rollout-sample-filter-path`、`--rollout-all-samples-process-path` - `tests/plugin_contracts/test_plugin_runtime_hook_contracts.py` - 覆盖 `--custom-rollout-log-function-path`、`--custom-eval-rollout-log-function-path`、`--custom-reward-post-process-path`、`--custom-convert-samples-to-train-data-path`、`--rollout-data-postprocess-path` + 覆盖 `--custom-rollout-log-function-path`、`--custom-eval-rollout-log-function-path`、`--custom-reward-post-process-path`、`--custom-convert-samples-to-train-data-path`、`--rollout-data-postprocess-path`、`--custom-rollout-request-hook-path` 本地运行全部 customization 契约测试: diff --git a/docs/zh/get_started/quick_start.md b/docs/zh/get_started/quick_start.md index 8af226801f..1daa0a502a 100644 --- a/docs/zh/get_started/quick_start.md +++ b/docs/zh/get_started/quick_start.md @@ -305,7 +305,7 @@ SGLANG_ARGS=( ### Colocated Actor and Rollout -在默认的配置下,训练(Actor)和推理(Rollout)的资源是分开指定的,通过 ray 给训练部分分配 `actor_num_nodes * actor_num_gpus_per_node` 张 GPU,给推理分配 `rollout_num_gpus` 张 GPU,也即训推分离。 +在默认的配置下,训练(Actor)和推理(Rollout)的资源是分开指定的,通过 ray 给训练部分分配 `actor_num_nodes * actor_num_gpus_per_node` 张 GPU,给推理分配 `rollout_num_gpus` 张 GPU,也即训推分离。将 `--rollout-num-gpus` 显式设置为 `0` 时,slime 仍会解析 SGLang 参数并启动 router,但不会启动本地 SGLang server。 **标准(分离)配置**: ```bash @@ -319,7 +319,7 @@ ray job submit ... \ 上述配置中,Actor 使用 4 张卡,Rollout 也使用 4 张卡,两者并行运行。 **训推一体化(Colocated)配置**: -要将训练和推理部署在同一组 GPU 上,请添加 `--colocate` 参数,开启后会忽略 `--rollout-num-gpus` 让训练和推理的卡数相等。 +要将训练和推理部署在同一组 GPU 上,请添加 `--colocate` 参数,开启后默认会让训练和推理的卡数相等;也可以显式设置一个不同的正数,例如让 rollout 卡数多于 actor,多出的 GPU 会作为 rollout-only 资源使用。如果显式设置 `--rollout-num-gpus 0`,则只启动 router,不启动本地 SGLang server。 ```bash diff --git a/docs/zh/get_started/usage.md b/docs/zh/get_started/usage.md index faa836bec6..a501f94634 100644 --- a/docs/zh/get_started/usage.md +++ b/docs/zh/get_started/usage.md @@ -19,7 +19,7 @@ - `--actor-num-gpus-per-node`:RL 的 actor 训练的每个节点有卡; -- `--rollout-num-gpus`:rollout (inference)一共需要多少卡; +- `--rollout-num-gpus`:rollout (inference)一共需要多少卡。设置为 `0` 时,slime 仍会解析 SGLang 参数并启动 router,但不会启动本地 SGLang server; - `--rollout-num-gpus-per-engine`:每个 inference engine 有多少卡,这个参数会比较像 sglang 的 `tp_size`,也就是在进行多机 serving 的时候,这个数值应该是总卡数,例如 2 机 16 卡 serving 一个模型,这里的值应该是 16。 @@ -29,7 +29,7 @@ 当需要训推一体的时候,还需要配置上: -- `--colocate`:开启训推一体。开启后会忽略 `--rollout-num-gpus` 让训练和推理的卡数相等。 +- `--colocate`:开启训推一体。开启后默认会让训练和推理的卡数相等;也可以显式设置一个不同的正数,例如让 rollout 卡数多于 actor,多出的 GPU 会作为 rollout-only 资源使用。如果显式设置 `--rollout-num-gpus 0`,则只启动 router,不启动本地 SGLang server。 此外,slime 支持 Prefill 和 Decode 的分离部署 (PD Disaggregation),可以通过设置 `--prefill-num-servers` 参数来指定用于 Prefill 的服务器数量。 @@ -188,6 +188,7 @@ sglang 的加载非常简单,只需要: - `--advantage-estimator`: 当前训练需要的 RL 算法,目前支持: - `grpo`(https://arxiv.org/abs/2402.03300); - `gspo`(https://arxiv.org/abs/2507.18071); + - `cispo`(https://arxiv.org/abs/2506.13585); - `reinforce_plus_plus` 与 `reinforce_plus_plus_baseline`(https://arxiv.org/abs/2501.03262); - `ppo`(https://arxiv.org/abs/1707.06347)。 diff --git a/examples/search-r1/generate_with_search.py b/examples/search-r1/generate_with_search.py index 2fb31feb16..16fc1fea04 100644 --- a/examples/search-r1/generate_with_search.py +++ b/examples/search-r1/generate_with_search.py @@ -157,6 +157,21 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: loss_mask = [] rollout_log_probs = [] if SEARCH_R1_CONFIGS["return_logprob"] else None + # BUGFIX: make the inference engine STOP at the tool/answer boundary. + # Without a stop, sglang keeps emitting tokens after / + # (junk, even fabricated new "Question:"s). The example only trimmed that junk + # via postprocess_responses when return_logprob=False; with return_logprob=True + # (TIS) trimming is disabled to keep token/logp aligned, so the junk stayed in + # the trajectory and got trained on (loss_mask=1) AND broke is_valid_sequence + # (trailing content after -> format invalid -> lower reward). + # Stopping at the tag avoids all of that and keeps token/logp aligned natively. + # slime already sets no_stop_trim=True, so the closing tag is kept in the output. + _stop_tags = ["", ""] + _existing_stop = sampling_params.get("stop") or [] + if isinstance(_existing_stop, str): + _existing_stop = [_existing_stop] + sampling_params = {**sampling_params, "stop": list(dict.fromkeys([*_existing_stop, *_stop_tags]))} + for _turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]): payload = { "text": prompt_text + response, diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 78056231df..0e68f8739b 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -5,7 +5,6 @@ from contextlib import nullcontext from pathlib import Path -import numpy as np import ray import torch import torch.distributed as dist @@ -229,29 +228,26 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: ) # TODO: this is ugly, move to somewhere else? # move tokens to GPU in advance + device = torch.cuda.current_device() rollout_data["tokens"] = [ - torch.tensor(t, dtype=torch.long, device=torch.cuda.current_device()) for t in rollout_data["tokens"] + t.to(device=device, dtype=torch.long, non_blocking=True) for t in rollout_data["tokens"] ] rollout_data["loss_masks"] = [ - torch.tensor(t, dtype=torch.int, device=torch.cuda.current_device()) for t in rollout_data["loss_masks"] + t.to(device=device, dtype=torch.int, non_blocking=True) for t in rollout_data["loss_masks"] ] if "rollout_mask_sums" in rollout_data: # Promote precomputed per-rollout mask totals to GPU tensors here # (matching loss_masks) so the loss reducer can just divide. - rollout_data["rollout_mask_sums"] = torch.tensor( - rollout_data["rollout_mask_sums"], dtype=torch.float32, device=torch.cuda.current_device() + rollout_data["rollout_mask_sums"] = rollout_data["rollout_mask_sums"].to( + device=device, dtype=torch.float32, non_blocking=True ) if "multimodal_train_inputs" in rollout_data: # Move multimodal training tensors to GPU in advance rollout_data["multimodal_train_inputs"] = [ ( { - key: ( - torch.from_numpy(v.copy()).to(device=torch.cuda.current_device()) - if isinstance(v, np.ndarray) - else v.to(device=torch.cuda.current_device()) - ) - for key, v in mm_dict.items() + key: value.to(device=device, non_blocking=True) if isinstance(value, torch.Tensor) else value + for key, value in mm_dict.items() } if mm_dict is not None else None @@ -273,16 +269,16 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: if key not in rollout_data: continue rollout_data[key] = [ - torch.tensor( - slice_log_prob_with_cp( - log_prob, - total_length, - response_length, - self.args.qkv_format, - rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, - ), - device=torch.cuda.current_device(), + slice_log_prob_with_cp( + log_prob, + total_length, + response_length, + self.args.qkv_format, + rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, + ).to( + device=device, dtype=torch.float32, + non_blocking=True, ) for i, (log_prob, total_length, response_length) in enumerate( zip( @@ -293,10 +289,6 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: ) ) ] - if "rollout_routed_experts" in rollout_data: - rollout_data["rollout_routed_experts"] = [ - torch.from_numpy(r) for r in rollout_data["rollout_routed_experts"] - ] return rollout_data def _switch_model(self, target_tag: str) -> None: @@ -617,13 +609,19 @@ def update_weights(self) -> None: ) reconnect_rollout_engines = self.args.offload_train and self.args.use_critic and not self.args.colocate + force_connect_rollout_target = getattr(self.args, "update_weight_delta_publish_only", False) + + if not rollout_engines and not reconnect_rollout_engines and not force_connect_rollout_target: + if dist.get_rank() == 0: + logger.info("No updatable SGLang engines are running; skip weight update.") + return if reconnect_rollout_engines: self.wake_up() elif self.args.offload_train: reload_process_groups() - if num_new_engines > 0 or reconnect_rollout_engines: + if num_new_engines > 0 or reconnect_rollout_engines or force_connect_rollout_target: self.weight_updater.connect_rollout_engines( rollout_engines, rollout_engine_lock, diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index a63fa159f0..a456939e74 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -13,6 +13,7 @@ from slime.utils.ppo_utils import ( calculate_log_probs_and_entropy, compute_approx_kl, + compute_cispo_loss, compute_gspo_kl, compute_opsm_mask, compute_policy_loss, @@ -576,10 +577,10 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) This function extracts rewards, log-probs, values, and masks from `rollout_data`, computes KL divergences, then applies the chosen advantage - estimator. Supported methods: "grpo", "gspo", "ppo", "reinforce_plus_plus", - and "reinforce_plus_plus_baseline". When `args.normalize_advantages` is - True, advantages are whitened across the data-parallel group using masked - statistics. + estimator. Supported methods: "grpo", "gspo", "cispo", "ppo", + "reinforce_plus_plus", and "reinforce_plus_plus_baseline". When + `args.normalize_advantages` is True, advantages are whitened across the + data-parallel group using masked statistics. Early returns if both `log_probs` and `values` are None (intermediate pipeline stages). @@ -632,7 +633,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) custom_adv_fn(args, rollout_data) advantages, returns = rollout_data["advantages"], rollout_data["returns"] - elif args.advantage_estimator in ["grpo", "gspo"]: + elif args.advantage_estimator in ["grpo", "gspo", "cispo"]: rewards = torch.tensor(rewards, dtype=torch.float32, device=kl[0].device) returns = get_grpo_returns(rewards, kl) # TODO: is the copy necessary? @@ -895,7 +896,10 @@ def policy_loss_function( log_probs = torch.cat(log_probs, dim=0) ppo_kl = old_log_probs - log_probs - pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high) + if args.advantage_estimator == "cispo": + pg_loss, pg_clipfrac = compute_cispo_loss(ppo_kl, log_probs, advantages, args.eps_clip, args.eps_clip_high) + else: + pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high) if args.use_opsm: pg_loss = pg_loss * opsm_mask diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py index fbe24bbc1c..0ca1d25556 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py @@ -509,13 +509,18 @@ def __init__( self.writer: AsyncSafetensorsWriter | None = None self.delta_dir: str | None = None self._pre_push_hook: Callable | None = None - # Disk transport: each pass boundary publishes its accumulated files - # (the only globally-synced flush points, since ``_publish_batch`` - # contains collectives). ``_pre_push_hook`` may return a Future, in - # which case the receiver RPC is deferred behind it via - # ``_rpc_executor`` so the main encode thread continues immediately. - # ``_pending_publishes`` holds the resulting Future[list[ObjectRef]] - # on rank 0; ``_finalize_sync`` awaits them at end of sync. + self._publish_hook: Callable | None = None + self._publish_only = bool(getattr(args, "update_weight_delta_publish_only", False)) + self._publish_wait = getattr(args, "update_weight_delta_publish_wait", "next-sync") + # Direct disk transport publishes at each pass boundary so receiver + # apply can overlap later encoding. Publish-only disk transport emits + # one complete version at finalize time, so external consumers never + # observe a partially published version. ``_pre_push_hook`` may return + # a Future; ``_pending_publishes`` holds Future[list[ObjectRef]] values. + # Direct transport drains them at end of sync (resume + cleanup depend + # on them); publish-only transport leaves the last publish in flight + # across the training step and drains it at the start of the next sync, + # so at most one version's publish is ever outstanding. self._pending_files: list[str] = [] self._pending_publishes: list = [] self._published_any: bool = False @@ -531,6 +536,10 @@ def __init__( from slime.utils.misc import load_function self._pre_push_hook = load_function(args.custom_delta_pre_push_path) + if getattr(args, "custom_delta_publish_path", None): + from slime.utils.misc import load_function + + self._publish_hook = load_function(args.custom_delta_publish_path) def connect_rollout_engines( self, @@ -561,6 +570,8 @@ def connect_rollout_engines( self._group_name = f"slime-pp_{pp_rank}" def disconnect_rollout_engines(self) -> None: + # A queued publish holds engine handles; settle it before dropping them. + self._drain_pending_publishes() if self.transport == "nccl": super().disconnect_rollout_engines() @@ -584,14 +595,17 @@ def update_weights(self) -> None: if self._is_pp_src_rank: os.makedirs(self._version_dir, exist_ok=True) - if dist.get_rank() == 0: + if dist.get_rank() == 0 and not self._publish_only: ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) self.density_nnz = self.density_numel = self.wire_bytes = self._flush_idx = 0 self._pending_files.clear() - self._pending_publishes.clear() + # Publish-only mode leaves the previous sync's publish in flight across + # the training step; settle it (and surface any publish failure, one + # sync late) before producing the next version. + self._drain_pending_publishes() self._published_any = False if self.writer is not None: self.writer.reset_counters() @@ -649,13 +663,15 @@ def _send_weights(self, pbar: tqdm | None) -> None: def _flush_and_publish(self, bucket: DeltaBucket, pbar: tqdm | None) -> None: """ - End-of-sub-pass: drain the in-flight bucket, barrier all PP ranks, then - (disk-only) fire one publish RPC for everything since the last call. + End-of-sub-pass: drain the in-flight bucket and barrier all PP ranks. + Direct disk transport also fires one receiver RPC for everything since + the last call; publish-only transport waits until finalize so the hook + sees a complete version. """ if bucket.has_updates: self._flush_bucket(bucket, pbar) dist.barrier(group=get_gloo_group()) - if self.transport == "disk": + if self.transport == "disk" and not self._publish_only: self._publish_batch() def _pipeline_pass( @@ -746,13 +762,15 @@ def _flush_bucket(self, bucket: DeltaBucket, pbar: tqdm | None) -> None: def _publish_batch(self) -> None: """ - Drain pending fsyncs, invoke the pre-push hook (may return a Future for an - async durability step on shared FS), then defer rank 0's - ``update_weights_from_disk`` RPC behind that Future via ``_rpc_executor``. - Each deferred dispatch lands in ``_pending_publishes`` as a - Future[list[ObjectRef]]; ``_finalize_sync`` awaits both layers. Safe to call - with empty ``_pending_files``: the all_gather still synchronizes and rank 0 - skips the dispatch when no rank produced files. + Drain pending fsyncs, invoke the pre-push hook (may return a Future for + an async durability step on shared FS), gather filenames, then defer + rank 0's publish/direct-update work behind that Future via + ``_rpc_executor``. Each deferred dispatch lands in + ``_pending_publishes`` as a Future[list[ObjectRef]]; direct disk + transport awaits both layers in ``_finalize_sync``, publish-only at the + start of the next sync. Safe to call with empty ``_pending_files``: + direct disk transport skips the dispatch, while publish-only still + calls the publish hook so a no-op version can be made visible. """ self.writer.drain() dist.barrier(group=get_gloo_group()) @@ -768,24 +786,31 @@ def _publish_batch(self) -> None: flat = [f for sub in all_files for f in sub] self._pending_files.clear() - if dist.get_rank() == 0 and flat: + if dist.get_rank() == 0 and (flat or self._publish_only): version_dir = self._version_dir engines = list(self.rollout_engines) weight_version = str(self.weight_version) self._published_any = True def _fire_when_committed() -> list: + refs = [] if commit_future is not None: commit_future.result() - return [ - engine.update_weights_from_disk.remote( - model_path=version_dir, - files=flat, - load_format="delta", - weight_version=weight_version, + if self._publish_hook is not None: + hook_refs = self._publish_hook(self.args, version_dir, flat, weight_version, engines) + if hook_refs is not None: + refs.extend(hook_refs) + if not self._publish_only: + refs.extend( + engine.update_weights_from_disk.remote( + model_path=version_dir, + files=flat, + load_format="delta", + weight_version=weight_version, + ) + for engine in engines ) - for engine in engines - ] + return refs self._pending_publishes.append(self._rpc_executor.submit(_fire_when_committed)) @@ -793,7 +818,14 @@ def _finalize_sync(self) -> None: """ Per-transport end-of-sync. NCCL: each flush already broadcasted; just resume. Disk: publish the trailing files, wait for all streamed applies to land, then - cleanup + resume. + cleanup + resume. Publish-only: dispatch the version's publish and return + without awaiting it by default — the hook runs concurrently with the + next training step and the start of the next sync settles it. With + ``--update-weight-delta-publish-wait=sync``, publish-only drains the + hook before returning so the next rollout dispatch starts only after + the hook's readiness contract has been satisfied. In both modes the + version dir must outlive the sync, which publish-only's no-cleanup + rule already ensures. """ if self.transport == "nccl": if dist.get_rank() == 0: @@ -801,15 +833,14 @@ def _finalize_sync(self) -> None: dist.barrier(group=get_gloo_group()) return - if self._pending_files: + if self._pending_files or self._publish_only: self._publish_batch() - if dist.get_rank() == 0: - # Each entry is a Future returning a list of ObjectRefs. Awaiting the - # Futures unblocks the (commit-then-RPC) chain; ray.get waits for the - # receivers' apply to finish. - object_refs = [ref for fut in self._pending_publishes for ref in fut.result()] - ray.get(object_refs) - self._pending_publishes.clear() + if dist.get_rank() == 0 and self._publish_only and self._publish_wait == "sync": + self._drain_pending_publishes() + if dist.get_rank() == 0 and not self._publish_only: + # Resume + cleanup must order after the receivers' apply, so drain + # the publish chain before either. + self._drain_pending_publishes() if not self._published_any: # No delta files needed publishing this sync (e.g. all-zero diff). # Engines never saw the new version via update_weights_from_disk, so @@ -821,6 +852,23 @@ def _finalize_sync(self) -> None: ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) + def _drain_pending_publishes(self) -> None: + """ + Await every queued (commit-then-publish) Future, then the ObjectRefs it + returned. Awaiting the Futures unblocks the (commit-then-RPC) chain; + ray.get waits for the hook's work and the receivers' apply to finish. + Re-raises a failed publish here, on the draining rank (rank 0; the list + is empty elsewhere). In publish-only ``next-sync`` mode that is one + sync after the failing dispatch; in ``sync`` mode it is still inside the + publishing sync. + """ + if not self._pending_publishes: + return + object_refs = [ref for fut in self._pending_publishes for ref in fut.result()] + self._pending_publishes.clear() + if object_refs: + ray.get(object_refs) + def _record_metrics(self) -> None: """ Allreduce density/byte counters across PP-src ranks; stash on diff --git a/slime/backends/sglang_utils/external.py b/slime/backends/sglang_utils/external.py index 7499bb907d..d4d7867539 100644 --- a/slime/backends/sglang_utils/external.py +++ b/slime/backends/sglang_utils/external.py @@ -175,10 +175,11 @@ def external_engine_infos_from_args(args) -> list[ExternalEngineInfo]: return [ExternalEngineInfo(**info) if isinstance(info, dict) else info for info in raw_infos] -def start_external_rollout_servers(args, *, start_router) -> dict[str, ExternalRolloutServer]: +def start_external_rollout_servers(args, *, start_router) -> tuple[dict[str, ExternalRolloutServer], list]: import ray from slime.backends.sglang_utils.sglang_engine import SGLangEngine + from slime.ray.utils import add_default_ray_env_vars infos = external_engine_infos_from_args(args) router_ip, router_port = start_router(args, has_pd_disaggregation=any(info.is_pd_worker for info in infos)) @@ -192,7 +193,11 @@ def start_external_rollout_servers(args, *, start_router) -> dict[str, ExternalR RolloutRayActor = ray.remote(SGLangEngine) gpu_offset = 0 for rank, info in enumerate(infos): - rollout_engine = RolloutRayActor.options(num_cpus=0.2, num_gpus=0).remote( + rollout_engine = RolloutRayActor.options( + num_cpus=0.2, + num_gpus=0, + runtime_env={"env_vars": add_default_ray_env_vars()}, + ).remote( args=args, rank=rank, worker_type=info.worker_type, @@ -211,11 +216,8 @@ def start_external_rollout_servers(args, *, start_router) -> dict[str, ExternalR ) ) - if init_handles: - ray.get(init_handles) - args.sglang_model_routers = {"default": (router_ip, router_port)} - return { + servers = { "default": ExternalRolloutServer( engines=engines, engine_gpu_counts=engine_gpu_counts, @@ -227,3 +229,4 @@ def start_external_rollout_servers(args, *, start_router) -> dict[str, ExternalR num_new_engines=len(engines), ) } + return servers, init_handles diff --git a/slime/backends/sglang_utils/http_endpoint.py b/slime/backends/sglang_utils/http_endpoint.py new file mode 100644 index 0000000000..e0b87d0a54 --- /dev/null +++ b/slime/backends/sglang_utils/http_endpoint.py @@ -0,0 +1,85 @@ +"""Helpers for rollout backends served by an opaque HTTP endpoint.""" + +from __future__ import annotations + +import dataclasses +import logging +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + + +def normalize_rollout_http_endpoint_url(url: str) -> str: + """Normalize an HTTP endpoint base URL used by rollout generation.""" + url = url.rstrip("/") + parsed = urlparse(url) + if parsed.scheme not in ("http", "https") or parsed.netloc == "": + raise ValueError(f"Invalid rollout HTTP endpoint URL {url!r}. Use an absolute http:// or https:// URL.") + return url + + +def uses_rollout_http_endpoint(args) -> bool: + return bool(getattr(args, "rollout_http_endpoint_url", None)) + + +def rollout_http_endpoint_url(args, endpoint: str = "/generate") -> str: + base = normalize_rollout_http_endpoint_url(args.rollout_http_endpoint_url) + if not endpoint.startswith("/"): + endpoint = f"/{endpoint}" + return f"{base}{endpoint}" + + +@dataclasses.dataclass +class HttpEndpointRolloutServer: + """Rollout server backed by an opaque HTTP endpoint. + + The endpoint is intentionally not assumed to be an SGLang router: it may not + expose worker-management APIs such as ``/workers`` and it may represent an + elastic fleet with no stable per-engine handles. + """ + + endpoint_url: str + model_name: str = "default" + update_weights: bool = True + router_ip: str | None = None + router_port: int | None = None + server_groups: list = dataclasses.field(default_factory=list) + engines: list = dataclasses.field(default_factory=list) + engine_gpu_counts: list[int] = dataclasses.field(default_factory=list) + engine_gpu_offsets: list[int] = dataclasses.field(default_factory=list) + num_new_engines: int = 0 + + @property + def all_engines(self): + return self.engines + + def recover(self): + logger.warning("Fault tolerance is not supported for opaque HTTP rollout endpoints; skip recover.") + + def offload(self): + return [] + + def onload(self, tags: list[str] | None = None): + return [] + + def onload_weights(self): + return [] + + def onload_kv(self): + return [] + + +def start_http_endpoint_rollout_servers(args) -> dict[str, HttpEndpointRolloutServer]: + endpoint_url = normalize_rollout_http_endpoint_url(args.rollout_http_endpoint_url) + args.rollout_http_endpoint_url = endpoint_url + args.sglang_model_routers = {} + if getattr(args, "rollout_num_engines", None) is None: + args.rollout_num_engines = 1 + logger.info("Using opaque HTTP rollout endpoint: %s", endpoint_url) + return { + "default": HttpEndpointRolloutServer( + endpoint_url=endpoint_url, + model_name="default", + update_weights=True, + ) + } diff --git a/slime/backends/sglang_utils/server_control.py b/slime/backends/sglang_utils/server_control.py new file mode 100644 index 0000000000..56da8e4c16 --- /dev/null +++ b/slime/backends/sglang_utils/server_control.py @@ -0,0 +1,67 @@ +import asyncio +import logging +from typing import Any + +from slime.utils.http_utils import get, post + +logger = logging.getLogger(__name__) + +ABORT_RETRY_INTERVAL_SECONDS = 3 + + +def num_requests_from_load(load: Any) -> int: + if isinstance(load, list): + return sum(num_requests_from_load(item) for item in load) + + if not isinstance(load, dict): + return 0 + + if "loads" in load: + return num_requests_from_load(load["loads"]) + + for key in ("num_reqs", "num_total_reqs", "total_reqs"): + value = load.get(key) + if isinstance(value, int): + return value + + running = load.get("num_running_reqs", load.get("total_running_reqs")) + waiting = load.get("num_waiting_reqs", load.get("total_waiting_reqs")) + return (running if isinstance(running, int) else 0) + (waiting if isinstance(waiting, int) else 0) + + +async def _abort_server_once(url: str) -> None: + try: + await post(f"{url}/abort_request", {"abort_all": True}) + except Exception as e: + logger.warning(f"Failed to abort SGLang server at {url}: {e}") + + +async def _get_server_num_requests(url: str) -> int: + return num_requests_from_load(await get(f"{url}/v1/loads?include=core")) + + +async def abort_server_until_idle(url: str, retry_interval: int = ABORT_RETRY_INTERVAL_SECONDS) -> None: + attempt = 1 + while True: + logger.info(f"Abort request for SGLang server {url}") + await _abort_server_once(url) + + try: + num_requests = await _get_server_num_requests(url) + except Exception as e: + logger.warning(f"Failed to get SGLang server load from {url}: {e}") + return + + if num_requests <= 0: + return + + logger.info( + f"SGLang server {url} still has {num_requests} requests after abort attempt {attempt}; " + f"retrying in {retry_interval} seconds." + ) + await asyncio.sleep(retry_interval) + attempt += 1 + + +async def abort_servers_until_idle(urls: list[str]) -> None: + await asyncio.gather(*(abort_server_until_idle(url) for url in urls)) diff --git a/slime/ray/actor_group.py b/slime/ray/actor_group.py index 27ad610ad9..9440404907 100644 --- a/slime/ray/actor_group.py +++ b/slime/ray/actor_group.py @@ -4,7 +4,7 @@ from ray.util.placement_group import PlacementGroup from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from slime.ray.utils import NOSET_VISIBLE_DEVICES_ENV_VARS_LIST +from slime.ray.utils import NOSET_VISIBLE_DEVICES_ENV_VARS_LIST, add_default_ray_env_vars class RayTrainGroup: @@ -89,7 +89,13 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor): actor_impl = MegatronTrainRayActor - TrainRayActor = ray.remote(num_gpus=1, runtime_env={"env_vars": env_vars})(actor_impl) + actor_options = { + "num_gpus": 1, + "runtime_env": {"env_vars": add_default_ray_env_vars(env_vars)}, + } + if getattr(self.args, "rollout_data_transport", "object-store") == "nixl": + actor_options["enable_tensor_transport"] = True + TrainRayActor = ray.remote(**actor_options)(actor_impl) # Create worker actors self._actor_handlers = [] diff --git a/slime/ray/placement_group.py b/slime/ray/placement_group.py index 32a13a7e71..108d723263 100644 --- a/slime/ray/placement_group.py +++ b/slime/ray/placement_group.py @@ -7,6 +7,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from .actor_group import RayTrainGroup +from .utils import add_default_ray_env_vars logger = logging.getLogger(__name__) @@ -73,7 +74,7 @@ def _create_placement_group(num_gpus): scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, placement_group_bundle_index=i, - ) + ), ).remote() ) gpu_ids = ray.get([actor.get_ip_and_gpu_id.remote() for actor in info_actors]) @@ -102,6 +103,11 @@ def _get_placement_group_layout(args) -> tuple[int, int]: if args.debug_train_only: return actor_num_gpus, 0 + if getattr(args, "rollout_http_endpoint_url", None): + if args.debug_rollout_only: + return 0, 0 + return actor_num_gpus, actor_num_gpus + if args.rollout_external: if args.debug_rollout_only: return 0, 0 @@ -111,7 +117,7 @@ def _get_placement_group_layout(args) -> tuple[int, int]: return args.rollout_num_gpus, 0 if args.colocate: - return actor_num_gpus, 0 + return max(actor_num_gpus, args.rollout_num_gpus), 0 return actor_num_gpus + args.rollout_num_gpus, actor_num_gpus @@ -214,10 +220,14 @@ def create_training_models(args, pgs, rollout_manager): def create_rollout_manager(args, pg): from .rollout import RolloutManager - rollout_manager = RolloutManager.options( - num_cpus=1, - num_gpus=0, - ).remote(args, pg) + rollout_manager_options = { + "num_cpus": 1, + "num_gpus": 0, + "runtime_env": {"env_vars": add_default_ray_env_vars()}, + } + if getattr(args, "rollout_data_transport", "object-store") == "nixl": + rollout_manager_options["enable_tensor_transport"] = True + rollout_manager = RolloutManager.options(**rollout_manager_options).remote(args, pg) # calculate num_rollout from num_epoch num_rollout_per_epoch = None diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index f772bf037a..b85f34a5e7 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -15,6 +15,7 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from slime.backends.sglang_utils.external import start_external_rollout_servers +from slime.backends.sglang_utils.http_endpoint import start_http_endpoint_rollout_servers, uses_rollout_http_endpoint from slime.backends.sglang_utils.sglang_config import ModelConfig, ServerGroupConfig, SglangConfig from slime.backends.sglang_utils.sglang_engine import SGLangEngine from slime.rollout.base_types import call_rollout_fn @@ -29,13 +30,21 @@ from ..utils.metric_utils import has_repetition from .rollout_validation import validate_server_group_gpu_indices -from .utils import NOSET_VISIBLE_DEVICES_ENV_VARS_LIST, Lock +from .utils import NOSET_VISIBLE_DEVICES_ENV_VARS_LIST, Lock, add_default_ray_env_vars logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) logger = logging.getLogger(__name__) +_ROLLOUT_DATA_TENSOR_DTYPES = { + "tokens": torch.long, + "loss_masks": torch.int, + "rollout_log_probs": torch.float32, + "teacher_log_probs": torch.float32, + "rollout_routed_experts": None, +} + _SGLANG_REQUEST_PERF_FIELDS = ( ("request/e2e_latency", "e2e_latency"), ("request/queue_time", "queue_time"), @@ -60,6 +69,38 @@ ) +def _cpu_tensor(value, dtype: torch.dtype | None = None) -> torch.Tensor: + if isinstance(value, np.ndarray) and not value.flags.writeable: + value = value.copy() + tensor = torch.as_tensor(value, dtype=dtype) if dtype is not None else torch.as_tensor(value) + return tensor.detach().cpu().contiguous() + + +def _tensorize_rollout_data_for_training(rollout_data: dict[str, Any]) -> None: + for key, dtype in _ROLLOUT_DATA_TENSOR_DTYPES.items(): + if key in rollout_data: + rollout_data[key] = [_cpu_tensor(value, dtype=dtype) for value in rollout_data[key]] + + if "multimodal_train_inputs" in rollout_data: + rollout_data["multimodal_train_inputs"] = [ + ( + { + key: _cpu_tensor(value) if isinstance(value, (np.ndarray, torch.Tensor)) else value + for key, value in mm_dict.items() + } + if mm_dict is not None + else None + ) + for mm_dict in rollout_data["multimodal_train_inputs"] + ] + + if "rollout_mask_sums" in rollout_data: + rollout_data["rollout_mask_sums"] = _cpu_tensor( + rollout_data["rollout_mask_sums"], + dtype=torch.float32, + ) + + @dataclasses.dataclass class ServerGroup: """A group of homogeneous SGLang engines with the same configuration. @@ -163,7 +204,7 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, runtime_env={ - "env_vars": env_vars, + "env_vars": add_default_ray_env_vars(env_vars), }, ).remote( self.args, @@ -386,6 +427,13 @@ def __init__(self, args, pg): self.pg = pg self.args = args + rollout_init_handles: list[Any] = [] + if self.args.debug_train_only: + self.servers: dict[str, Any] = {} + else: + init_http_client(args) + self.servers, rollout_init_handles = start_rollout_servers(args, pg) + data_source_cls = load_function(self.args.data_source_path) self.data_source = data_source_cls(args) @@ -402,14 +450,15 @@ def __init__(self, args, pg): logger.info(f"import {self.args.rollout_function_path} as generate_rollout function.") logger.info(f"import {self.args.eval_function_path} as eval_generate_rollout function.") - if self.args.debug_train_only: - self.servers: dict[str, Any] = {} - else: - init_http_client(args) - self.servers = start_rollout_servers(args, pg) + if rollout_init_handles: + ray.get(rollout_init_handles) init_tracking(args, primary=False) - self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote() + self.rollout_engine_lock = Lock.options( + num_cpus=1, + num_gpus=0, + runtime_env={"env_vars": add_default_ray_env_vars()}, + ).remote() self.rollout_id = -1 self._health_monitors = [] @@ -421,23 +470,6 @@ def __init__(self, args, pg): self._health_monitors.append(monitor) self._ci_fault_injection_pending = self.args.ci_test # Flag for CI fault injection - def _get_metrics_router_addr(self) -> str | None: - """Return the router address for scraping SGLang engine metrics. - - The sglang_router gateway exposes ``/engine_metrics`` on its main port, - which aggregates Prometheus metrics from all backend sglang servers. - Returns ``http://{ip}:{port}`` for the first server, or ``None`` when - metrics are disabled or no servers are running. - """ - srv = self.server - if srv is None or srv.router_ip is None: - return None - return f"http://{srv.router_ip}:{srv.router_port}" - - def get_metrics_router_addr(self) -> str | None: - """Public wrapper for remote calls from the driver process.""" - return self._get_metrics_router_addr() - def _try_ci_fault_injection(self): """Try to inject fault during generate (when health monitor is running).""" if not self._ci_fault_injection_pending: @@ -656,7 +688,7 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]): raw_rewards = [sample.get_reward_value(self.args) for sample in samples] if ( - self.args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"] + self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce_plus_plus_baseline"] and self.args.rewards_normalization ): # group norm @@ -669,7 +701,7 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]): mean = rewards.mean(dim=-1, keepdim=True) rewards = rewards - mean - if self.args.advantage_estimator in ["grpo", "gspo"] and self.args.grpo_std_normalization: + if self.args.advantage_estimator in ["grpo", "gspo", "cispo"] and self.args.grpo_std_normalization: std = rewards.std(dim=-1, keepdim=True) rewards = rewards / (std + 1e-6) @@ -833,7 +865,14 @@ def _split_train_data_by_dp(self, data): rollout_data["global_batch_sizes"] = global_batch_sizes rollout_data["num_microbatches"] = num_microbatches rollout_data["micro_batch_indices"] = micro_batch_indices[r] - rollout_data_refs.append(Box(ray.put(rollout_data))) + _tensorize_rollout_data_for_training(rollout_data) + transport = getattr(self.args, "rollout_data_transport", "object-store") + if transport == "nixl": + rollout_data_refs.append(Box(ray.put(rollout_data, _tensor_transport="nixl"))) + elif transport == "object-store": + rollout_data_refs.append(Box(ray.put(rollout_data))) + else: + raise ValueError(f"Unsupported rollout data transport: {transport!r}") return rollout_data_refs @@ -1028,25 +1067,34 @@ def _compute_megatron_num_gpus(args) -> int: return num -def start_rollout_servers(args, pg) -> dict[str, Any]: - """Start rollout servers: one per model, each with its own router. +def start_rollout_servers(args, pg) -> tuple[dict[str, Any], list[Any]]: + """Start rollout servers without waiting for final engine initialization. Each model defined in the sglang config gets its own router and set of server groups. Server groups within a model may have different ``num_gpus_per_engine`` (e.g. for PD disaggregation where prefill and decode use different TP sizes). - Returns a dict mapping model name → ``RolloutServer``. + Returns ``(servers, init_handles)`` where servers maps model name to + ``RolloutServer`` and init_handles contains pending ``engine.init`` refs. Note: ``init_http_client`` should be called separately before this, as the HTTP client is shared across all servers. """ + if uses_rollout_http_endpoint(args): + # HTTP endpoints have no local engines to initialize, so there are no + # pending init handles. Return the (servers, init_handles) tuple the + # caller (RolloutManager.__init__) and this function's annotation expect, + # matching the other branches below. + return start_http_endpoint_rollout_servers(args), [] + if args.rollout_external: return start_external_rollout_servers(args, start_router=_start_router) config = _resolve_sglang_config(args) servers: dict[str, RolloutServer] = {} + pending_init_handles: list[Any] = [] gpu_offset = 0 engine_offset = 0 @@ -1110,6 +1158,8 @@ def _make_group(group_cfg, router_ip, router_port, overrides_extra=None): if has_epd: # --- Phase 1: start encoder groups, wait, collect URLs --- + # Encoder URLs are injected into the non-encoder workers' server args, + # so this phase must stay synchronous even though final LLM init is deferred. encoder_urls: list[str] = [] for group_cfg in model_cfg.server_groups: if group_cfg.worker_type != "encoder": @@ -1140,8 +1190,7 @@ def _make_group(group_cfg, router_ip, router_port, overrides_extra=None): non_encoder_handles.extend(handles) server_groups.append(group) - if non_encoder_handles: - ray.get(non_encoder_handles) + pending_init_handles.extend(non_encoder_handles) else: # No EPD — start all groups in one pass (original path). all_init_handles: list = [] @@ -1151,8 +1200,7 @@ def _make_group(group_cfg, router_ip, router_port, overrides_extra=None): all_init_handles.extend(handles) server_groups.append(group) - if all_init_handles: - ray.get(all_init_handles) + pending_init_handles.extend(all_init_handles) servers[model_cfg.name] = RolloutServer( server_groups=server_groups, @@ -1165,7 +1213,7 @@ def _make_group(group_cfg, router_ip, router_port, overrides_extra=None): # Expose per-model router info for custom rollout functions. args.sglang_model_routers = {name: (srv.router_ip, srv.router_port) for name, srv in servers.items()} - return servers + return servers, pending_init_handles def _resolve_sglang_config(args) -> SglangConfig: @@ -1178,6 +1226,9 @@ def _resolve_sglang_config(args) -> SglangConfig: assert actual == expected, f"sglang_config total GPUs ({actual}) != rollout_num_gpus ({expected})" return config + if args.rollout_num_gpus == 0: + return SglangConfig(models=[ModelConfig(name="default", server_groups=[])]) + if args.prefill_num_servers is not None: return SglangConfig.from_prefill_num_servers(args) diff --git a/slime/ray/utils.py b/slime/ray/utils.py index dc22466b62..7b3a3d5262 100644 --- a/slime/ray/utils.py +++ b/slime/ray/utils.py @@ -5,7 +5,6 @@ import torch from slime.ray.ray_actor import RayActor - # Refer to # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96 # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103 @@ -24,6 +23,15 @@ "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", ] +RAY_DEFAULT_ENV_VARS = { + # Ray's uvloop integration has caused intermittent async actor issues. + "RAY_USE_UVLOOP": "0", +} + + +def add_default_ray_env_vars(env_vars: dict[str, str] | None = None) -> dict[str, str]: + return RAY_DEFAULT_ENV_VARS | (env_vars or {}) + def ray_noset_visible_devices(env_vars=os.environ): return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST) diff --git a/slime/rollout/rm_hub/__init__.py b/slime/rollout/rm_hub/__init__.py index 0991e559e5..a2739fbec7 100644 --- a/slime/rollout/rm_hub/__init__.py +++ b/slime/rollout/rm_hub/__init__.py @@ -53,6 +53,11 @@ async def remote_rm(args, sample: Sample, max_retries: int = 10): async def async_rm(args, sample: Sample, **kwargs): + # Per-sample custom_rm_path (from eval dataset config) takes priority + if sample.custom_rm_path: + rm_function = load_function(sample.custom_rm_path) + return await rm_function(args, sample, **kwargs) + if args.custom_rm_path is not None: rm_function = load_function(args.custom_rm_path) return await rm_function(args, sample, **kwargs) diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index eee3a13680..b1a730e760 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -14,6 +14,8 @@ from packaging.version import parse from tqdm import tqdm +from slime.backends.sglang_utils.http_endpoint import rollout_http_endpoint_url, uses_rollout_http_endpoint +from slime.backends.sglang_utils.server_control import abort_servers_until_idle from slime.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput from slime.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter from slime.utils.async_utils import run @@ -32,11 +34,12 @@ from .rm_hub import async_rm, batched_async_rm -__all__ = ["generate_rollout", "get_model_url"] +__all__ = ["generate_rollout", "get_model_url", "rollout_request_context"] logger = logging.getLogger(__name__) _PROCESSOR_PROMPT_KEYS = {"input_ids", "attention_mask"} +_MISSING = object() def _prepare_prompt_ids(sample: Sample, tokenizer, processor: Any) -> list[int]: @@ -62,7 +65,7 @@ def _prepare_prompt_ids(sample: Sample, tokenizer, processor: Any) -> list[int]: def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") -> str: - """Return the router URL for a named model. + """Return the rollout URL for a named model. Use this in custom rollout functions to route requests to a specific model when multiple models are deployed via ``--sglang-config``:: @@ -70,9 +73,14 @@ def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") url = get_model_url(args, "ref", "/generate") resp = await post(url, json=payload) - Falls back to the default router if *model_name* is not found or - ``sglang_model_routers`` is not set. + If ``--rollout-http-endpoint-url`` is set, returns that opaque endpoint + with *endpoint* appended and does not assume SGLang router APIs exist. + Otherwise, falls back to the default router if *model_name* is not found + or ``sglang_model_routers`` is not set. """ + if uses_rollout_http_endpoint(args): + return rollout_http_endpoint_url(args, endpoint) + routers = getattr(args, "sglang_model_routers", None) if routers and model_name in routers: ip, port = routers[model_name] @@ -80,6 +88,67 @@ def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") return f"http://{args.sglang_router_ip}:{args.sglang_router_port}{endpoint}" +@contextmanager +def rollout_request_context(args: Namespace, rollout_id: int, *, evaluation: bool = False): + old_rollout_id = getattr(args, "_rollout_request_rollout_id", _MISSING) + old_evaluation = getattr(args, "_rollout_request_evaluation", _MISSING) + args._rollout_request_rollout_id = int(rollout_id) + args._rollout_request_evaluation = bool(evaluation) + + try: + yield + finally: + _restore_context_attr(args, "_rollout_request_rollout_id", old_rollout_id) + _restore_context_attr(args, "_rollout_request_evaluation", old_evaluation) + + +def _restore_context_attr(args: Namespace, name: str, old_value: Any) -> None: + if old_value is _MISSING: + if hasattr(args, name): + delattr(args, name) + else: + setattr(args, name, old_value) + + +async def _post_generate( + args: Namespace, + url: str, + payload: dict[str, Any], + *, + headers: dict | None, + sample: Sample, +): + request = { + "url": url, + "payload": payload, + "headers": headers, + "max_retries": 60, + "retry_sleep": 1.0, + "rollout_id": getattr(args, "_rollout_request_rollout_id", None), + "evaluation": getattr(args, "_rollout_request_evaluation", False), + } + + if (hook_path := getattr(args, "custom_rollout_request_hook_path", None)) is not None: + hook = load_function(hook_path) + result = hook(args, sample, request) + if inspect.isawaitable(result): + result = await result + if result is not None: + if not isinstance(result, dict): + raise TypeError( + f"{hook_path} must return None or a dict of request updates, got {type(result).__name__}" + ) + request.update(result) + + return await post( + request["url"], + request["payload"], + max_retries=request["max_retries"], + headers=request["headers"], + retry_sleep=request["retry_sleep"], + ) + + class GenerateState(metaclass=SingletonMeta): """ The global state for the generation process. @@ -153,7 +222,7 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A assert isinstance(sample.prompt, str) state = GenerateState(args) - url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + url = get_model_url(args, "default", "/generate") assert ( sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED @@ -196,7 +265,7 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A headers = {"X-SMG-Routing-Key": sample.session_id} with trace_span(sample, "sglang_generate", attrs={"max_new_tokens": sampling_params["max_new_tokens"]}) as span: - output = await post(url, payload, headers=headers) + output = await _post_generate(args, url, payload, headers=headers, sample=sample) span.update(build_sglang_meta_trace_attrs(output["meta_info"])) if "output_token_logprobs" in output["meta_info"]: @@ -348,12 +417,13 @@ async def generate_and_rm_group( async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]: - aborted_samples = [] - state = GenerateState(args) assert not state.aborted state.aborted = True + if getattr(args, "rollout_http_endpoint_abort_strategy", None) == "cancel-only": + return await _cancel_pending_tasks(state) + if parse(sglang_router.__version__) <= parse("0.2.1"): response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") urls = response["urls"] @@ -361,12 +431,31 @@ async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]: response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") urls = [worker["url"] for worker in response["workers"]] - logger.info(f"Abort request for {urls}") - abort_tasks = [post(f"{url}/abort_request", {"abort_all": True}) for url in urls] - abort_results = await asyncio.gather(*abort_tasks, return_exceptions=True) - for url, result in zip(urls, abort_results, strict=False): + await abort_servers_until_idle(urls) + + return await _drain_aborted_pending_tasks(args, rollout_id, state) + + +async def _cancel_pending_tasks(state: GenerateState) -> list[list[Sample]]: + if not state.pendings: + return [] + pending = list(state.pendings) + for task in pending: + task.cancel() + results = await asyncio.gather(*pending, return_exceptions=True) + for result in results: if isinstance(result, Exception): - logger.warning(f"Failed to abort worker at {url}: {result}") + logger.warning("Pending rollout task ended during cancel-only abort: %s", result) + state.pendings.clear() + return [] + + +async def _drain_aborted_pending_tasks( + args: Namespace, + rollout_id: int, + state: GenerateState, +) -> list[list[Sample]]: + aborted_samples = [] # make sure all the pending tasks are finished count = 0 @@ -619,11 +708,12 @@ def generate_rollout( RolloutFnTrainOutput | RolloutFnEvalOutput: the output of the rollout """ assert args.rollout_global_dataset - if evaluation: - output, _ = run(eval_rollout(args, rollout_id)) + with rollout_request_context(args, rollout_id, evaluation=evaluation): + if evaluation: + output, _ = run(eval_rollout(args, rollout_id)) + return output + + output, aborted_samples = run(generate_rollout_async(args, rollout_id, data_source.get_samples)) + if aborted_samples: + data_source.add_samples(aborted_samples) return output - - output, aborted_samples = run(generate_rollout_async(args, rollout_id, data_source.get_samples)) - if aborted_samples: - data_source.add_samples(aborted_samples) - return output diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 302eba99af..0713b8654b 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -7,11 +7,11 @@ from typing import Any import yaml -from sglang_router.launch_router import RouterArgs from slime.backends.sglang_utils.arguments import sglang_parse_args from slime.backends.sglang_utils.arguments import validate_args as sglang_validate_args from slime.backends.sglang_utils.external import apply_external_engine_info_to_args +from slime.backends.sglang_utils.http_endpoint import normalize_rollout_http_endpoint_url from slime.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from slime.utils.logging_utils import configure_logger @@ -49,8 +49,9 @@ def add_cluster_arguments(parser): default=None, help=( "Number of GPUs for inference. Note that when using --colocate, " - "i.e. the training and the inference engines are on the same gpus, this param will be ignored and will be set as " - "actor_num_gpus_per_node * actor_num_nodes." + "i.e. the training and the inference engines are on the same gpus, this param will be set as " + "actor_num_gpus_per_node * actor_num_nodes unless it is explicitly set. " + "Set it to 0 to launch routers without local SGLang engines." ), ) parser.add_argument( @@ -202,6 +203,15 @@ def add_train_arguments(parser): "release. Prefer the transport-level directory flag for both full and delta disk sync." ), ) + parser.add_argument( + "--update-weight-delta-root", + type=str, + default=None, + help=( + "Optional root directory for publish-based disk delta metadata. " + "Defaults to the parent of --update-weight-delta-dir when omitted." + ), + ) parser.add_argument( "--update-weight-delta-keep-files", action="store_true", @@ -219,6 +229,43 @@ def add_train_arguments(parser): "Called from every trainer rank; the hook gates itself." ), ) + parser.add_argument( + "--custom-delta-publish-path", + type=str, + default=None, + help=( + "Path to a custom rank-0 function called after disk delta filenames are gathered " + "and the pre-push hook has completed. Signature: " + "``def hook(args, version_dir: str, files: list[str], weight_version: str, " + "rollout_engines) -> list | None``. Returned Ray ObjectRefs are awaited before " + "the sync completes. With --update-weight-delta-publish-only, " + "--update-weight-delta-publish-wait controls whether this happens in the same " + "sync or at the start of the next sync." + ), + ) + parser.add_argument( + "--update-weight-delta-publish-wait", + type=str, + choices=["next-sync", "sync"], + default="next-sync", + help=( + "When --update-weight-delta-publish-only is set, choose when rank 0 waits for " + "--custom-delta-publish-path to finish. 'next-sync' pipelines publish work " + "across the next training step and surfaces failures one sync late. 'sync' " + "blocks update_weights until the publish hook returns, useful when the hook " + "polls rollout-fleet readiness before allowing the next rollout dispatch." + ), + ) + parser.add_argument( + "--update-weight-delta-publish-only", + action="store_true", + default=False, + help=( + "For disk delta transport, publish gathered delta files through " + "--custom-delta-publish-path without issuing direct rollout-engine update RPCs. " + "Useful for elastic HTTP rollout endpoints that consume published versions." + ), + ) parser.add_argument( "--custom-model-provider-path", type=str, @@ -527,6 +574,17 @@ def add_rollout_arguments(parser): "It may be helpful for updating loss mask." ), ) + parser.add_argument( + "--rollout-data-transport", + type=str, + choices=["object-store", "nixl"], + default="object-store", + help=( + "Transport for rollout data refs sent from rollout manager to trainer. Large rollout " + "fields are tensorized on CPU before the refs are stored. Set to nixl to transfer " + "those torch tensors via Ray NIXL." + ), + ) parser.add_argument( "--rollout-external-engine-addrs", type=str, @@ -534,6 +592,38 @@ def add_rollout_arguments(parser): nargs="+", help="Address and ports of the external engines.", ) + parser.add_argument( + "--rollout-http-endpoint-url", + type=str, + default=None, + help=( + "Opaque HTTP endpoint base URL for rollout generation. " + "When set, slime sends /generate requests to this endpoint " + "without launching or registering SGLang workers." + ), + ) + parser.add_argument( + "--rollout-http-endpoint-abort-strategy", + type=str, + choices=["cancel-only", "router-workers"], + default=None, + help=( + "Abort behavior for the default SGLang rollout. " + "'cancel-only' cancels local pending tasks and does not call router /workers; " + "'router-workers' uses the SGLang router worker list." + ), + ) + parser.add_argument( + "--custom-rollout-request-hook-path", + type=str, + default=None, + help=( + "Path to a hook called before each default SGLang rollout /generate request. " + "Signature: ``def hook(args, sample, request) -> None | dict``. " + "The request dict contains url, payload, headers, max_retries, retry_sleep, " + "rollout_id, and evaluation. Mutate it in place or return a dict of updates." + ), + ) return parser def add_fault_tolerance_arguments(parser): @@ -906,6 +996,7 @@ def add_algo_arguments(parser): choices=[ "grpo", "gspo", + "cispo", "reinforce_plus_plus", "reinforce_plus_plus_baseline", "ppo", @@ -1110,16 +1201,6 @@ def add_on_policy_distillation_arguments(parser): ) return parser - def add_router_arguments(parser): - parser.add_argument( - "--use-slime-router", - action="store_true", - default=False, - help="Whether to use SlimeRouter for text-based routing instead of SGLang token-based routing", - ) - RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True) - return parser - # wandb def add_wandb_arguments(parser): # wandb parameters @@ -1481,7 +1562,6 @@ def add_ci_arguments(parser): parser = add_on_policy_distillation_arguments(parser) parser = add_wandb_arguments(parser) parser = add_tensorboard_arguments(parser) - parser = add_router_arguments(parser) parser = add_debug_arguments(parser) parser = add_network_arguments(parser) parser = add_reward_model_arguments(parser) @@ -1728,6 +1808,9 @@ def _resolve_update_weight_disk_dir(args) -> None: def _validate_update_weight_args(args) -> None: _resolve_update_weight_disk_dir(args) + if args.update_weight_delta_publish_only and args.update_weight_mode != "delta": + raise ValueError("--update-weight-delta-publish-only requires --update-weight-mode=delta.") + if args.update_weight_mode == "delta": if args.update_weight_transport not in ("nccl", "disk"): raise ValueError( @@ -1740,18 +1823,20 @@ def _validate_update_weight_args(args) -> None: "weights via CUDA IPC (only a handle crosses processes), so the delta bookkeeping " "(snapshot + diff + sparse encode) is pure overhead." ) + if args.update_weight_transport == "disk" and args.update_weight_delta_root is None: + args.update_weight_delta_root = os.path.dirname(os.path.abspath(args.update_weight_disk_dir)) + if args.update_weight_delta_publish_only: + if args.update_weight_transport != "disk": + raise ValueError("--update-weight-delta-publish-only requires --update-weight-transport=disk.") + if not args.custom_delta_publish_path: + raise ValueError("--update-weight-delta-publish-only requires --custom-delta-publish-path.") + if not args.update_weight_delta_keep_files: + raise ValueError("--update-weight-delta-publish-only requires --update-weight-delta-keep-files.") def slime_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) - if args.use_slime_router: - logger.warning( - "--use-slime-router is deprecated and ignored. slime now always uses sglang_router " - "built from https://github.com/zhuzilin/sgl-router." - ) - args.use_slime_router = False - if args.kl_coef != 0 or args.use_kl_loss: if not os.path.exists(args.ref_load): raise FileNotFoundError(f"ref_load {args.ref_load} does not exist, please check the path.") @@ -1860,6 +1945,14 @@ def slime_validate_args(args): if args.eps_clip_high is None: args.eps_clip_high = args.eps_clip + if args.advantage_estimator == "cispo" and args.eps_clip < 1.0: + logger.warning( + "CISPO is canonically single-sided, but --eps-clip=%s keeps the lower clip bound %s active. " + "Set --eps-clip 1.0 (and tune --eps-clip-high, e.g. 4.0) for the canonical wide setting.", + args.eps_clip, + 1.0 - args.eps_clip, + ) + if args.eval_reward_key is None: args.eval_reward_key = args.reward_key @@ -1874,6 +1967,18 @@ def slime_validate_args(args): ) args.debug_train_only = True + if args.rollout_http_endpoint_url is not None: + args.rollout_http_endpoint_url = normalize_rollout_http_endpoint_url(args.rollout_http_endpoint_url) + if args.rollout_http_endpoint_abort_strategy is None: + args.rollout_http_endpoint_abort_strategy = "cancel-only" + if getattr(args, "rollout_num_engines", None) is None: + args.rollout_num_engines = 1 + elif args.rollout_http_endpoint_abort_strategy is None: + args.rollout_http_endpoint_abort_strategy = "router-workers" + + if args.rollout_http_endpoint_url is not None and args.rollout_external_engine_addrs is not None: + raise ValueError("--rollout-http-endpoint-url and --rollout-external-engine-addrs are mutually exclusive.") + args.rollout_external = args.rollout_external_engine_addrs is not None if args.rollout_external and not args.debug_train_only: @@ -1890,8 +1995,11 @@ def slime_validate_args(args): del args.offload if args.debug_rollout_only: - if args.colocate and (not args.rollout_num_gpus): + if args.colocate and args.rollout_num_gpus is None: args.rollout_num_gpus = args.actor_num_gpus_per_node * args.actor_num_nodes + elif args.rollout_num_gpus == 0: + args.actor_num_gpus_per_node = 0 + args.actor_num_nodes = 0 else: args.actor_num_gpus_per_node = min(8, args.rollout_num_gpus) args.actor_num_nodes = args.rollout_num_gpus // args.actor_num_gpus_per_node @@ -1911,12 +2019,10 @@ def slime_validate_args(args): args.offload_train = True if args.offload_rollout is None: args.offload_rollout = True - if args.rollout_num_gpus != args.actor_num_gpus_per_node * args.actor_num_nodes: - logger.info( - f"rollout_num_gpus {args.rollout_num_gpus} != actor_num_gpus_per_node {args.actor_num_gpus_per_node} " - f"* actor_num_nodes {args.actor_num_nodes}, overriding rollout_num_gpus to match actor_num_gpus_per_node * actor_num_nodes." - ) + if args.rollout_num_gpus is None: args.rollout_num_gpus = args.actor_num_gpus_per_node * args.actor_num_nodes + elif args.rollout_num_gpus == 0: + logger.info("rollout_num_gpus is 0 under colocate; no local SGLang engines will be launched.") if args.offload_train is None: args.offload_train = False diff --git a/slime/utils/external_utils/command_utils.py b/slime/utils/external_utils/command_utils.py index af32de59fe..2f003e1f11 100644 --- a/slime/utils/external_utils/command_utils.py +++ b/slime/utils/external_utils/command_utils.py @@ -138,6 +138,7 @@ def execute_train( { "env_vars": { "PYTHONPATH": "/root/Megatron-LM/", + "RAY_USE_UVLOOP": "0", "CUDA_DEVICE_MAX_CONNECTIONS": "1", "NCCL_NVLS_ENABLE": str(int(check_has_nvlink())), "no_proxy": f"127.0.0.1,{master_addr}", diff --git a/slime/utils/http_utils.py b/slime/utils/http_utils.py index ced387e573..ae3e57565c 100644 --- a/slime/utils/http_utils.py +++ b/slime/utils/http_utils.py @@ -162,7 +162,7 @@ def _next_actor(): return actor -async def _post(client, url, payload, max_retries=60, headers=None): +async def _post(client, url, payload, max_retries=60, headers=None, retry_sleep: float = 1.0): retry_count = 0 while retry_count < max_retries: response = None @@ -188,7 +188,7 @@ async def _post(client, url, payload, max_retries=60, headers=None): if retry_count >= max_retries: logger.info(f"Max retries ({max_retries}) reached, failing... (url={url})") raise e - await asyncio.sleep(1) + await asyncio.sleep(retry_sleep) continue finally: if response is not None: @@ -244,6 +244,8 @@ def _init_ray_distributed_post(args): import ray from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + from slime.ray.utils import add_default_ray_env_vars + # Discover alive nodes nodes = [n for n in ray.nodes() if n.get("Alive")] if not nodes: @@ -260,8 +262,8 @@ def __init__(self, concurrency: int): trust_env=False, # internal SGLang comm only — never route through system proxy ) - async def do_post(self, url, payload, max_retries=60, headers=None): - return await _post(self._client, url, payload, max_retries, headers=headers) + async def do_post(self, url, payload, max_retries=60, headers=None, retry_sleep: float = 1.0): + return await _post(self._client, url, payload, max_retries, headers=headers, retry_sleep=retry_sleep) # Create actors per node created = [] @@ -275,6 +277,7 @@ async def do_post(self, url, payload, max_retries=60, headers=None): actor = _HttpPosterActor.options( name=None, lifetime="detached", + runtime_env={"env_vars": add_default_ray_env_vars()}, scheduling_strategy=scheduling, max_concurrency=per_actor_conc, # Use tiny CPU to schedule @@ -285,7 +288,7 @@ async def do_post(self, url, payload, max_retries=60, headers=None): _post_actors = created -async def post(url, payload, max_retries=60, headers=None): +async def post(url, payload, max_retries=60, headers=None, retry_sleep: float = 1.0): # If distributed mode is enabled and actors exist, dispatch via Ray. if _distributed_post_enabled and _post_actors: try: @@ -297,13 +300,13 @@ async def post(url, payload, max_retries=60, headers=None): # `min(32, cpu+4)`), which becomes a hard upper bound on the # number of in-flight POSTs that can be waited on in parallel # and produces large tail latencies under high concurrency. - obj_ref = actor.do_post.remote(url, payload, max_retries, headers=headers) + obj_ref = actor.do_post.remote(url, payload, max_retries, headers=headers, retry_sleep=retry_sleep) return await obj_ref except Exception as e: logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") # fall through to local - return await _post(_http_client, url, payload, max_retries, headers=headers) + return await _post(_http_client, url, payload, max_retries, headers=headers, retry_sleep=retry_sleep) async def get(url): diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 2a858e7a3f..327dec2de6 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -148,6 +148,29 @@ def compute_policy_loss( return pg_losses, clipfrac +@torch.compile(dynamic=True) +def compute_cispo_loss( + ppo_kl: torch.Tensor, + log_probs: torch.Tensor, + advantages: torch.Tensor, + eps_clip: float, + eps_clip_high: float, +): + """CISPO loss from MiniMax-M1 (https://arxiv.org/abs/2506.13585, Eq. 4-5): + ``-sg(clip(ratio, 1 - eps_clip, 1 + eps_clip_high)) * advantages * log_probs``. + + Unlike PPO, the IS ratio is clipped under stop-gradient and the gradient flows + through ``log_probs``, so clipped tokens still contribute gradient. The bounds + reuse the delta-from-1 convention of ``compute_policy_loss``; canonical CISPO + disables the lower bound (``eps_clip >= 1.0``). + """ + ratio = (-ppo_kl).exp() + ratio_truncated = torch.clamp(ratio, min=1.0 - eps_clip, max=1.0 + eps_clip_high) + pg_losses = -ratio_truncated.detach() * advantages * log_probs + clipfrac = (ratio_truncated != ratio).float() + return pg_losses, clipfrac + + def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: dist.ProcessGroup | None): # TODO: when megatron is not installed, fall back to naive implementation from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy diff --git a/slime/utils/types.py b/slime/utils/types.py index 54092680e2..2052ad5269 100644 --- a/slime/utils/types.py +++ b/slime/utils/types.py @@ -24,6 +24,7 @@ class Sample: tokens: list[int] = field(default_factory=list) multimodal_inputs: dict[str, Any] | None = None # raw multimodal data, e.g. images, videos, etc. multimodal_train_inputs: dict[str, Any] | None = None # processed multimodal data, e.g. pixel_values, etc. + apply_chat_template_kwargs: dict = field(default_factory=dict) # response response: str = "" response_length: int = 0 @@ -50,6 +51,7 @@ class Status(Enum): metadata: dict = field(default_factory=dict) generate_function_path: str | None = None + custom_rm_path: str | None = None # metadata used during training, e.g., what loss to use for this sample. train_metadata: dict | None = None diff --git a/slime_plugins/models/glm5/glm5.py b/slime_plugins/models/glm5/glm5.py index 273b5a7452..c45ae4f443 100644 --- a/slime_plugins/models/glm5/glm5.py +++ b/slime_plugins/models/glm5/glm5.py @@ -29,6 +29,28 @@ from .ops.indexer import generate_varlen_mask_params, lighting_indexer from .ops.sparse_mla import SparseMLA +# Names of the indexer submodules. On a DSA model with *cross-layer index +# sharing* these only exist on "computing" layers; "skip" layers drop them. +_INDEXER_SUBMODULE_NAMES = ("wq_b", "wk", "k_norm", "weights_proj") + + +def is_skip_topk_layer(layer_number: int, skip_topk_offset: int, topk_freq: int) -> bool: + """Whether the (1-indexed) Megatron ``layer_number`` reuses a previous layer's top-k. + + Mirrors ``glm-train-prod``'s ``_get_skip_topk_flags``: a layer *computes* its + own top-k when ``max(layer_number - offset, 0) % freq == 0``; otherwise it is a + skip layer that reuses the most recent computing layer's indices. + """ + return (max(layer_number - skip_topk_offset, 0) % topk_freq) != 0 + + +def source_compute_layer(layer_number: int, skip_topk_offset: int, topk_freq: int) -> int: + """The computing layer whose ``topk_indices`` a skip layer reuses.""" + layer = layer_number + while is_skip_topk_layer(layer, skip_topk_offset, topk_freq): + layer -= 1 + return layer + @dataclass class DSASelfAttentionSubmodules: @@ -137,6 +159,30 @@ def __init__( self.index_topk = 2048 + # Cross-layer index sharing (optional). When the HF config provides + # ``index_topk_freq`` / ``index_skip_topk_offset`` (see ``get_glm5_spec``), + # only a subset of "computing" layers run the indexer top-k; the remaining + # "skip" layers reuse the most recent computing layer's ``topk_indices``. + # When those attrs are absent (``freq`` defaults to 1) every layer computes + # its own top-k and ``skip_topk`` is always False -- i.e. the plain DSA path. + self.index_topk_freq = getattr(config, "index_topk_freq", 1) or 1 + self.skip_topk_offset = getattr(config, "index_skip_topk_offset", 0) or 0 + self.index_share = self.index_topk_freq > 1 + self.skip_topk = self.index_share and is_skip_topk_layer( + layer_number, self.skip_topk_offset, self.index_topk_freq + ) + self._source_layer = ( + source_compute_layer(layer_number, self.skip_topk_offset, self.index_topk_freq) + if self.index_share + else layer_number + ) + + # Attribute name of the per-microbatch top-k holder we attach to the + # ``packed_seq_params`` object (a plain dict: source layer_number -> topk_indices). + # Used only on index-share models; see ``forward`` for why it lives on + # ``packed_seq_params`` (per-microbatch isolation + recompute safety under PP). + _HOLDER_ATTR = "_dsa_index_share_topk_holder" + def forward( self, hidden_states, @@ -204,13 +250,51 @@ def fused_select_topk(index_q, index_k, w, starts, ends, block_size=8192): topk_indices.append(topk_indices_block) return torch.cat(indexer_topk_scores, dim=0), torch.cat(topk_indices, dim=0).unsqueeze(1) - starts, ends = generate_varlen_mask_params(packed_seq_params.cu_seqlens_q) - index_key = index_key.squeeze(1) - head_weights = head_weights.unsqueeze(-1) + if self.index_share: + # Cross-layer index sharing. The top-k holder lives on the per-microbatch + # ``packed_seq_params`` object: it is constructed fresh per microbatch in + # ``get_batch`` and is closure-captured by Megatron's activation-checkpoint + # ``custom_forward``, so the same instance is reused at recompute time. + # That gives per-microbatch isolation (no cross-microbatch clobber under + # PP 1F1B) AND recompute safety (the computing layer's entry written in the + # original forward is still present when a skip layer's chunk recomputes). + # Note: this never crosses a PP boundary -- a stage always starts on a + # computing layer (asserted in ``get_glm5_spec``), so a skip layer's source + # is always in-stage. + holder = getattr(packed_seq_params, self._HOLDER_ATTR, None) + if holder is None: + holder = {} + setattr(packed_seq_params, self._HOLDER_ATTR, holder) + + if self.skip_topk: + if self._source_layer not in holder: + raise AssertionError( + "DSA index-share: skip layer " + f"(layer_number={self.layer_number}) needs the top-k of its source " + f"computing layer (layer_number={self._source_layer}), but that layer " + "did not run in this pipeline stage's forward. Cross-PP top-k sharing " + "is not supported; ensure every pipeline stage starts on a computing " + f"layer (index_topk_freq={self.index_topk_freq}, " + f"index_skip_topk_offset={self.skip_topk_offset}). " + f"Holder has layers {sorted(holder)}." + ) + topk_indices = holder[self._source_layer] + else: + starts, ends = generate_varlen_mask_params(packed_seq_params.cu_seqlens_q) + index_key = index_key.squeeze(1) + head_weights = head_weights.unsqueeze(-1) + starts = scatter_to_sequence_parallel_region(starts, group=parallel_state.get_context_parallel_group()) + ends = scatter_to_sequence_parallel_region(ends, group=parallel_state.get_context_parallel_group()) + _, topk_indices = fused_select_topk(index_query, index_key, head_weights, starts, ends) + holder[self.layer_number] = topk_indices + else: + starts, ends = generate_varlen_mask_params(packed_seq_params.cu_seqlens_q) + index_key = index_key.squeeze(1) + head_weights = head_weights.unsqueeze(-1) - starts = scatter_to_sequence_parallel_region(starts, group=parallel_state.get_context_parallel_group()) - ends = scatter_to_sequence_parallel_region(ends, group=parallel_state.get_context_parallel_group()) - _, topk_indices = fused_select_topk(index_query, index_key, head_weights, starts, ends) + starts = scatter_to_sequence_parallel_region(starts, group=parallel_state.get_context_parallel_group()) + ends = scatter_to_sequence_parallel_region(ends, group=parallel_state.get_context_parallel_group()) + _, topk_indices = fused_select_topk(index_query, index_key, head_weights, starts, ends) core_attn_out, _ = SparseMLA.apply(q, kv, topk_indices, self.softmax_scale) core_attn_out = torch.einsum("thm,hdm->thd", core_attn_out, wv) @@ -403,6 +487,15 @@ def __init__( ) self.weights_proj.weight._skip_gather = True + # Index-share skip layers carry no indexer weights -- drop the modules the + # base path built unconditionally so the parameter set matches the + # checkpoint (which only stores indexer weights on computing layers) and so + # weight export to HF naturally omits them on skip layers. + if self.skip_topk: + for name in _INDEXER_SUBMODULE_NAMES: + if hasattr(self, name): + delattr(self, name) + def get_absorb_query_key_value_tensors( self, hidden_states, @@ -427,8 +520,11 @@ def get_absorb_query_key_value_tensors( rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( inference_context, None, hidden_states, self.config, packed_seq_params ) - # TODO: support apply_rope_fusion - rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len, packed_seq_params=packed_seq_params) + # YarnRotaryEmbedding/RotaryEmbedding.forward is wrapped in lru_cache, so it + # only accepts hashable args: pass the packed-sequence flag, not the + # (unhashable) PackedSeqParams object. + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == "thd" + rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) cu_seqlens_q = packed_seq_params.cu_seqlens_q cu_seqlens_kv = packed_seq_params.cu_seqlens_kv @@ -508,6 +604,11 @@ def fuse_rope(q, cu_seqlens, gathered=False): query = query.contiguous() key = key.contiguous() + if self.skip_topk: + # Index-share skip layer: reuse a previous layer's top-k, so the indexer + # projections are not run here. Return None for the index tensors. + return query, key, w_vc, None, None, None + # ========================================= # Indexer # ========================================= @@ -607,6 +708,13 @@ def get_glm5_spec(args, config, vp_stage): hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True) config.index_num_attention_heads = hf_config.index_n_heads config.index_head_dim = hf_config.index_head_dim + # Optional cross-layer index-sharing schedule. Present on DSA checkpoints that + # only store indexer weights on a subset of "computing" layers (e.g. GLM5.2 / + # DeepSeek-V3.2 750B). When absent, every layer computes its own top-k (plain + # DSA) and ``DSAMLASelfAttention`` runs the non-shared path. + config.index_topk_freq = getattr(hf_config, "index_topk_freq", 1) or 1 + config.index_skip_topk_offset = getattr(hf_config, "index_skip_topk_offset", 0) or 0 + # Define the decoder block spec kwargs = { "use_transformer_engine": True, @@ -615,6 +723,32 @@ def get_glm5_spec(args, config, vp_stage): kwargs["vp_stage"] = vp_stage transformer_layer_spec = get_gpt_decoder_block_spec(config, **kwargs) num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage) + + # Cross-layer index sharing keeps the shared top-k in a per-microbatch holder + # on ``packed_seq_params``, which does not cross PP boundaries. A skip layer + # therefore must run in the same pipeline stage as the computing layer it + # reuses. If a (virtual) pipeline stage *starts* with a skip layer, its source + # computing layer lives on a previous stage and the lookup would miss. Forbid + # that split here (supporting it would need PP send/recv of the top-k). + if config.index_topk_freq > 1: + from megatron.core.transformer.transformer_block import get_transformer_layer_offset + + layer_offset = get_transformer_layer_offset(config, vp_stage=vp_stage) + for local_id in range(num_layers_to_build): + layer_number = local_id + layer_offset + 1 # Megatron layer_number is 1-indexed + if local_id == 0 and is_skip_topk_layer( + layer_number, config.index_skip_topk_offset, config.index_topk_freq + ): + src = source_compute_layer(layer_number, config.index_skip_topk_offset, config.index_topk_freq) + raise AssertionError( + "DSA index-share pipeline split is invalid: this stage starts at global " + f"layer_number={layer_number} which is a skip layer whose source computing " + f"layer={src} is on a previous pipeline stage. Cross-layer top-k sharing does " + "not cross PP boundaries. Choose a pipeline layout where every stage begins on " + "a computing layer (index_topk_freq=" + f"{config.index_topk_freq}, index_skip_topk_offset={config.index_skip_topk_offset})." + ) + backend = TESpecProvider() self_attn_module_spec = ModuleSpec( diff --git a/tests/plugin_contracts/test_plugin_runtime_hook_contracts.py b/tests/plugin_contracts/test_plugin_runtime_hook_contracts.py index a8380feecc..3fb6d365fc 100644 --- a/tests/plugin_contracts/test_plugin_runtime_hook_contracts.py +++ b/tests/plugin_contracts/test_plugin_runtime_hook_contracts.py @@ -37,6 +37,7 @@ def run_contract_test_file() -> None: "custom-reward-post-process-path", "custom-convert-samples-to-train-data-path", "rollout-data-postprocess-path", + "custom-rollout-request-hook-path", ], ) @@ -73,6 +74,11 @@ def reference_rollout_data_postprocess(args, rollout_id, rollout_data) -> None: args.rollout_data_postprocess_called = True +def reference_rollout_request_hook(args, sample, request) -> None: + args.rollout_request_hook_called = True + request["payload"]["hooked"] = sample.index + + def make_sample(index: int, reward: float = 1.0) -> Sample: return Sample( index=index, @@ -128,6 +134,14 @@ def invoke_rollout_data_postprocess(fn): assert args.rollout_data_postprocess_called is True +def invoke_rollout_request_hook(fn): + args = type("Args", (), {})() + request = {"payload": {}} + assert fn(args, Sample(index=7), request) is None + assert args.rollout_request_hook_called is True + assert request["payload"]["hooked"] == 7 + + HOOK_CASES = [ HookCase( "custom_rollout_log", @@ -174,6 +188,15 @@ def invoke_rollout_data_postprocess(fn): ("args", "rollout_id", "rollout_data"), invoke_rollout_data_postprocess, ), + HookCase( + "rollout_request_hook", + "CUSTOM_ROLLOUT_REQUEST_HOOK_PATH", + "plugin_contracts.test_plugin_runtime_hook_contracts.reference_rollout_request_hook", + "slime/rollout/sglang_rollout.py", + "hook(args, sample, request)", + ("args", "sample", "request"), + invoke_rollout_request_hook, + ), ] diff --git a/tests/test_cispo_loss.py b/tests/test_cispo_loss.py new file mode 100644 index 0000000000..9f2e86c89c --- /dev/null +++ b/tests/test_cispo_loss.py @@ -0,0 +1,52 @@ +"""CPU tests for compute_cispo_loss (MiniMax-M1, https://arxiv.org/abs/2506.13585).""" + +import math + +import pytest +import torch + +from slime.utils.ppo_utils import compute_cispo_loss + +NUM_GPUS = 0 + +ADVANTAGES = torch.tensor([1.0, -0.5, 2.0, -1.0]) +LOG_PROBS = torch.tensor([-0.7, -1.2, -0.4, -2.1]) + +# (eps_clip, eps_clip_high, raw IS ratios, ratios after clamp to [1 - eps_clip, 1 + eps_clip_high]) +CLIP_CASES = [ + pytest.param(0.2, 0.28, [1.0, 1.14, 1.56, 0.4], [1.0, 1.14, 1.28, 0.8], id="ppo_band"), + pytest.param(1.0, 4.0, [1.0, 3.0, 9.0, 0.4], [1.0, 3.0, 5.0, 0.4], id="wide_minimax_band"), +] + + +@pytest.mark.parametrize("eps_clip, eps_clip_high, ratios, clamped", CLIP_CASES) +def test_compute_cispo_loss_matches_closed_form_surrogate(eps_clip, eps_clip_high, ratios, clamped): + ppo_kl = -torch.tensor([math.log(r) for r in ratios]) + + pg_losses, clipfrac = compute_cispo_loss(ppo_kl, LOG_PROBS, ADVANTAGES, eps_clip, eps_clip_high) + + expected_losses = -torch.tensor(clamped) * ADVANTAGES * LOG_PROBS + torch.testing.assert_close(pg_losses, expected_losses, rtol=1e-6, atol=1e-6) + expected_clipfrac = torch.tensor([float(c != r) for c, r in zip(clamped, ratios, strict=True)]) + torch.testing.assert_close(clipfrac, expected_clipfrac) + + +@pytest.mark.parametrize("eps_clip, eps_clip_high, ratios, clamped", CLIP_CASES) +def test_compute_cispo_loss_gradient_flows_only_through_log_probs(eps_clip, eps_clip_high, ratios, clamped): + # ratio = exp(-ppo_kl) = exp(log_ratios): if CISPO failed to stop-gradient the + # clipped IS ratio, backward would populate log_ratios.grad. + log_ratios = torch.tensor([math.log(r) for r in ratios], requires_grad=True) + ppo_kl = -log_ratios + log_probs = LOG_PROBS.clone().requires_grad_() + + pg_losses, _ = compute_cispo_loss(ppo_kl, log_probs, ADVANTAGES, eps_clip, eps_clip_high) + pg_losses.sum().backward() + + torch.testing.assert_close(log_probs.grad, -torch.tensor(clamped) * ADVANTAGES, rtol=1e-6, atol=1e-6) + assert log_ratios.grad is None or torch.all( + log_ratios.grad == 0 + ), f"CISPO must stop-gradient on the IS ratio; log_ratios.grad={log_ratios.grad}" + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/tests/test_delta_publish_only.py b/tests/test_delta_publish_only.py new file mode 100644 index 0000000000..62fbde37db --- /dev/null +++ b/tests/test_delta_publish_only.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +import os +import sys +import types +from argparse import Namespace +from dataclasses import dataclass +from enum import Enum +from pathlib import Path + +import pytest + +torch = pytest.importorskip("torch") + + +def _install_import_stubs() -> None: + if "safetensors.torch" not in sys.modules: + try: + import safetensors.torch # noqa: F401 + except ImportError: + safetensors = types.ModuleType("safetensors") + safetensors_torch = types.ModuleType("safetensors.torch") + safetensors_torch.save = lambda tensors, metadata=None: b"" + safetensors.torch = safetensors_torch + sys.modules["safetensors"] = safetensors + sys.modules["safetensors.torch"] = safetensors_torch + + if "ray" not in sys.modules: + ray = types.ModuleType("ray") + actor = types.ModuleType("ray.actor") + + class ActorHandle: + pass + + class ObjectRef: + pass + + actor.ActorHandle = ActorHandle + ray.actor = actor + ray.ObjectRef = ObjectRef + ray.get = lambda refs: refs + sys.modules["ray"] = ray + sys.modules["ray.actor"] = actor + + if "megatron" not in sys.modules: + megatron = types.ModuleType("megatron") + core = types.ModuleType("megatron.core") + mpu = types.ModuleType("megatron.core.mpu") + mpu.get_data_parallel_rank = lambda with_context_parallel=False: 0 + mpu.get_tensor_model_parallel_rank = lambda: 0 + mpu.get_pipeline_model_parallel_rank = lambda: 0 + mpu.get_expert_model_parallel_world_size = lambda: 1 + mpu.get_expert_model_parallel_group = lambda: None + mpu.get_expert_tensor_parallel_world_size = lambda: 1 + mpu.get_expert_tensor_parallel_group = lambda: None + mpu.get_tensor_model_parallel_world_size = lambda: 1 + mpu.get_tensor_model_parallel_group = lambda: None + mpu.get_expert_model_parallel_rank = lambda: 0 + transformer = types.ModuleType("megatron.core.transformer") + transformer_layer = types.ModuleType("megatron.core.transformer.transformer_layer") + transformer_layer.get_transformer_layer_offset = lambda config, *args, **kwargs: 0 + core.mpu = mpu + core.transformer = transformer + megatron.core = core + sys.modules["megatron"] = megatron + sys.modules["megatron.core"] = core + sys.modules["megatron.core.mpu"] = mpu + sys.modules["megatron.core.transformer"] = transformer + sys.modules["megatron.core.transformer.transformer_layer"] = transformer_layer + + megatron_to_hf = types.ModuleType("slime.backends.megatron_utils.megatron_to_hf") + megatron_to_hf.convert_to_hf = lambda args, model_name, name, param, quantization_config: [(name, param)] + sys.modules.setdefault("slime.backends.megatron_utils.megatron_to_hf", megatron_to_hf) + + if "sglang" not in sys.modules: + sglang = types.ModuleType("sglang") + srt = types.ModuleType("sglang.srt") + sys.modules["sglang"] = sglang + sys.modules["sglang.srt"] = srt + + if "sglang.srt.layers.quantization.fp8_utils" not in sys.modules: + fp8_utils = types.ModuleType("sglang.srt.layers.quantization.fp8_utils") + fp8_utils.quant_weight_ue8m0 = None + fp8_utils.transform_scale_ue8m0 = None + sys.modules["sglang.srt.layers"] = types.ModuleType("sglang.srt.layers") + sys.modules["sglang.srt.layers.quantization"] = types.ModuleType("sglang.srt.layers.quantization") + sys.modules["sglang.srt.layers.quantization.fp8_utils"] = fp8_utils + + if "sglang.srt.model_loader.utils" not in sys.modules: + model_loader_utils = types.ModuleType("sglang.srt.model_loader.utils") + model_loader_utils.should_deepgemm_weight_requant_ue8m0 = None + sys.modules["sglang.srt.model_loader"] = types.ModuleType("sglang.srt.model_loader") + sys.modules["sglang.srt.model_loader.utils"] = model_loader_utils + + utils = sys.modules.get("sglang.srt.utils") + if utils is None: + utils = types.ModuleType("sglang.srt.utils") + utils.__path__ = [] + sys.modules["sglang.srt.utils"] = utils + utils.MultiprocessingSerializer = object + + patch_torch = types.ModuleType("sglang.srt.utils.patch_torch") + patch_torch.monkey_patch_torch_reductions = lambda: None + sys.modules.setdefault("sglang.srt.utils.patch_torch", patch_torch) + sys.modules.setdefault("sglang.srt.patch_torch", patch_torch) + + if "sglang.srt.managers.io_struct" not in sys.modules: + io_struct = types.ModuleType("sglang.srt.managers.io_struct") + + class DeltaEncoding(Enum): + INDICES = "indices" + DELTAS = "deltas" + DELTAS_ZSTD = "deltas_zstd" + + @dataclass + class DeltaParam: + name: str + dtype: str + shape: list[int] + pos_start: int + pos_end: int + pos_width: int + val_start: int + val_end: int + + @dataclass + class DeltaSpec: + encoding: DeltaEncoding + params: list[DeltaParam] + checksum: int + + io_struct.DeltaEncoding = DeltaEncoding + io_struct.DeltaParam = DeltaParam + io_struct.DeltaSpec = DeltaSpec + sys.modules["sglang.srt.managers"] = types.ModuleType("sglang.srt.managers") + sys.modules["sglang.srt.managers.io_struct"] = io_struct + + tensor_bucket = types.ModuleType("sglang.srt.weight_sync.tensor_bucket") + tensor_bucket.FlattenedTensorBucket = object + sys.modules.setdefault("sglang.srt.weight_sync", types.ModuleType("sglang.srt.weight_sync")) + sys.modules.setdefault("sglang.srt.weight_sync.tensor_bucket", tensor_bucket) + + +_install_import_stubs() + +from slime.backends.megatron_utils.update_weight import update_weight_from_distributed_delta as delta_mod # noqa: E402 + + +class _InlineFuture: + def __init__(self, value): + self._value = value + + def result(self): + return self._value + + +class _InlineExecutor: + def submit(self, fn): + return _InlineFuture(fn()) + + +class _FakeWriter: + def __init__(self): + self.drain_calls = 0 + + def drain(self): + self.drain_calls += 1 + + +class _RemoteMethod: + def __init__(self, owner, name): + self._owner = owner + self._name = name + + def remote(self, **kwargs): + self._owner.calls.append((self._name, kwargs)) + return f"{self._name}-ref" + + +class _FakeEngine: + def __init__(self): + self.calls = [] + self.update_weights_from_disk = _RemoteMethod(self, "update_weights_from_disk") + self.set_weight_version = _RemoteMethod(self, "set_weight_version") + self.continue_generation = _RemoteMethod(self, "continue_generation") + + +def _patch_single_rank_dist(monkeypatch): + barrier_calls = [] + gathered = [] + + monkeypatch.setattr(delta_mod, "get_gloo_group", lambda: None) + monkeypatch.setattr(delta_mod.dist, "get_rank", lambda: 0) + monkeypatch.setattr(delta_mod.dist, "get_world_size", lambda: 1) + monkeypatch.setattr(delta_mod.dist, "barrier", lambda group=None: barrier_calls.append(group)) + + def all_gather_object(outputs, value, group=None): + gathered.append((list(value), group)) + outputs[0] = list(value) + + monkeypatch.setattr(delta_mod.dist, "all_gather_object", all_gather_object) + return barrier_calls, gathered + + +def _make_publish_only_updater(tmp_path: Path, publish_hook, *, publish_wait: str = "next-sync"): + updater = delta_mod.UpdateWeightFromDistributedDelta.__new__(delta_mod.UpdateWeightFromDistributedDelta) + updater.args = Namespace(update_weight_delta_keep_files=True, update_weight_delta_publish_wait=publish_wait) + updater.transport = "disk" + updater._publish_only = True + updater._publish_wait = publish_wait + updater._pending_files = [] + updater._pending_publishes = [] + updater._published_any = False + updater._pre_push_hook = None + updater._publish_hook = publish_hook + updater._rpc_executor = _InlineExecutor() + updater.writer = _FakeWriter() + updater.weight_version = 7 + updater._version_dir = os.path.join(tmp_path, "weight_v000007") + os.makedirs(updater._version_dir, exist_ok=True) + updater.rollout_engines = [_FakeEngine()] + return updater + + +def test_publish_only_finalize_calls_publish_hook_without_engine_rpcs_or_cleanup(monkeypatch, tmp_path): + _patch_single_rank_dist(monkeypatch) + ray_get_calls = [] + monkeypatch.setattr(delta_mod.ray, "get", lambda refs: ray_get_calls.append(refs)) + monkeypatch.setattr( + delta_mod.shutil, "rmtree", lambda *_args, **_kwargs: pytest.fail("publish-only must keep files") + ) + + hook_calls = [] + + def publish_hook(args, version_dir, files, weight_version, engines): + hook_calls.append((args, version_dir, files, weight_version, engines)) + return ["publish-ref"] + + updater = _make_publish_only_updater(tmp_path, publish_hook) + updater._pending_files = ["rank0000_flush000000.safetensors"] + + updater._finalize_sync() + + assert updater.writer.drain_calls == 1 + assert updater._pending_files == [] + assert updater._published_any is True + assert hook_calls == [ + ( + updater.args, + updater._version_dir, + ["rank0000_flush000000.safetensors"], + "7", + updater.rollout_engines, + ) + ] + assert updater.rollout_engines[0].calls == [] + # The publish stays in flight across the training step: finalize must not + # await its refs; the next sync (or disconnect) drains it. + assert len(updater._pending_publishes) == 1 + assert ray_get_calls == [] + assert os.path.isdir(updater._version_dir) + + updater._drain_pending_publishes() + + assert updater._pending_publishes == [] + assert ray_get_calls == [["publish-ref"]] + + +def test_publish_only_finalize_publishes_noop_version(monkeypatch, tmp_path): + _patch_single_rank_dist(monkeypatch) + ray_get_calls = [] + monkeypatch.setattr(delta_mod.ray, "get", lambda refs: ray_get_calls.append(refs)) + + hook_calls = [] + + def publish_hook(args, version_dir, files, weight_version, engines): + hook_calls.append((version_dir, files, weight_version, engines)) + return None + + updater = _make_publish_only_updater(tmp_path, publish_hook) + + updater._finalize_sync() + + assert updater.writer.drain_calls == 1 + assert updater._published_any is True + assert hook_calls == [(updater._version_dir, [], "7", updater.rollout_engines)] + assert updater.rollout_engines[0].calls == [] + assert len(updater._pending_publishes) == 1 + + updater._drain_pending_publishes() + + assert updater._pending_publishes == [] + assert ray_get_calls == [] + + +def test_publish_only_sync_wait_drains_publish_before_return(monkeypatch, tmp_path): + _patch_single_rank_dist(monkeypatch) + ray_get_calls = [] + monkeypatch.setattr(delta_mod.ray, "get", lambda refs: ray_get_calls.append(refs)) + + updater = _make_publish_only_updater(tmp_path, lambda *a: ["publish-ref"], publish_wait="sync") + updater._pending_files = ["rank0000_flush000000.safetensors"] + + updater._finalize_sync() + + assert updater._pending_publishes == [] + assert ray_get_calls == [["publish-ref"]] + + +def test_disconnect_drains_pending_publish(monkeypatch, tmp_path): + _patch_single_rank_dist(monkeypatch) + ray_get_calls = [] + monkeypatch.setattr(delta_mod.ray, "get", lambda refs: ray_get_calls.append(refs)) + + updater = _make_publish_only_updater(tmp_path, lambda *a: ["publish-ref"]) + updater._finalize_sync() + assert len(updater._pending_publishes) == 1 + + updater.disconnect_rollout_engines() + + assert updater._pending_publishes == [] + assert ray_get_calls == [["publish-ref"]] + + +def test_publish_only_flush_defers_publish_until_finalize(monkeypatch, tmp_path): + barrier_calls, _gathered = _patch_single_rank_dist(monkeypatch) + updater = _make_publish_only_updater(tmp_path, publish_hook=None) + + class FakeBucket: + has_updates = True + + flush_calls = [] + monkeypatch.setattr(updater, "_flush_bucket", lambda bucket, pbar: flush_calls.append((bucket, pbar))) + monkeypatch.setattr( + updater, + "_publish_batch", + lambda: pytest.fail("publish-only should publish once from _finalize_sync"), + ) + + bucket = FakeBucket() + updater._flush_and_publish(bucket, pbar=None) + + assert flush_calls == [(bucket, None)] + assert len(barrier_calls) == 1 + assert updater._pending_publishes == [] diff --git a/tests/test_delta_weight_update.py b/tests/test_delta_weight_update.py deleted file mode 100644 index 2c76858a30..0000000000 --- a/tests/test_delta_weight_update.py +++ /dev/null @@ -1,143 +0,0 @@ -"""E2E smoke test for disk-backed delta weight updates. - -Runs a tiny Qwen3.5-0.8B job so the first weight update seeds the delta -snapshot and the post-train update publishes sparse delta files through -``update_weights_from_disk(load_format="delta", files=...)``. -""" - -import os -import tempfile -from pathlib import Path - -import slime.utils.external_utils.command_utils as U - - -MODEL_NAME = "Qwen3.5-0.8B" -MODEL_TYPE = "qwen3.5-0.8B" -NUM_GPUS = 4 -TORCH_DIST_CKPT = f"/dev/shm/{MODEL_NAME}_torch_dist" - - -def prepare(): - U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") - U.hf_download_dataset("zhuzilin/gsm8k") - U.convert_checkpoint( - model_name=MODEL_NAME, - megatron_model_type=MODEL_TYPE, - num_gpus_per_node=NUM_GPUS, - dir_dst="/dev/shm", - ) - - -def execute(): - with tempfile.TemporaryDirectory(prefix="slime_delta_weight_update_") as delta_dir: - ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load {TORCH_DIST_CKPT} " - - rollout_args = ( - "--prompt-data /root/datasets/gsm8k/train.parquet " - "--input-key messages " - "--label-key label " - "--apply-chat-template " - "--rollout-shuffle " - "--rm-type math " - "--num-rollout 1 " - "--rollout-batch-size 4 " - "--n-samples-per-prompt 4 " - "--rollout-max-response-len 1024 " - "--rollout-temperature 0.8 " - "--over-sampling-batch-size 8 " - "--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " - "--global-batch-size 16 " - ) - - perf_args = ( - "--tensor-model-parallel-size 1 " - "--sequence-parallel " - "--pipeline-model-parallel-size 1 " - "--context-parallel-size 1 " - "--expert-model-parallel-size 1 " - "--expert-tensor-parallel-size 1 " - "--use-dynamic-batch-size " - "--max-tokens-per-gpu 9216 " - ) - - grpo_args = ( - "--advantage-estimator grpo " - "--use-kl-loss " - "--kl-loss-coef 0.00 " - "--kl-loss-type low_var_kl " - # Nonzero entropy coef guarantees a nonzero gradient even when all - # rewards in a group tie (advantages=0), so the delta sync writes - # real sparse files instead of an empty no-op. - "--entropy-coef 0.01 " - "--eps-clip 0.2 " - "--eps-clip-high 0.28 " - ) - - optimizer_args = ( - "--optimizer adam " - "--lr 1e-6 " - "--lr-decay-style constant " - "--weight-decay 0.1 " - "--adam-beta1 0.9 " - "--adam-beta2 0.98 " - ) - - sglang_args = ( - "--rollout-num-gpus-per-engine 1 " - "--rollout-num-gpus 3 " - "--sglang-mem-fraction-static 0.7 " - "--sglang-cuda-graph-max-bs 32 " - "--sglang-enable-metrics " - ) - - delta_args = ( - "--update-weight-mode delta " - "--update-weight-transport disk " - "--update-weight-encoding deltas " - f"--update-weight-disk-dir {delta_dir} " - "--update-weight-delta-keep-files " - ) - - ci_args = "--ci-test " - - misc_args = ( - "--attention-dropout 0.0 " - "--hidden-dropout 0.0 " - "--accumulate-allreduce-grads-in-fp32 " - "--attention-softmax-in-fp32 " - "--attention-backend flash " - "--loss-mask-type qwen3_5 " - "--actor-num-nodes 1 " - "--actor-num-gpus-per-node 1 " - ) - - train_args = ( - f"{ckpt_args} " - f"{rollout_args} " - f"{optimizer_args} " - f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " - f"{perf_args} " - f"{sglang_args} " - f"{delta_args} " - f"{ci_args} " - f"{misc_args} " - ) - - U.execute_train( - train_args=train_args, - num_gpus_per_node=NUM_GPUS, - megatron_model_type=MODEL_TYPE, - ) - - delta_files = list(Path(delta_dir).glob("weight_v*/*.safetensors")) - assert delta_files, f"No disk delta safetensors were written under {delta_dir}" - - -if __name__ == "__main__": - prepare() - for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): - os.environ.pop(proxy_var, None) - execute() diff --git a/tests/test_megatron_argument_validation.py b/tests/test_megatron_argument_validation.py index 9ec17d3a7a..5178a82143 100644 --- a/tests/test_megatron_argument_validation.py +++ b/tests/test_megatron_argument_validation.py @@ -224,6 +224,137 @@ def test_update_weight_disk_dir_rejects_conflicting_alias(monkeypatch): module._resolve_update_weight_disk_dir(args) +def make_slime_validate_args(**overrides): + values = dict( + eval_config=None, + eval_prompt_data=None, + kl_coef=0, + use_kl_loss=False, + ref_load=None, + use_opd=False, + opd_type=None, + opd_teacher_load=None, + megatron_to_hf_mode="raw", + load=None, + hf_checkpoint="/tmp/hf", + ref_ckpt_step=None, + ckpt_step=None, + no_load_optim=False, + no_load_rng=False, + finetune=False, + start_rollout_id=None, + eval_interval=None, + save_interval=None, + save=None, + kl_loss_coef=0, + advantage_estimator="grpo", + normalize_advantages=False, + use_rollout_logprobs=False, + use_tis=False, + get_mismatch_metrics=False, + custom_tis_function_path=None, + use_dynamic_batch_size=False, + max_tokens_per_gpu=None, + log_probs_max_tokens_per_gpu=None, + balance_by_flops=False, + balance_data=False, + eps_clip_high=None, + eps_clip=0.2, + eval_reward_key=None, + reward_key="reward", + dump_details=None, + save_debug_rollout_data=None, + save_debug_train_data=None, + load_debug_rollout_data=None, + rollout_external_engine_addrs=None, + rollout_http_endpoint_url=None, + rollout_http_endpoint_abort_strategy=None, + update_weight_delta_publish_only=False, + debug_train_only=False, + actor_num_gpus_per_node=8, + actor_num_nodes=1, + offload=False, + offload_train=None, + offload_rollout=None, + debug_rollout_only=False, + colocate=False, + rollout_num_gpus=8, + train_memory_margin_bytes=0, + eval_function_path=None, + rollout_function_path="custom.rollout", + num_steps_per_rollout=None, + rollout_batch_size=1, + n_samples_per_prompt=1, + global_batch_size=None, + grpo_std_normalization=True, + over_sampling_batch_size=None, + num_epoch=None, + num_rollout=1, + rollout_global_dataset=False, + enable_mtp_training=False, + mtp_num_layers=None, + use_rollout_routing_replay=False, + use_routing_replay=False, + custom_config_path=None, + eval_max_context_len=None, + rollout_max_context_len=None, + rollout_max_prompt_len=None, + qkv_format="thd", + train_backend="megatron", + only_train_params_name_list=None, + freeze_params_name_list=None, + update_weight_transport="nccl", + update_weight_disk_dir=None, + update_weight_delta_dir=None, + update_weight_mode="full", + ) + values.update(overrides) + return types.SimpleNamespace(**values) + + +@pytest.mark.unit +def test_slime_validate_args_preserves_zero_rollout_gpus_under_colocate(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(colocate=True, rollout_num_gpus=0) + + module.slime_validate_args(args) + + assert args.rollout_num_gpus == 0 + assert args.offload_train is True + assert args.offload_rollout is True + + +@pytest.mark.unit +def test_slime_validate_args_preserves_larger_rollout_gpus_under_colocate(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args( + colocate=True, + actor_num_gpus_per_node=8, + actor_num_nodes=1, + rollout_num_gpus=12, + ) + + module.slime_validate_args(args) + + assert args.rollout_num_gpus == 12 + assert args.offload_train is True + assert args.offload_rollout is True + + +@pytest.mark.unit +def test_slime_validate_args_preserves_zero_rollout_gpus_without_colocate(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args(colocate=False, rollout_num_gpus=0) + + module.slime_validate_args(args) + + assert args.rollout_num_gpus == 0 + assert args.actor_num_gpus_per_node == 8 + assert args.actor_num_nodes == 1 + assert args.offload_train is False + assert args.offload_rollout is False + + @pytest.mark.unit def test_update_weight_delta_rejects_colocate(monkeypatch): module = load_slime_arguments_module(monkeypatch) @@ -232,6 +363,7 @@ def test_update_weight_delta_rejects_colocate(monkeypatch): update_weight_transport="nccl", update_weight_disk_dir=None, update_weight_delta_dir=None, + update_weight_delta_publish_only=False, colocate=True, ) @@ -247,6 +379,7 @@ def test_update_weight_delta_rejects_unknown_transport(monkeypatch): update_weight_transport="tensor", update_weight_disk_dir=None, update_weight_delta_dir=None, + update_weight_delta_publish_only=False, colocate=False, ) diff --git a/tests/test_placement_group.py b/tests/test_placement_group.py index 8f918d4a74..ea5da55e7c 100644 --- a/tests/test_placement_group.py +++ b/tests/test_placement_group.py @@ -10,7 +10,6 @@ from slime.ray.placement_group import _create_placement_group, _get_placement_group_layout - NUM_GPUS = 0 @@ -34,9 +33,19 @@ def _args(**overrides): pytest.param({}, (48, 16), id="normal_non_colocate"), pytest.param({"debug_train_only": True}, (16, 0), id="debug_train_only"), pytest.param({"debug_rollout_only": True}, (32, 0), id="debug_rollout_only"), - pytest.param({"colocate": True}, (16, 0), id="colocate"), + pytest.param({"colocate": True, "rollout_num_gpus": 8}, (16, 0), id="colocate_rollout_less_than_actor"), + pytest.param({"colocate": True, "rollout_num_gpus": 16}, (16, 0), id="colocate_rollout_equals_actor"), + pytest.param({"colocate": True, "rollout_num_gpus": 32}, (32, 0), id="colocate_rollout_more_than_actor"), + pytest.param({"rollout_num_gpus": 0}, (16, 16), id="zero_rollout_gpus"), + pytest.param({"colocate": True, "rollout_num_gpus": 0}, (16, 0), id="colocate_zero_rollout_gpus"), pytest.param({"rollout_external": True}, (16, 16), id="external"), pytest.param({"rollout_external": True, "debug_rollout_only": True}, (0, 0), id="external_debug_rollout"), + pytest.param({"rollout_http_endpoint_url": "https://rollout.example"}, (16, 16), id="http_endpoint"), + pytest.param( + {"rollout_http_endpoint_url": "https://rollout.example", "debug_rollout_only": True}, + (0, 0), + id="http_endpoint_debug_rollout", + ), ], ) def test_placement_group_layout(overrides, expected): diff --git a/tests/test_qwen2.5_0.5B_ppo_critic_only_short.py b/tests/test_qwen2.5_0.5B_ppo_critic_only_short.py deleted file mode 100644 index 88bbc07735..0000000000 --- a/tests/test_qwen2.5_0.5B_ppo_critic_only_short.py +++ /dev/null @@ -1,129 +0,0 @@ -import os -import tempfile - -import slime.utils.external_utils.command_utils as U - -MODEL_NAME = "Qwen2.5-0.5B-Instruct" -MODEL_TYPE = "qwen2.5-0.5B" -NUM_GPUS = 4 - - -def prepare(): - U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") - U.hf_download_dataset("zhuzilin/dapo-math-17k") - - -def execute(): - megatron_config = tempfile.NamedTemporaryFile("w", suffix=".yaml", delete=False) - megatron_config.write( - """ -megatron: - - name: default - role: critic - overrides: - lr: 1e-5 - - name: default - role: actor - overrides: - lr: 1e-6 -""" - ) - megatron_config.close() - - ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " - - rollout_args = ( - "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " - "--input-key prompt " - "--label-key label " - "--apply-chat-template " - "--rollout-shuffle " - "--rm-type deepscaler " - "--num-rollout 2 " - "--rollout-batch-size 4 " - "--n-samples-per-prompt 4 " - "--rollout-max-response-len 1024 " - "--rollout-temperature 0.8 " - "--global-batch-size 16 " - "--balance-data " - ) - - perf_args = ( - "--tensor-model-parallel-size 1 " - "--sequence-parallel " - "--pipeline-model-parallel-size 1 " - "--context-parallel-size 1 " - "--expert-model-parallel-size 1 " - "--expert-tensor-parallel-size 1 " - "--use-dynamic-batch-size " - "--max-tokens-per-gpu 9216 " - ) - - ppo_args = ( - "--advantage-estimator ppo " - "--kl-loss-coef 0.00 " - "--kl-loss-type k1 " - "--kl-coef 0.00 " - "--entropy-coef 0.00 " - "--eps-clip 4e-4 " - "--num-critic-only-steps 2 " - "--normalize-advantages " - ) - - optimizer_args = ( - "--optimizer adam " - "--lr 1e-6 " - "--lr-decay-style constant " - "--weight-decay 0.1 " - "--adam-beta1 0.9 " - "--adam-beta2 0.98 " - ) - - sglang_args = ( - "--rollout-num-gpus-per-engine 1 " - "--rollout-num-gpus 2 " - "--sglang-mem-fraction-static 0.7 " - "--sglang-cuda-graph-max-bs 16 " - "--sglang-enable-metrics " - ) - - ci_args = "--ci-test " - - misc_args = ( - "--attention-dropout 0.0 " - "--hidden-dropout 0.0 " - "--accumulate-allreduce-grads-in-fp32 " - "--attention-softmax-in-fp32 " - "--attention-backend flash " - "--actor-num-nodes 1 " - "--actor-num-gpus-per-node 4 " - "--megatron-to-hf-mode bridge " - "--colocate " - ) - - train_args = ( - f"--megatron-config-path {megatron_config.name} " - f"{ckpt_args} " - f"{rollout_args} " - f"{optimizer_args} " - f"{ppo_args} " - f"{U.get_default_wandb_args(__file__)} " - f"{perf_args} " - f"{sglang_args} " - f"{ci_args} " - f"{misc_args} " - ) - - U.execute_train( - train_args=train_args, - num_gpus_per_node=NUM_GPUS, - megatron_model_type=MODEL_TYPE, - ) - - -if __name__ == "__main__": - prepare() - for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): - os.environ.pop(proxy_var, None) - execute() diff --git a/tests/test_qwen3.5_0.8B_gsm8k_short.py b/tests/test_qwen3.5_0.8B_gsm8k_short.py index 0b493919a5..8bf2c46d0c 100644 --- a/tests/test_qwen3.5_0.8B_gsm8k_short.py +++ b/tests/test_qwen3.5_0.8B_gsm8k_short.py @@ -36,6 +36,7 @@ def execute(): "--n-samples-per-prompt 4 " "--rollout-max-response-len 1024 " "--rollout-temperature 0.8 " + "--rollout-data-transport nixl " "--over-sampling-batch-size 8 " "--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " "--global-batch-size 16 " diff --git a/tests/test_qwen3_30B_A3B.py b/tests/test_qwen3_30B_A3B.py index 5f836e28a6..bb8509f690 100644 --- a/tests/test_qwen3_30B_A3B.py +++ b/tests/test_qwen3_30B_A3B.py @@ -93,7 +93,7 @@ def execute(): sglang_args = ( "--rollout-num-gpus-per-engine 8 " "--sglang-mem-fraction-static 0.8 " - "--sglang-cuda-graph-max-bs 16 " + "--sglang-cuda-graph-max-bs 32 " "--sglang-max-running-requests 512 " "--sglang-enable-metrics " ) diff --git a/tests/test_qwen3_30B_A3B_r3.py b/tests/test_qwen3_30B_A3B_r3.py index b1d1e42da2..a0e3cc6da1 100644 --- a/tests/test_qwen3_30B_A3B_r3.py +++ b/tests/test_qwen3_30B_A3B_r3.py @@ -40,6 +40,7 @@ def execute(): "--n-samples-per-prompt 4 " "--rollout-max-response-len 8192 " "--rollout-temperature 1 " + "--rollout-data-transport nixl " "--global-batch-size 16 " "--balance-data " ) @@ -93,7 +94,7 @@ def execute(): sglang_args = ( "--rollout-num-gpus-per-engine 8 " "--sglang-mem-fraction-static 0.8 " - "--sglang-cuda-graph-max-bs 16 " + "--sglang-cuda-graph-max-bs 32 " "--sglang-max-running-requests 512 " "--sglang-enable-metrics " ) diff --git a/tests/test_qwen3_4B_ppo_disaggregate.py b/tests/test_qwen3_4B_ppo_disaggregate.py index d29ebd010f..bafb6260ff 100644 --- a/tests/test_qwen3_4B_ppo_disaggregate.py +++ b/tests/test_qwen3_4B_ppo_disaggregate.py @@ -51,6 +51,7 @@ def execute(): "--n-samples-per-prompt 4 " "--rollout-max-response-len 8192 " "--rollout-temperature 0.8 " + "--rollout-data-transport nixl " "--global-batch-size 16 " "--balance-data " ) diff --git a/tests/test_rollout_http_endpoint.py b/tests/test_rollout_http_endpoint.py new file mode 100644 index 0000000000..187cd7887d --- /dev/null +++ b/tests/test_rollout_http_endpoint.py @@ -0,0 +1,326 @@ +import asyncio +import sys +from argparse import Namespace +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +try: + import ray # noqa: F401 +except ImportError: + pass + +try: + from tests.plugin_contracts._shared import install_stubs +except ImportError: + from plugin_contracts._shared import install_stubs + +install_stubs(with_sglang_router=True, with_transformers=True) + +from slime.backends.sglang_utils.http_endpoint import ( # noqa: E402 + normalize_rollout_http_endpoint_url, + start_http_endpoint_rollout_servers, +) +from slime.rollout import sglang_rollout # noqa: E402 +from slime.rollout.sglang_rollout import abort, generate, get_model_url # noqa: E402 +from slime.utils.types import Sample # noqa: E402 + +NUM_GPUS = 0 + + +def _args(**overrides): + values = { + "ci_test": False, + "rollout_http_endpoint_url": None, + "rollout_http_endpoint_abort_strategy": "router-workers", + "sglang_router_ip": "10.0.0.1", + "sglang_router_port": 30000, + "sglang_model_routers": None, + "router_policy": None, + "use_rollout_routing_replay": False, + "partial_rollout": False, + "mask_offpolicy_in_partial_rollout": False, + "sglang_speculative_algorithm": None, + "custom_rollout_request_hook_path": None, + } + values.update(overrides) + return Namespace(**values) + + +class _Tokenizer: + def encode(self, prompt, add_special_tokens=False): + assert add_special_tokens is False + return [101, len(prompt)] + + +class _GenerateState: + def __init__(self, args): + self.args = args + self.tokenizer = _Tokenizer() + self.processor = None + self.pendings = set() + self.aborted = False + + +def test_normalize_rollout_http_endpoint_url_requires_absolute_http_url(): + assert normalize_rollout_http_endpoint_url("https://rollout.example/") == "https://rollout.example" + with pytest.raises(ValueError, match="absolute http"): + normalize_rollout_http_endpoint_url("rollout.example") + + +def test_get_model_url_prefers_http_endpoint(): + args = _args( + rollout_http_endpoint_url="https://rollout.example/base/", + sglang_model_routers={"default": ("10.0.0.2", 30001)}, + ) + + assert get_model_url(args, "default", "/generate") == "https://rollout.example/base/generate" + assert get_model_url(args, "reward", "score") == "https://rollout.example/base/score" + + +def test_get_model_url_uses_model_router_without_http_endpoint(): + args = _args(sglang_model_routers={"reward": ("10.0.0.3", 30002)}) + + assert get_model_url(args, "reward", "/generate") == "http://10.0.0.3:30002/generate" + assert get_model_url(args, "missing", "/generate") == "http://10.0.0.1:30000/generate" + + +def test_generate_posts_to_http_endpoint(monkeypatch): + captured = {} + + async def fake_post(url, payload, headers=None, **_kwargs): + captured["url"] = url + captured["payload"] = payload + captured["headers"] = headers + return { + "text": " answer", + "meta_info": { + "output_token_logprobs": [[-0.25, 42]], + "finish_reason": {"type": "stop"}, + "prompt_tokens": 2, + "cached_tokens": 1, + }, + } + + monkeypatch.setattr(sglang_rollout, "GenerateState", _GenerateState) + monkeypatch.setattr(sglang_rollout, "post", fake_post) + + args = _args(rollout_http_endpoint_url="https://rollout.example") + sample = asyncio.run(generate(args, Sample(index=0, prompt="hi"), {"max_new_tokens": 8})) + + assert captured["url"] == "https://rollout.example/generate" + assert captured["payload"]["input_ids"] == [101, 2] + assert captured["payload"]["return_logprob"] is True + assert sample.response == " answer" + assert sample.tokens == [101, 2, 42] + assert sample.status == Sample.Status.COMPLETED + + +def test_generate_request_hook_can_add_exact_weight_version(monkeypatch): + captured = {} + + async def fake_post(url, payload, headers=None, max_retries=60, retry_sleep=1.0): + captured["url"] = url + captured["payload"] = payload + captured["max_retries"] = max_retries + captured["retry_sleep"] = retry_sleep + return { + "text": " answer", + "meta_info": { + "output_token_logprobs": [[-0.25, 42]], + "finish_reason": {"type": "stop"}, + "prompt_tokens": 2, + "cached_tokens": 1, + }, + } + + monkeypatch.setattr(sglang_rollout, "GenerateState", _GenerateState) + monkeypatch.setattr(sglang_rollout, "post", fake_post) + + def hook(args, sample, request): + assert args.rollout_http_endpoint_url == "https://rollout.example" + assert sample.index == 0 + assert request["rollout_id"] == 9 + assert request["evaluation"] is False + request["payload"]["weight_version"] = {"exact_version": request["rollout_id"]} + request["max_retries"] = 123 + request["retry_sleep"] = 0.25 + + monkeypatch.setattr(sglang_rollout, "load_function", lambda path: hook) + + args = _args( + rollout_http_endpoint_url="https://rollout.example", + custom_rollout_request_hook_path="example.hook", + ) + with sglang_rollout.rollout_request_context(args, rollout_id=9): + sample = asyncio.run(generate(args, Sample(index=0, prompt="hi"), {"max_new_tokens": 8})) + + assert captured["url"] == "https://rollout.example/generate" + assert captured["payload"]["weight_version"] == {"exact_version": 9} + assert captured["max_retries"] == 123 + assert captured["retry_sleep"] == 0.25 + assert sample.status == Sample.Status.COMPLETED + + +def test_generate_request_hook_can_return_request_updates(monkeypatch): + captured = {} + + async def fake_post(url, payload, headers=None, max_retries=60, retry_sleep=1.0): + captured["url"] = url + captured["payload"] = payload + captured["max_retries"] = max_retries + captured["retry_sleep"] = retry_sleep + return { + "text": " answer", + "meta_info": { + "output_token_logprobs": [[-0.25, 42]], + "finish_reason": {"type": "stop"}, + "prompt_tokens": 2, + "cached_tokens": 1, + }, + } + + monkeypatch.setattr(sglang_rollout, "GenerateState", _GenerateState) + monkeypatch.setattr(sglang_rollout, "post", fake_post) + + async def hook(_args, _sample, request): + payload = dict(request["payload"]) + payload["weight_version"] = {"min_required_version": request["rollout_id"]} + return { + "payload": payload, + "max_retries": 123, + "retry_sleep": 0.25, + } + + monkeypatch.setattr(sglang_rollout, "load_function", lambda path: hook) + + args = _args( + rollout_http_endpoint_url="https://rollout.example", + custom_rollout_request_hook_path="example.hook", + ) + with sglang_rollout.rollout_request_context(args, rollout_id=9): + sample = asyncio.run(generate(args, Sample(index=0, prompt="hi"), {"max_new_tokens": 8})) + + assert captured["url"] == "https://rollout.example/generate" + assert captured["payload"]["weight_version"] == {"min_required_version": 9} + assert captured["max_retries"] == 123 + assert captured["retry_sleep"] == 0.25 + assert sample.status == Sample.Status.COMPLETED + + +def test_generate_retries_until_exact_weight_version_is_available(monkeypatch): + aiohttp_web = pytest.importorskip("aiohttp.web") + httpx = pytest.importorskip("httpx") + + async def run(): + from slime.utils import http_utils + + attempts = [] + + async def handle_generate(request): + payload = await request.json() + attempts.append(payload) + assert payload["weight_version"] == {"exact_version": 11} + if len(attempts) == 1: + raise aiohttp_web.HTTPNotFound(text="weight version not loaded") + if len(attempts) == 2: + raise aiohttp_web.HTTPConflict(text="weight version still loading") + return aiohttp_web.json_response( + { + "text": " answer", + "meta_info": { + "output_token_logprobs": [[-0.25, 42]], + "finish_reason": {"type": "stop"}, + "prompt_tokens": 2, + "cached_tokens": 0, + }, + } + ) + + app = aiohttp_web.Application() + app.router.add_post("/generate", handle_generate) + runner = aiohttp_web.AppRunner(app) + await runner.setup() + site = aiohttp_web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + port = site._server.sockets[0].getsockname()[1] + + old_client = http_utils._http_client + old_distributed = http_utils._distributed_post_enabled + old_post_actors = http_utils._post_actors + client = httpx.AsyncClient(timeout=httpx.Timeout(None), trust_env=False) + http_utils._http_client = client + http_utils._distributed_post_enabled = False + http_utils._post_actors = [] + try: + monkeypatch.setattr(sglang_rollout, "GenerateState", _GenerateState) + + def hook(_args, _sample, request): + request["payload"]["weight_version"] = {"exact_version": request["rollout_id"]} + request["max_retries"] = 5 + request["retry_sleep"] = 0.01 + + monkeypatch.setattr(sglang_rollout, "load_function", lambda path: hook) + + args = _args( + rollout_http_endpoint_url=f"http://127.0.0.1:{port}", + custom_rollout_request_hook_path="example.hook", + ) + with sglang_rollout.rollout_request_context(args, rollout_id=11): + sample = await generate(args, Sample(index=0, prompt="hi"), {"max_new_tokens": 8}) + finally: + await client.aclose() + http_utils._http_client = old_client + http_utils._distributed_post_enabled = old_distributed + http_utils._post_actors = old_post_actors + await runner.cleanup() + + assert len(attempts) == 3 + assert sample.status == Sample.Status.COMPLETED + assert sample.tokens == [101, 2, 42] + + asyncio.run(run()) + + +def test_cancel_only_abort_does_not_query_router_workers(monkeypatch): + async def run(): + async def never_finishes(): + await asyncio.sleep(60) + + task = asyncio.create_task(never_finishes()) + state = _GenerateState(_args()) + state.pendings.add(task) + + def fake_state(_args): + return state + + async def fail_get(_url): + raise AssertionError("cancel-only abort must not query router workers") + + monkeypatch.setattr(sglang_rollout, "GenerateState", fake_state) + monkeypatch.setattr(sglang_rollout, "get", fail_get) + + result = await abort(_args(rollout_http_endpoint_abort_strategy="cancel-only"), rollout_id=7) + + assert result == [] + assert state.pendings == set() + assert task.cancelled() + + asyncio.run(run()) + + +def test_start_http_endpoint_rollout_servers_returns_no_engine_server(): + args = _args(rollout_http_endpoint_url="https://rollout.example/", rollout_num_engines=None) + + servers = start_http_endpoint_rollout_servers(args) + + server = servers["default"] + assert args.rollout_http_endpoint_url == "https://rollout.example" + assert args.rollout_num_engines == 1 + assert server.engines == [] + assert server.server_groups == [] + assert server.router_ip is None diff --git a/tests/utils/test_sglang_config.py b/tests/utils/test_sglang_config.py index 5015a74558..543c84d98e 100644 --- a/tests/utils/test_sglang_config.py +++ b/tests/utils/test_sglang_config.py @@ -1,10 +1,17 @@ """Unit tests for SglangConfig multi-model parsing with update_weights.""" +import sys import tempfile +from argparse import Namespace +from pathlib import Path import pytest import yaml +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + def _write_yaml(data: dict) -> str: f = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) @@ -29,6 +36,7 @@ def test_update_weights_default_true(self): } ) config = SglangConfig.from_yaml(path) + config.models[0].resolve(Namespace(hf_checkpoint="/tmp/hf", rollout_num_gpus_per_engine=1)) assert len(config.models) == 1 assert config.models[0].update_weights is True @@ -83,6 +91,199 @@ def test_multi_model_total_gpus(self): config = SglangConfig.from_yaml(path) assert config.total_num_gpus == 12 + def test_config_allows_model_with_no_server_groups(self): + """A model with no server groups can expose a router without local engines.""" + from slime.backends.sglang_utils.sglang_config import SglangConfig + + path = _write_yaml({"sglang": [{"name": "default", "server_groups": []}]}) + + config = SglangConfig.from_yaml(path) + + assert len(config.models) == 1 + assert config.models[0].name == "default" + assert config.models[0].server_groups == [] + assert config.total_num_gpus == 0 + + +class TestZeroGpuRolloutConfig: + def test_resolve_default_zero_gpu_config_has_no_server_groups(self): + from slime.ray.rollout import _resolve_sglang_config + + args = Namespace(sglang_config=None, prefill_num_servers=None, rollout_num_gpus=0) + + config = _resolve_sglang_config(args) + + assert len(config.models) == 1 + assert config.models[0].name == "default" + assert config.models[0].server_groups == [] + assert config.total_num_gpus == 0 + + def test_zero_gpu_config_takes_precedence_over_prefill_num_servers(self): + from slime.ray.rollout import _resolve_sglang_config + + args = Namespace(sglang_config=None, prefill_num_servers=1, rollout_num_gpus=0) + + config = _resolve_sglang_config(args) + + assert config.models[0].server_groups == [] + assert config.total_num_gpus == 0 + + def test_start_rollout_servers_zero_gpu_starts_router_without_engines(self, monkeypatch): + from slime.ray import rollout as rollout_module + + def fake_start_router(args, *, has_pd_disaggregation=False, force_new=False): + assert has_pd_disaggregation is False + assert force_new is False + return "127.0.0.1", 3456 + + monkeypatch.setattr(rollout_module, "_start_router", fake_start_router) + args = Namespace( + rollout_external=False, + sglang_config=None, + prefill_num_servers=None, + rollout_num_gpus=0, + rollout_num_gpus_per_engine=1, + num_gpus_per_node=8, + debug_train_only=False, + debug_rollout_only=False, + colocate=False, + actor_num_nodes=1, + actor_num_gpus_per_node=8, + offload_rollout=False, + hf_checkpoint="/tmp/hf", + ) + + servers, init_handles = rollout_module.start_rollout_servers(args, pg=(None, [], [])) + + assert list(servers) == ["default"] + assert init_handles == [] + server = servers["default"] + assert server.router_ip == "127.0.0.1" + assert server.router_port == 3456 + assert server.server_groups == [] + assert server.engines == [] + assert args.sglang_router_ip == "127.0.0.1" + assert args.sglang_router_port == 3456 + assert args.sglang_model_routers == {"default": ("127.0.0.1", 3456)} + + def test_start_rollout_servers_defers_engine_wait(self, monkeypatch): + from slime.ray import rollout as rollout_module + + def fake_start_router(args, *, has_pd_disaggregation=False, force_new=False): + assert has_pd_disaggregation is False + assert force_new is False + return "127.0.0.1", 3456 + + def fake_start_engines(self, port_cursors=None): + self.all_engines = [object() for _ in self.all_engines] + return [f"init-{self.rank_offset}"], port_cursors or {} + + ray_get_calls = [] + + def fake_ray_get(refs): + ray_get_calls.append(refs) + + monkeypatch.setattr(rollout_module, "_start_router", fake_start_router) + monkeypatch.setattr(rollout_module.ServerGroup, "start_engines", fake_start_engines) + monkeypatch.setattr(rollout_module.ray, "get", fake_ray_get) + + args = Namespace( + rollout_external=False, + sglang_config=None, + prefill_num_servers=None, + rollout_num_gpus=2, + rollout_num_gpus_per_engine=1, + num_gpus_per_node=8, + debug_train_only=False, + debug_rollout_only=False, + colocate=False, + actor_num_nodes=1, + actor_num_gpus_per_node=8, + offload_rollout=False, + hf_checkpoint="/tmp/hf", + ) + + servers, init_handles = rollout_module.start_rollout_servers(args, pg=(None, [], [])) + + assert list(servers) == ["default"] + assert init_handles == ["init-0"] + assert ray_get_calls == [] + + def test_start_rollout_servers_waits_for_epd_encoder_before_non_encoder(self, monkeypatch): + from slime.backends.sglang_utils.sglang_config import ModelConfig, ServerGroupConfig, SglangConfig + from slime.ray import rollout as rollout_module + + class FakeRemoteMethod: + def __init__(self, value): + self.value = value + + def remote(self): + return self.value + + class FakeEngine: + def __init__(self, url_ref): + self.get_url = FakeRemoteMethod(url_ref) + + def fake_start_router(args, *, has_pd_disaggregation=False, force_new=False): + assert has_pd_disaggregation is False + assert force_new is False + return "127.0.0.1", 3456 + + def fake_resolve_sglang_config(args): + return SglangConfig( + models=[ + ModelConfig( + name="default", + server_groups=[ + ServerGroupConfig(worker_type="encoder", num_gpus=1), + ServerGroupConfig(worker_type="regular", num_gpus=1), + ], + ) + ] + ) + + def fake_start_engines(self, port_cursors=None): + if self.worker_type == "encoder": + self.all_engines = [FakeEngine("encoder-url-ref") for _ in self.all_engines] + else: + self.all_engines = [object() for _ in self.all_engines] + return [f"{self.worker_type}-init-{self.rank_offset}"], port_cursors or {} + + ray_get_calls = [] + + def fake_ray_get(refs): + ray_get_calls.append(refs) + if refs == ["encoder-url-ref"]: + return ["http://encoder"] + return None + + monkeypatch.setattr(rollout_module, "_start_router", fake_start_router) + monkeypatch.setattr(rollout_module, "_resolve_sglang_config", fake_resolve_sglang_config) + monkeypatch.setattr(rollout_module.ServerGroup, "start_engines", fake_start_engines) + monkeypatch.setattr(rollout_module.ray, "get", fake_ray_get) + + args = Namespace( + rollout_external=False, + rollout_num_gpus_per_engine=1, + num_gpus_per_node=8, + debug_train_only=False, + debug_rollout_only=False, + colocate=False, + actor_num_nodes=1, + actor_num_gpus_per_node=8, + offload_rollout=False, + hf_checkpoint="/tmp/hf", + ) + + servers, init_handles = rollout_module.start_rollout_servers(args, pg=(None, [], [])) + + groups = servers["default"].server_groups + assert [group.worker_type for group in groups] == ["encoder", "regular"] + assert groups[1].sglang_overrides["language_only"] is True + assert groups[1].sglang_overrides["encoder_urls"] == ["http://encoder"] + assert init_handles == ["regular-init-1"] + assert ray_get_calls == [["encoder-init-0"], ["encoder-url-ref"]] + class TestGetModelUrl: def test_get_model_url_basic(self):