From b339d55723ebdecbbbf71b9b19f9d892b5ba6324 Mon Sep 17 00:00:00 2001 From: Jason Mancuso <7891333+jvmncs@users.noreply.github.com> Date: Fri, 22 May 2026 15:27:22 -0400 Subject: [PATCH 1/2] add modal example for delta compression tranport over Volume to flash rollout server --- examples/delta_weight_sync/README.md | 25 + examples/delta_weight_sync/__init__.py | 1 + .../delta_weight_sync/modal_delta_sync.py | 791 ++++++++++++++++++ slime/backends/megatron_utils/sglang.py | 28 +- slime/backends/sglang_utils/arguments.py | 22 + .../sglang_utils/modal_volume_hooks.py | 28 + slime/backends/sglang_utils/sglang_engine.py | 46 +- slime/ray/placement_group.py | 3 + slime/ray/rollout.py | 70 +- slime/rollout/rm_hub/__init__.py | 6 + slime/rollout/sglang_rollout.py | 27 +- slime/utils/arguments.py | 5 +- slime/utils/url_utils.py | 88 ++ tests/utils/test_modal_delta_sync.py | 229 +++++ tests/utils/test_sglang_config.py | 50 +- uv.lock | 3 + 16 files changed, 1354 insertions(+), 68 deletions(-) create mode 100644 examples/delta_weight_sync/__init__.py create mode 100644 examples/delta_weight_sync/modal_delta_sync.py create mode 100644 slime/backends/sglang_utils/modal_volume_hooks.py create mode 100644 slime/utils/url_utils.py create mode 100644 tests/utils/test_modal_delta_sync.py create mode 100644 uv.lock diff --git a/examples/delta_weight_sync/README.md b/examples/delta_weight_sync/README.md index 3782b20bb4..109148868d 100644 --- a/examples/delta_weight_sync/README.md +++ b/examples/delta_weight_sync/README.md @@ -10,6 +10,7 @@ Both modes are lossless by construction (selective overwrite via NaN sentinel; n ## Files - `run-glm4.7-355B-A32B-delta.sh`: 16-node (8 actor + 8 rollout) GLM-4.7-355B-A32B launcher. Disk transport active by default; NCCL block commented below it. +- `modal_delta_sync.py`: single-file Modal deployment for Qwen3-4B. It runs an autoinference-style SGLang rollout container behind `@modal.experimental.http_server`, applies the local SGLang delta patch, mounts the delta Volume, and includes the H100 slime trainer function plus local entrypoints. ## Usage @@ -28,6 +29,30 @@ DELTA_ARGS=( ) ``` +**Modal Volume + `http_server`:** + +Deploy the Qwen3-4B rollout server with one warm container: + +```bash +MIN_CONTAINERS=1 uv run modal deploy examples/delta_weight_sync/modal_delta_sync.py +``` + +Launch a two-step end-to-end training smoke against the deployed app: + +```bash +uv run --with requests modal run examples/delta_weight_sync/modal_delta_sync.py::launch_run --num-rollout 2 +``` + +For custom trainer invocations, use the deployed URL as both the generation router and the external engine admin endpoint: + +```bash +--rollout-external +--sglang-router-url https://your-rollout-url.modal.run +--rollout-external-engine-addrs https://your-rollout-url.modal.run +--update-weight-delta-dir /delta +--custom-delta-pre-push-path slime.backends.sglang_utils.modal_volume_hooks.commit_modal_delta_volume +``` + **NCCL (baseline):** ```bash diff --git a/examples/delta_weight_sync/__init__.py b/examples/delta_weight_sync/__init__.py new file mode 100644 index 0000000000..fea2c10e09 --- /dev/null +++ b/examples/delta_weight_sync/__init__.py @@ -0,0 +1 @@ +"""Delta weight sync examples.""" diff --git a/examples/delta_weight_sync/modal_delta_sync.py b/examples/delta_weight_sync/modal_delta_sync.py new file mode 100644 index 0000000000..a49a725c30 --- /dev/null +++ b/examples/delta_weight_sync/modal_delta_sync.py @@ -0,0 +1,791 @@ +"""Modal Qwen3-4B rollout + slime trainer with disk delta weight sync. + +Deploy: + MIN_CONTAINERS=1 uv run modal deploy examples/delta_weight_sync/modal_delta_sync.py + +Launch a deployed end-to-end run: + uv run --with requests modal run examples/delta_weight_sync/modal_delta_sync.py::launch_run --num-rollout 2 +""" + +from __future__ import annotations + +import json +import os +import re +import shlex +import subprocess +import threading +import time +import asyncio +import glob +from collections import deque +from pathlib import Path +from typing import Any + +import modal +import modal.experimental + + +def _local_repo_root() -> Path: + file_path = Path(__file__).resolve() + candidates = [file_path.parent, *file_path.parents] + for candidate in candidates: + if (candidate / "docker/patch/latest/sglang.patch").exists(): + return candidate + raise RuntimeError(f"Could not locate slime repo root from {file_path}") + + +IS_LOCAL = modal.is_local() +REPO_ROOT = _local_repo_root() if IS_LOCAL else Path("/root/slime") + +MINUTES = 60 +HOURS = 60 * MINUTES + +APP_NAME = os.environ.get("SLIME_MODAL_APP_NAME", "slime-qwen3-4b-delta-sync") +MODEL_NAME = os.environ.get("SLIME_MODAL_MODEL_NAME", "Qwen/Qwen3-4B") +MODEL_REVISION = os.environ.get("SLIME_MODAL_MODEL_REVISION", "main") + +SLIME_COMMIT = "0a664bc5eb776a785b4e035ddd57866f921d0cdc" +SGLANG_IMAGE_TAG = "v0.5.10.post1" +MEGATRON_COMMIT = "1dcf0dafa884ad52ffb243625717a3471643e087" + +ROLLOUT_BASE_IMAGE = os.environ.get("SLIME_MODAL_ROLLOUT_BASE_IMAGE", f"slimerl/sglang:{SGLANG_IMAGE_TAG}") +TRAINER_BASE_IMAGE = os.environ.get("SLIME_MODAL_TRAINER_BASE_IMAGE", "slimerl/slime-test:nightly-dev-20260429b") +AUTOINFERENCE_UTILS_VERSION = os.environ.get("AUTOINFERENCE_UTILS_VERSION", "0.2.0") + +HF_CACHE_PATH = "/root/.cache/huggingface" +HF_CACHE_VOLUME_NAME = os.environ.get("HF_CACHE_VOLUME_NAME", "huggingface-cache") +DELTA_MOUNT_PATH = os.environ.get("SLIME_DELTA_MOUNT_PATH", "/delta") +DELTA_VOLUME_NAME = os.environ.get("SLIME_DELTA_VOLUME_NAME", "slime-qwen3-4b-deltas") + +SGLANG_INTERNAL_PORT = int(os.environ.get("SLIME_SGLANG_INTERNAL_PORT", "8001")) +SIDECAR_PORT = int(os.environ.get("SLIME_MODAL_SIDECAR_PORT", "8000")) +PROXY_REGIONS = os.environ.get("PROXY_REGIONS", "us-west").split(",") +REGION = os.environ.get("SLIME_MODAL_REGION", "us") +STARTUP_TIMEOUT = int(os.environ.get("SLIME_MODAL_STARTUP_TIMEOUT", str(45 * MINUTES))) +SCALEDOWN_WINDOW = int(os.environ.get("SLIME_MODAL_SCALEDOWN_WINDOW", str(15 * MINUTES))) +MIN_CONTAINERS = int(os.environ.get("MIN_CONTAINERS", os.environ.get("SLIME_MODAL_ROLLOUT_MIN_CONTAINERS", "0"))) +MAX_CONTAINERS = 1 +TARGET_INPUTS = int(os.environ.get("SLIME_MODAL_TARGET_INPUTS", "16")) + +GPU_TYPE = os.environ.get("SLIME_MODAL_GPU_TYPE", "H100") +GPU = f"{GPU_TYPE}:1" +MEMORY_MB = int(os.environ.get("SLIME_MODAL_MEMORY_MB", "131072")) + +SGLANG_CONTEXT_LENGTH = int(os.environ.get("SLIME_SGLANG_CONTEXT_LENGTH", "2048")) +SGLANG_MEM_FRACTION_STATIC = os.environ.get("SLIME_SGLANG_MEM_FRACTION_STATIC", "0.50") + +HF_IMAGE_ENV = { + "HF_HOME": HF_CACHE_PATH, + "HF_XET_HIGH_PERFORMANCE": "1", + "HF_HUB_ENABLE_HF_TRANSFER": "1", +} +RUNTIME_ENV = { + "PYTHONUNBUFFERED": "1", + "PYTHONPATH": "/root:/root/Megatron-LM:/root/slime", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "SLIME_DELTA_VOLUME_NAME": DELTA_VOLUME_NAME, +} + +APPLY_SGLANG_PATCH = ( + "set -eux; " + 'SGLANG_DIR="${SGLANG_DIR:-/sgl-workspace/sglang}"; ' + '[ -d "$SGLANG_DIR" ] || SGLANG_DIR="/root/src/sglang"; ' + 'test -d "$SGLANG_DIR"; ' + 'cd "$SGLANG_DIR"; ' + "git update-index --refresh || true; " + "if git apply --check /tmp/sglang.patch; then " + "git apply --3way /tmp/sglang.patch; " + "elif git apply --reverse --check /tmp/sglang.patch; then " + 'echo "sglang.patch is already applied"; ' + "else " + "git apply --3way /tmp/sglang.patch; " + "fi; " + "if grep -R -n '^<<<<<<< ' python/sglang; then " + 'echo "SGLang patch conflict markers remain after sglang.patch"; ' + "exit 1; " + "fi" +) + +app = modal.App(name=APP_NAME) +hf_cache_volume = modal.Volume.from_name(HF_CACHE_VOLUME_NAME, create_if_missing=True) +delta_volume = modal.Volume.from_name(DELTA_VOLUME_NAME, create_if_missing=True) + + +def _join_url(base_url: str, path: str) -> str: + return f"{base_url.rstrip('/')}/{path.lstrip('/')}" + + +def _normalize_base_url(base_url: str) -> str: + if base_url.startswith(("http://", "https://")): + return base_url.rstrip("/") + return f"http://{base_url.rstrip('/')}" + + +_DELTA_VERSION_RE = re.compile(r"^weight_v\d{6}$") +_HOP_BY_HOP_HEADERS = { + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailer", + "transfer-encoding", + "upgrade", + "content-length", +} + + +def validate_delta_update_payload(payload: dict[str, Any], *, delta_mount_path: str = "/delta") -> str: + if payload.get("load_format") != "delta": + raise ValueError("update_weights_from_disk sidecar only accepts load_format='delta'") + + model_path = payload.get("model_path") + if not isinstance(model_path, str): + raise ValueError("update_weights_from_disk payload requires string model_path") + + mount = os.path.realpath(delta_mount_path) + real_model_path = os.path.realpath(model_path) + if os.path.commonpath([mount, real_model_path]) != mount: + raise ValueError(f"model_path must be under {delta_mount_path}") + + if os.path.dirname(real_model_path) != mount or not _DELTA_VERSION_RE.fullmatch(os.path.basename(real_model_path)): + raise ValueError(f"model_path must match {delta_mount_path}/weight_vNNNNNN") + + return real_model_path + + +def verify_delta_dir_ready(model_path: str) -> None: + done_path = os.path.join(model_path, "DONE") + if not os.path.isfile(done_path): + raise FileNotFoundError(f"missing delta DONE marker: {done_path}") + if not glob.glob(os.path.join(model_path, "*.safetensors")): + raise FileNotFoundError(f"missing delta safetensors files under: {model_path}") + + +def _delta_dir_summary(model_path: str) -> tuple[int, int]: + safetensors_paths = glob.glob(os.path.join(model_path, "*.safetensors")) + total_bytes = sum(os.path.getsize(path) for path in safetensors_paths if os.path.isfile(path)) + return len(safetensors_paths), total_bytes + + +async def _reload_volume(delta_volume_obj: Any) -> None: + if delta_volume_obj is None: + return + reload_fn = getattr(delta_volume_obj, "reload", None) + if reload_fn is None: + raise TypeError("delta_volume must expose a reload() method") + if asyncio.iscoroutinefunction(reload_fn): + await reload_fn() + else: + await asyncio.to_thread(reload_fn) + + +def _forward_response_headers(headers: Any) -> dict[str, str]: + return {k: v for k, v in headers.items() if k.lower() not in _HOP_BY_HOP_HEADERS} + + +async def _proxy_request( + request: Any, + *, + target_base_url: str, + delta_volume_obj: Any, + delta_mount_path: str, + update_lock: Any, +) -> Any: + import aiohttp + from aiohttp import web + + endpoint = request.match_info["tail"] + endpoint = f"/{endpoint}" if endpoint else "/" + target_url = _join_url(target_base_url, endpoint) + if request.query_string: + target_url = f"{target_url}?{request.query_string}" + + excluded_headers = _HOP_BY_HOP_HEADERS | {"host"} + headers = {k: v for k, v in request.headers.items() if k.lower() not in excluded_headers} + update_version = None + update_started_at = None + + if endpoint == "/update_weights_from_disk": + if request.method != "POST": + return web.json_response({"error": "method not allowed"}, status=405) + async with update_lock: + try: + payload = await request.json() + model_path = validate_delta_update_payload(payload, delta_mount_path=delta_mount_path) + update_version = os.path.basename(model_path) + update_started_at = time.time() + print( + f"delta sidecar update start: version={update_version} model_path={model_path}", + flush=True, + ) + await _reload_volume(delta_volume_obj) + verify_delta_dir_ready(model_path) + file_count, total_bytes = _delta_dir_summary(model_path) + print( + "delta sidecar update ready: " + f"version={update_version} safetensors={file_count} bytes={total_bytes} " + f"reload_verify_s={time.time() - update_started_at:.2f}", + flush=True, + ) + except json.JSONDecodeError as exc: + print(f"delta sidecar update rejected: invalid JSON payload: {exc}", flush=True) + return web.json_response({"error": f"invalid JSON payload: {exc}"}, status=400) + except Exception as exc: # noqa: BLE001 - sidecar validation errors should become HTTP errors + print(f"delta sidecar update rejected: {exc}", flush=True) + return web.json_response({"error": str(exc)}, status=400) + body = json.dumps(payload).encode("utf-8") + headers["content-type"] = "application/json" + else: + body = await request.read() + + try: + async with aiohttp.ClientSession() as session: + async with session.request(request.method, target_url, data=body, headers=headers) as response: + content = await response.read() + if update_version is not None: + elapsed_s = time.time() - update_started_at if update_started_at is not None else -1.0 + print( + "delta sidecar update forwarded: " + f"version={update_version} upstream_status={response.status} total_s={elapsed_s:.2f}", + flush=True, + ) + return web.Response( + body=content, + status=response.status, + headers=_forward_response_headers(response.headers), + ) + except aiohttp.ClientError as exc: + return web.json_response({"error": f"SGLang upstream unavailable: {exc}"}, status=503) + + +def create_delta_proxy_app( + *, + target_base_url: str = "http://127.0.0.1:8001", + delta_volume_obj: Any = None, + delta_mount_path: str = "/delta", +) -> Any: + from aiohttp import web + + proxy_app = web.Application() + proxy_app["target_base_url"] = _normalize_base_url(target_base_url) + proxy_app["delta_volume"] = delta_volume_obj + proxy_app["delta_mount_path"] = delta_mount_path + proxy_app["update_lock"] = asyncio.Lock() + + async def handler(request: Any) -> Any: + return await _proxy_request( + request, + target_base_url=proxy_app["target_base_url"], + delta_volume_obj=proxy_app["delta_volume"], + delta_mount_path=proxy_app["delta_mount_path"], + update_lock=proxy_app["update_lock"], + ) + + proxy_app.router.add_route("*", "/{tail:.*}", handler) + return proxy_app + + +def run_delta_proxy( + *, + host: str = "0.0.0.0", + port: int = 8000, + target_base_url: str = "http://127.0.0.1:8001", + delta_volume_obj: Any = None, + delta_mount_path: str = "/delta", +) -> None: + from aiohttp import web + + proxy_app = create_delta_proxy_app( + target_base_url=target_base_url, + delta_volume_obj=delta_volume_obj, + delta_mount_path=delta_mount_path, + ) + web.run_app(proxy_app, host=host, port=port, handle_signals=False) + + +def _build_rollout_image() -> modal.Image: + image = modal.Image.from_registry(ROLLOUT_BASE_IMAGE) + if IS_LOCAL: + image = ( + image.add_local_file( + REPO_ROOT / "docker/patch/latest/sglang.patch", + "/tmp/sglang.patch", + copy=True, + ) + .run_commands(APPLY_SGLANG_PATCH) + .run_commands( + "sed -i 's/timeout_keep_alive=5/timeout_keep_alive=300/g' " + "/sgl-workspace/sglang/python/sglang/srt/entrypoints/http_server.py || true", + "apt-get update -qq && apt-get install -y -qq libcudart12 2>/dev/null " + "|| pip install nvidia-cuda-runtime-cu12 --quiet", + r"sed -i 's/self_named_buffers\[name\]\[\.\.\.] = tensor/self_named_buffers[name].data.copy_(tensor)/g' " + "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler_update_weights_mixin.py || true", + ) + ) + image = image.uv_pip_install( + f"autoinference-utils=={AUTOINFERENCE_UTILS_VERSION}", + "aiohttp", + "hf_transfer", + "requests", + ) + if IS_LOCAL: + image = ( + image.add_local_dir(REPO_ROOT / "slime", "/root/slime/slime", copy=True) + .add_local_file(REPO_ROOT / "examples/delta_weight_sync/modal_delta_sync.py", "/root/modal_delta_sync.py", copy=True) + ) + return image.env(HF_IMAGE_ENV | RUNTIME_ENV) + + +def _build_trainer_image() -> modal.Image: + image = modal.Image.from_registry(TRAINER_BASE_IMAGE).uv_pip_install("hf_transfer", "modal", "requests") + if IS_LOCAL: + image = ( + image.add_local_dir(REPO_ROOT / "slime", "/root/slime/slime", copy=True) + .add_local_file(REPO_ROOT / "train.py", "/root/slime/train.py", copy=True) + .add_local_file(REPO_ROOT / "examples/delta_weight_sync/modal_delta_sync.py", "/root/modal_delta_sync.py", copy=True) + ) + return image.env(HF_IMAGE_ENV | RUNTIME_ENV) + + +rollout_image = _build_rollout_image() +trainer_image = _build_trainer_image() + +with rollout_image.imports(): + from autoinference_utils.endpoint import SGLangEndpoint, warmup_chat_completions + +SERVER_ARGS = { + "--revision": MODEL_REVISION, + "--served-model-name": MODEL_NAME, + "--context-length": str(SGLANG_CONTEXT_LENGTH), + "--mem-fraction-static": SGLANG_MEM_FRACTION_STATIC, + "--reasoning-parser": "qwen3", + "--trust-remote-code": "", + "--cuda-graph-bs": "1 2 4 8 16", + "--cuda-graph-max-bs": str(TARGET_INPUTS), + "--max-running-requests": str(TARGET_INPUTS), + "--disable-piecewise-cuda-graph": "", + "--skip-server-warmup": "", +} + +WARMUP_PAYLOAD = { + "model": MODEL_NAME, + "messages": [{"role": "user", "content": "Reply with exactly OK."}], + "max_tokens": 8, + "temperature": 0, + "chat_template_kwargs": {"enable_thinking": False}, +} + +MODEL_ARGS = [ + "--swiglu", + "--num-layers", + "36", + "--hidden-size", + "2560", + "--ffn-hidden-size", + "9728", + "--num-attention-heads", + "32", + "--group-query-attention", + "--num-query-groups", + "8", + "--use-rotary-position-embeddings", + "--disable-bias-linear", + "--normalization", + "RMSNorm", + "--norm-epsilon", + "1e-6", + "--rotary-base", + "1000000", + "--vocab-size", + "151936", + "--kv-channels", + "128", + "--qk-layernorm", +] + + +def _run(cmd: list[str], *, cwd: str | None = None, env: dict[str, str] | None = None, check: bool = True) -> int: + print("+ " + shlex.join(cmd), flush=True) + proc = subprocess.Popen( + cmd, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + last_output: deque[str] = deque(maxlen=240) + if proc.stdout is not None: + for line in proc.stdout: + line = line.rstrip("\n") + last_output.append(line) + print(line, flush=True) + return_code = proc.wait() + if check and return_code != 0: + tail = "\n".join(last_output) + raise RuntimeError(f"Command failed with exit code {return_code}: {shlex.join(cmd)}\n\nLast output:\n{tail}") + return return_code + + +def _detect_nvlink() -> str: + try: + output = subprocess.check_output(["nvidia-smi", "topo", "-m"], text=True, stderr=subprocess.DEVNULL) + except Exception: + return "0" + return "1" if "NV" in output else "0" + + +def _write_smoke_dataset(path: str, rows: int = 8) -> None: + os.makedirs(os.path.dirname(path), exist_ok=True) + prompts = [ + ("What is 1 + 1? Answer with a single integer.", "2"), + ("What is 3 + 4? Answer with a single integer.", "7"), + ("What is 9 - 5? Answer with a single integer.", "4"), + ("What is 6 / 2? Answer with a single integer.", "3"), + ] + with open(path, "w", encoding="utf-8") as f: + for i in range(rows): + prompt, label = prompts[i % len(prompts)] + f.write(json.dumps({"prompt": prompt, "label": label}) + "\n") + + +def _ensure_model_cached() -> str: + from huggingface_hub import snapshot_download + + model_path = snapshot_download(repo_id=MODEL_NAME, revision=MODEL_REVISION) + hf_cache_volume.commit() + return model_path + + +def _build_train_args( + rollout_url: str, + *, + model_path: str, + num_rollout: int, + response_len: int, +) -> list[str]: + rollout_batch_size = 1 + n_samples_per_prompt = 1 + global_batch_size = rollout_batch_size * n_samples_per_prompt + dataset_path = "/root/slime-data/qwen3_4b_smoke.jsonl" + + return [ + "--actor-num-nodes", + "1", + "--actor-num-gpus-per-node", + "1", + "--megatron-to-hf-mode", + "bridge", + "--hf-checkpoint", + model_path, + "--load", + model_path, + *MODEL_ARGS, + "--prompt-data", + dataset_path, + "--input-key", + "prompt", + "--label-key", + "label", + "--apply-chat-template", + "--apply-chat-template-kwargs", + json.dumps({"enable_thinking": False}), + "--custom-rm-path", + "slime.rollout.rm_hub.constant_reward", + "--disable-rewards-normalization", + "--num-rollout", + str(num_rollout), + "--rollout-batch-size", + str(rollout_batch_size), + "--n-samples-per-prompt", + str(n_samples_per_prompt), + "--num-steps-per-rollout", + "1", + "--global-batch-size", + str(global_batch_size), + "--rollout-max-context-len", + str(SGLANG_CONTEXT_LENGTH), + "--rollout-max-prompt-len", + "512", + "--rollout-max-response-len", + str(response_len), + "--rollout-temperature", + "0.7", + "--skip-eval-before-train", + "--advantage-estimator", + "grpo", + "--kl-coef", + "0.0", + "--kl-loss-coef", + "0.0", + "--entropy-coef", + "0.0", + "--eps-clip", + "0.2", + "--eps-clip-high", + "0.28", + "--optimizer", + "adam", + "--lr", + "1e-6", + "--lr-decay-style", + "constant", + "--weight-decay", + "0.1", + "--adam-beta1", + "0.9", + "--adam-beta2", + "0.98", + "--tensor-model-parallel-size", + "1", + "--pipeline-model-parallel-size", + "1", + "--context-parallel-size", + "1", + "--expert-model-parallel-size", + "1", + "--expert-tensor-parallel-size", + "1", + "--micro-batch-size", + "1", + "--use-dynamic-batch-size", + "--max-tokens-per-gpu", + "2048", + "--rollout-external", + "--rollout-num-gpus", + "1", + "--rollout-num-gpus-per-engine", + "1", + "--sglang-router-url", + rollout_url, + "--rollout-external-engine-addrs", + rollout_url, + "--sglang-mem-fraction-static", + SGLANG_MEM_FRACTION_STATIC, + "--update-weight-mode", + "delta", + "--update-weight-transport", + "disk", + "--update-weight-encoding", + "deltas", + "--update-weight-delta-dir", + DELTA_MOUNT_PATH, + "--update-weight-delta-keep-files", + "--custom-delta-pre-push-path", + "slime.backends.sglang_utils.modal_volume_hooks.commit_modal_delta_volume", + "--attention-dropout", + "0.0", + "--hidden-dropout", + "0.0", + "--accumulate-allreduce-grads-in-fp32", + "--attention-softmax-in-fp32", + "--attention-backend", + "flash", + ] + + +@app.cls( + include_source=False, + image=rollout_image, + gpu=GPU, + volumes={ + HF_CACHE_PATH: hf_cache_volume, + DELTA_MOUNT_PATH: delta_volume, + }, + min_containers=MIN_CONTAINERS, + max_containers=MAX_CONTAINERS, + timeout=24 * HOURS, + startup_timeout=STARTUP_TIMEOUT, + scaledown_window=SCALEDOWN_WINDOW, + memory=MEMORY_MB, + region=REGION, +) +@modal.experimental.http_server( + port=SIDECAR_PORT, + proxy_regions=PROXY_REGIONS, + exit_grace_period=25, + startup_timeout=STARTUP_TIMEOUT, +) +@modal.concurrent(target_inputs=TARGET_INPUTS) +class RolloutServer: + @modal.enter() + def startup(self) -> None: + self.endpoint = SGLangEndpoint( + model_path=MODEL_NAME, + worker_port=SGLANG_INTERNAL_PORT, + tp=1, + extra_server_args=SERVER_ARGS, + health_timeout=STARTUP_TIMEOUT, + health_poll_interval=5.0, + ) + self.endpoint.start() + warmup_chat_completions( + port=SGLANG_INTERNAL_PORT, + payload=WARMUP_PAYLOAD, + successful_requests=2, + request_timeout=90.0, + ) + hf_cache_volume.commit() + + self.sidecar_thread = threading.Thread( + target=run_delta_proxy, + kwargs={ + "host": "0.0.0.0", + "port": SIDECAR_PORT, + "target_base_url": f"http://127.0.0.1:{SGLANG_INTERNAL_PORT}", + "delta_volume_obj": delta_volume, + "delta_mount_path": DELTA_MOUNT_PATH, + }, + daemon=True, + ) + self.sidecar_thread.start() + print(f"{MODEL_NAME} rollout is serving through the delta sidecar on port {SIDECAR_PORT}.", flush=True) + + @modal.exit() + def stop(self) -> None: + if hasattr(self, "endpoint"): + self.endpoint.stop() + + +@app.function( + name="train_qwen3_4b", + include_source=False, + image=trainer_image, + gpu=GPU, + volumes={ + HF_CACHE_PATH: hf_cache_volume, + DELTA_MOUNT_PATH: delta_volume, + }, + max_containers=1, + timeout=24 * HOURS, + memory=MEMORY_MB, + region=REGION, +) +def train_qwen3_4b(rollout_url: str, num_rollout: int = 2, response_len: int = 16) -> dict[str, Any]: + rollout_url = rollout_url.rstrip("/") + model_path = _ensure_model_cached() + dataset_path = "/root/slime-data/qwen3_4b_smoke.jsonl" + _write_smoke_dataset(dataset_path) + + train_env = os.environ.copy() + train_env.update(RUNTIME_ENV) + train_env["NCCL_NVLS_ENABLE"] = _detect_nvlink() + train_env["MASTER_ADDR"] = "127.0.0.1" + train_env["RAY_DEDUP_LOGS"] = "0" + + _run(["ray", "stop", "--force"], check=False) + _run( + [ + "ray", + "start", + "--head", + "--node-ip-address", + "127.0.0.1", + "--num-gpus", + "1", + "--disable-usage-stats", + "--dashboard-host", + "0.0.0.0", + "--dashboard-port", + "8265", + ], + env=train_env, + ) + + runtime_env_json = json.dumps( + { + "env_vars": { + "PYTHONPATH": train_env["PYTHONPATH"], + "CUDA_DEVICE_MAX_CONNECTIONS": train_env["CUDA_DEVICE_MAX_CONNECTIONS"], + "NCCL_NVLS_ENABLE": train_env["NCCL_NVLS_ENABLE"], + "RAY_DEDUP_LOGS": train_env["RAY_DEDUP_LOGS"], + } + } + ) + train_args = _build_train_args( + rollout_url, + model_path=model_path, + num_rollout=num_rollout, + response_len=response_len, + ) + job_cmd = [ + "ray", + "job", + "submit", + "--address=http://127.0.0.1:8265", + f"--runtime-env-json={runtime_env_json}", + "--", + "python3", + "/root/slime/train.py", + *train_args, + ] + started = time.time() + try: + _run(job_cmd, cwd="/root/slime", env=train_env) + finally: + _run(["ray", "stop", "--force"], check=False) + + delta_volume.commit() + delta_volume.reload() + version_dirs = [] + if os.path.isdir(DELTA_MOUNT_PATH): + version_dirs = sorted(name for name in os.listdir(DELTA_MOUNT_PATH) if name.startswith("weight_v")) + return { + "rollout_url": rollout_url, + "num_rollout": num_rollout, + "elapsed_s": round(time.time() - started, 1), + "delta_versions": version_dirs, + } + + +def _deployed_rollout_url() -> str: + rollout_cls = modal.Cls.from_name(APP_NAME, "RolloutServer") + urls = rollout_cls._experimental_get_flash_urls() + if not urls: + raise RuntimeError(f"No http_server URL found for deployed Modal class {APP_NAME}::RolloutServer") + return urls[0].rstrip("/") + + +def _assert_rollout_ready(rollout_url: str) -> None: + import requests + + deadline = time.time() + STARTUP_TIMEOUT + last_error = "not checked" + while time.time() < deadline: + try: + response = requests.get(_join_url(rollout_url, "/health_generate"), timeout=30) + if response.status_code == 200: + break + last_error = f"{response.status_code} {response.text[:300]}" + except requests.RequestException as exc: + last_error = repr(exc) + print(f"Waiting for deployed rollout /health_generate: {last_error}", flush=True) + time.sleep(10) + else: + raise TimeoutError(f"Rollout did not become healthy within {STARTUP_TIMEOUT}s: {last_error}") + + bad_update = requests.post( + _join_url(rollout_url, "/update_weights_from_disk"), + json={"load_format": "delta", "model_path": "/tmp/not-a-delta"}, + timeout=30, + ) + if bad_update.status_code != 400: + raise RuntimeError(f"sidecar accepted bad delta path: {bad_update.status_code} {bad_update.text}") + + +@app.local_entrypoint() +def rollout_url() -> None: + print(_deployed_rollout_url()) + + +@app.local_entrypoint() +def launch_run(num_rollout: int = 2, response_len: int = 16) -> None: + rollout_url_value = _deployed_rollout_url() + print(f"Using deployed rollout URL: {rollout_url_value}", flush=True) + _assert_rollout_ready(rollout_url_value) + train_fn = modal.Function.from_name(APP_NAME, "train_qwen3_4b") + result = train_fn.remote(rollout_url_value, num_rollout=num_rollout, response_len=response_len) + print(json.dumps(result, indent=2, sort_keys=True)) diff --git a/slime/backends/megatron_utils/sglang.py b/slime/backends/megatron_utils/sglang.py index 801217310d..382b5f719e 100644 --- a/slime/backends/megatron_utils/sglang.py +++ b/slime/backends/megatron_utils/sglang.py @@ -1,4 +1,7 @@ # the file to manage all sglang deps in the megatron actor +from dataclasses import dataclass +from enum import Enum + try: from sglang.srt.layers.quantization.fp8_utils import quant_weight_ue8m0, transform_scale_ue8m0 from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 @@ -16,12 +19,27 @@ try: from sglang.srt.managers.io_struct import DeltaEncoding, DeltaParam, DeltaSpec except ImportError: - # Older sglang images don't have delta-sync io_struct. Only --update-weight-mode=delta - # needs these; the default full-sync path runs without them. - DeltaEncoding = None - DeltaParam = None - DeltaSpec = None + class DeltaEncoding(str, Enum): + INDICES = "indices" + DELTAS = "deltas" + DELTAS_ZSTD = "deltas_zstd" + + @dataclass + class DeltaParam: + name: str + dtype: str + shape: list[int] + pos_start: int + pos_end: int + pos_width: int + val_start: int + val_end: int + + @dataclass + class DeltaSpec: + encoding: DeltaEncoding + params: list[DeltaParam] from sglang.srt.utils import MultiprocessingSerializer diff --git a/slime/backends/sglang_utils/arguments.py b/slime/backends/sglang_utils/arguments.py index 0a4801743f..219e332c8e 100644 --- a/slime/backends/sglang_utils/arguments.py +++ b/slime/backends/sglang_utils/arguments.py @@ -2,6 +2,7 @@ from sglang.srt.server_args import ServerArgs from slime.utils.http_utils import _wrap_ipv6 +from slime.utils.url_utils import normalize_base_url, parse_external_engine_addr # TODO: use all sglang router arguments with `--sglang-router` prefix @@ -21,6 +22,15 @@ def add_sglang_router_arguments(parser): default=None, help="Port of the SGLang router", ) + parser.add_argument( + "--sglang-router-url", + type=str, + default=None, + help=( + "Full http(s) base URL for an externally managed SGLang router or single-engine endpoint. " + "When set, rollout generation uses this URL instead of --sglang-router-ip/port." + ), + ) parser.add_argument( "--sglang-router-request-timeout-secs", type=int, @@ -156,6 +166,18 @@ def validate_args(args): if getattr(args, "sglang_router_ip", None): args.sglang_router_ip = _wrap_ipv6(args.sglang_router_ip) + if getattr(args, "sglang_router_url", None): + args.sglang_router_url = normalize_base_url(args.sglang_router_url) + assert getattr(args, "rollout_external", False), "--sglang-router-url requires --rollout-external." + + if getattr(args, "rollout_external", False): + external_engine_addrs = getattr(args, "rollout_external_engine_addrs", None) + assert external_engine_addrs, ( + "--rollout-external requires --rollout-external-engine-addrs so external admin and update RPCs " + "target explicit engine endpoints." + ) + for addr in external_engine_addrs: + parse_external_engine_addr(addr) # Mutual-exclusion checks for PD disaggregation / sglang-config. assert not ( diff --git a/slime/backends/sglang_utils/modal_volume_hooks.py b/slime/backends/sglang_utils/modal_volume_hooks.py new file mode 100644 index 0000000000..cc7938d55f --- /dev/null +++ b/slime/backends/sglang_utils/modal_volume_hooks.py @@ -0,0 +1,28 @@ +import os +from typing import Any + + +def _distributed_rank() -> int: + try: + import torch.distributed as dist + + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() + except Exception: + pass + return int(os.environ.get("RANK", "0")) + + +def commit_modal_delta_volume(args: Any, version_dir: str, rollout_engines: Any) -> None: + """Commit a Modal Volume after slime writes a disk delta version directory.""" + if _distributed_rank() != 0: + return + + volume_name = os.environ.get("SLIME_DELTA_VOLUME_NAME") + if not volume_name: + raise RuntimeError("SLIME_DELTA_VOLUME_NAME must be set to commit a Modal delta volume") + + import modal + + modal.Volume.from_name(volume_name, create_if_missing=True).commit() + print(f"Committed Modal delta volume {volume_name} for {version_dir}", flush=True) diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index 0564915e41..52125818fc 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -15,6 +15,7 @@ from slime.ray.ray_actor import RayActor from slime.utils.http_utils import get_host_info +from slime.utils.url_utils import join_url, make_http_base_url logger = logging.getLogger(__name__) @@ -123,9 +124,11 @@ def init( disaggregation_bootstrap_port=None, router_ip=None, router_port=None, + server_url=None, + external_addr_is_url: bool = False, ): - self.router_ip = router_ip if router_ip is not None else self.args.sglang_router_ip - self.router_port = router_port if router_port is not None else self.args.sglang_router_port + self.router_ip = router_ip if router_ip is not None else getattr(self.args, "sglang_router_ip", None) + self.router_port = router_port if router_port is not None else getattr(self.args, "sglang_router_port", None) host = host or get_host_info()[1] @@ -160,6 +163,12 @@ def _format_v6_uri(addr): self.node_rank = server_args_dict["node_rank"] self.server_host = server_args_dict["host"] # with [] if ipv6 self.server_port = server_args_dict["port"] + self.server_url = server_url or make_http_base_url(self.server_host, self.server_port) + + if external_addr_is_url: + external_engine_need_check_fields = [ + name for name in external_engine_need_check_fields if name not in {"host", "port"} + ] if self.args.rollout_external: self._init_external(server_args_dict, external_engine_need_check_fields=external_engine_need_check_fields) @@ -170,7 +179,7 @@ def _init_external(self, expect_server_args, external_engine_need_check_fields): logger.info(f"Use external SGLang engine (rank={self.rank}, expect_server_args={expect_server_args})") def _get_actual_server_args(): - response = requests.get(f"http://{self.server_host}:{self.server_port}/get_server_info") + response = requests.get(join_url(self.server_url, "/get_server_info")) response.raise_for_status() return response.json() @@ -183,7 +192,7 @@ def _sanity_check_server_args(actual_server_args, expect_server_args): ), f"{name=} {expect_value=} {actual_value=} {expect_server_args=} {actual_server_args=}" _wait_server_healthy( - base_url=f"http://{self.server_host}:{self.server_port}", + base_url=self.server_url, api_key=None, is_process_alive=lambda: True, ) @@ -198,14 +207,15 @@ def _init_normal(self, server_args_dict): return if self.node_rank == 0 and self.router_ip and self.router_port: + worker_url = self.server_url if parse(sglang_router.__version__) <= parse("0.2.1"): assert self.worker_type == "regular", "pd disaggregation is not supported in old router." response = requests.post( - f"http://{self.router_ip}:{self.router_port}/add_worker?url=http://{self.server_host}:{self.server_port}", + f"http://{self.router_ip}:{self.router_port}/add_worker?url={worker_url}", ) else: payload = { - "url": f"http://{self.server_host}:{self.server_port}", + "url": worker_url, "worker_type": self.worker_type, } if self.worker_type == "prefill": @@ -229,7 +239,7 @@ def _make_request(self, endpoint: str, payload: dict | None = None): if self.node_rank != 0: return - url = f"http://{self.server_host}:{self.server_port}/{endpoint}" + url = join_url(self.server_url, endpoint) response = requests.post(url, json=payload or {}) try: response.raise_for_status() @@ -254,7 +264,7 @@ def health_generate(self, timeout: float = 5.0) -> bool: return True response = requests.get( - f"http://{self.server_host}:{self.server_port}/health_generate", + join_url(self.server_url, "/health_generate"), timeout=timeout, ) response.raise_for_status() @@ -292,7 +302,7 @@ def flush_cache(self): # flush cache will not return status_code 200 when there are pending requests for _ in range(60): try: - response = requests.get(f"http://{self.server_host}:{self.server_port}/flush_cache") + response = requests.get(join_url(self.server_url, "/flush_cache")) if response.status_code == 200: break except NewConnectionError as e: @@ -307,7 +317,7 @@ def flush_cache(self): def get_url(self): if self.node_rank != 0: return None - return f"http://{self.server_host}:{self.server_port}" + return self.server_url def shutdown(self): if self.args.rollout_external: @@ -315,12 +325,10 @@ def shutdown(self): logger.info(f"Shutdown engine {self.server_host}:{self.server_port}...") if self.worker_type != "encoder" and self.node_rank == 0: - worker_url = f"http://{self.server_host}:{self.server_port}" + worker_url = self.server_url response = None if parse(sglang_router.__version__) <= parse("0.2.1"): - response = requests.post( - f"http://{self.router_ip}:{self.router_port}/remove_worker?url=http://{self.server_host}:{self.server_port}" - ) + response = requests.post(f"http://{self.router_ip}:{self.router_port}/remove_worker?url={worker_url}") elif parse(sglang_router.__version__) < parse("0.3.0"): worker_url = quote(worker_url, safe="") response = requests.delete(f"http://{self.router_ip}:{self.router_port}/workers/{worker_url}") @@ -346,7 +354,7 @@ def shutdown(self): def get_weight_version(self): if self.node_rank != 0: return - url = f"http://{self.server_host}:{self.server_port}/get_weight_version" + url = join_url(self.server_url, "/get_weight_version") response = requests.get(url) response.raise_for_status() return response.json()["weight_version"] @@ -457,12 +465,12 @@ def update_weights_from_distributed( ) def pause_generation(self): - response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={}) + response = requests.post(join_url(self.server_url, "/pause_generation"), json={}) response.raise_for_status() return response def continue_generation(self): - response = requests.post(f"http://{self.server_host}:{self.server_port}/continue_generation", json={}) + response = requests.post(join_url(self.server_url, "/continue_generation"), json={}) response.raise_for_status() return response @@ -500,7 +508,7 @@ def start_profile( record_shapes: bool | None = None, ): response = requests.post( - f"http://{self.server_host}:{self.server_port}/start_profile", + join_url(self.server_url, "/start_profile"), json={ "output_dir": output_dir, "start_step": start_step, @@ -515,7 +523,7 @@ def start_profile( return response def stop_profile(self): - response = requests.post(f"http://{self.server_host}:{self.server_port}/stop_profile", json={}) + response = requests.post(join_url(self.server_url, "/stop_profile"), json={}) response.raise_for_status() return response diff --git a/slime/ray/placement_group.py b/slime/ray/placement_group.py index edc723dcef..b30da65abb 100644 --- a/slime/ray/placement_group.py +++ b/slime/ray/placement_group.py @@ -89,6 +89,9 @@ def create_placement_groups(args): elif args.colocate: num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node rollout_offset = 0 + elif getattr(args, "rollout_external", False): + num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + rollout_offset = num_gpus else: num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + args.rollout_num_gpus rollout_offset = args.actor_num_nodes * args.actor_num_gpus_per_node diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 69c786cf84..25b2b18955 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -25,6 +25,7 @@ from slime.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix from slime.utils.misc import Box, group_by, load_function from slime.utils.types import Sample +from slime.utils.url_utils import parse_external_engine_addr from ..utils.metric_utils import has_repetition from .rollout_validation import validate_server_group_gpu_indices @@ -58,6 +59,7 @@ class ServerGroup: model_path: str | None = None # checkpoint path for update_weights_from_disk router_ip: str | None = None router_port: int | None = None + router_url: str | None = None @property def nodes_per_engine(self): @@ -109,16 +111,22 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis global_rank = self.rank_offset + i num_gpus = 0.2 num_cpus = num_gpus + scheduling_strategy = None + base_gpu_id = 0 - # Get the base GPU ID from placement group using gpu_offset. - gpu_index = self.gpu_offset + i * num_gpu_per_engine - base_gpu_id = int(reordered_gpu_ids[gpu_index]) - - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=reordered_bundle_indices[gpu_index], - ) + if self.args.rollout_external: + num_gpus = 0 + num_cpus = 0.2 + else: + # Get the base GPU ID from placement group using gpu_offset. + gpu_index = self.gpu_offset + i * num_gpu_per_engine + base_gpu_id = int(reordered_gpu_ids[gpu_index]) + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=reordered_bundle_indices[gpu_index], + ) env_vars = {name: "1" for name in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST} | { key: os.environ.get(key, default_val) @@ -134,13 +142,18 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis "SLIME_ENABLE_PROFILING": "true", }.items() } - rollout_engine = RolloutRayActor.options( - num_cpus=num_cpus, - num_gpus=num_gpus, - scheduling_strategy=scheduling_strategy, - runtime_env={ + actor_options = { + "num_cpus": num_cpus, + "num_gpus": num_gpus, + "runtime_env": { "env_vars": env_vars, }, + } + if scheduling_strategy is not None: + actor_options["scheduling_strategy"] = scheduling_strategy + + rollout_engine = RolloutRayActor.options( + **actor_options ).remote( self.args, rank=global_rank, @@ -230,6 +243,7 @@ class RolloutServer: server_groups: list[ServerGroup] router_ip: str | None = None router_port: int | None = None + router_url: str | None = None model_name: str = "default" update_weights: bool = True @@ -411,7 +425,11 @@ def _get_metrics_router_addr(self) -> str | None: metrics are disabled or no servers are running. """ srv = self.server - if srv is None or srv.router_ip is None: + if srv is None: + return None + if srv.router_url is not None: + return srv.router_url + if srv.router_ip is None: return None return f"http://{srv.router_ip}:{srv.router_port}" @@ -843,13 +861,14 @@ def _validate_rollout_id_annotated(node, depth=0): def _allocate_rollout_engine_addr_and_ports_external(args, rollout_engines): addr_and_ports = {} for rank, _ in rollout_engines: - addr = args.rollout_external_engine_addrs[rank] - [host, port] = addr.split(":") + addr = parse_external_engine_addr(args.rollout_external_engine_addrs[rank]) addr_and_ports[rank] = dict( - dist_init_addr=addr, + dist_init_addr=addr.dist_init_addr, nccl_port=None, - host=host, - port=int(port), + host=addr.host, + port=addr.port, + server_url=addr.base_url, + external_addr_is_url=addr.is_url, ) return addr_and_ports @@ -950,6 +969,10 @@ def _start_router(args, *, has_pd_disaggregation: bool = False, force_new: bool ``force_new`` is False, skip launching and return the existing values. When ``force_new`` is True (multi-model), always allocate a fresh port. """ + if not force_new and getattr(args, "sglang_router_url", None): + addr = parse_external_engine_addr(args.sglang_router_url) + return addr.host, addr.port + if not force_new and args.sglang_router_ip is not None: return args.sglang_router_ip, args.sglang_router_port @@ -1041,6 +1064,7 @@ def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: has_pd = model_cfg.has_pd_disaggregation router_ip, router_port = _start_router(args, has_pd_disaggregation=has_pd, force_new=(model_idx > 0)) + router_url = getattr(args, "sglang_router_url", None) if model_idx == 0 else None # Write back for backward compat (first model only). if model_idx == 0: @@ -1085,6 +1109,7 @@ def _make_group(group_cfg, router_ip, router_port, overrides_extra=None): model_path=overrides.get("model_path", args.hf_checkpoint), router_ip=router_ip, router_port=router_port, + router_url=router_url, ) engine_offset += num_engines gpu_offset += group_cfg.num_gpus @@ -1140,12 +1165,15 @@ def _make_group(group_cfg, router_ip, router_port, overrides_extra=None): server_groups=server_groups, router_ip=router_ip, router_port=router_port, + router_url=router_url, model_name=model_cfg.name, update_weights=model_cfg.update_weights, ) # Expose per-model router info for custom rollout functions. - args.sglang_model_routers = {name: (srv.router_ip, srv.router_port) for name, srv in servers.items()} + args.sglang_model_routers = { + name: srv.router_url or (srv.router_ip, srv.router_port) for name, srv in servers.items() + } return servers diff --git a/slime/rollout/rm_hub/__init__.py b/slime/rollout/rm_hub/__init__.py index 0991e559e5..fab9a43bc3 100644 --- a/slime/rollout/rm_hub/__init__.py +++ b/slime/rollout/rm_hub/__init__.py @@ -52,6 +52,12 @@ async def remote_rm(args, sample: Sample, max_retries: int = 10): await asyncio.sleep(backoff) +async def constant_reward(args, sample_or_samples, **kwargs): + if isinstance(sample_or_samples, list): + return [1.0 for _ in sample_or_samples] + return 1.0 + + async def async_rm(args, sample: Sample, **kwargs): if args.custom_rm_path is not None: rm_function = load_function(args.custom_rm_path) diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index c7f86b98ca..8e322f3992 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -29,6 +29,11 @@ ) from slime.utils.trace_utils import build_sglang_meta_trace_attrs, trace_function, trace_span from slime.utils.types import Sample +from slime.utils.url_utils import ( + get_default_router_url_from_args, + get_external_engine_base_urls_from_args, + get_model_url_from_args, +) from .rm_hub import async_rm, batched_async_rm @@ -61,6 +66,10 @@ def _prepare_prompt_ids(sample: Sample, tokenizer, processor: Any) -> list[int]: return tokenizer.encode(sample.prompt, add_special_tokens=False) +def _default_router_url(args: Namespace, endpoint: str) -> str: + return get_default_router_url_from_args(args, endpoint) + + def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") -> str: """Return the router URL for a named model. @@ -73,11 +82,7 @@ def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") Falls back to the default router if *model_name* is not found or ``sglang_model_routers`` is not set. """ - routers = getattr(args, "sglang_model_routers", None) - if routers and model_name in routers: - ip, port = routers[model_name] - return f"http://{ip}:{port}{endpoint}" - return f"http://{args.sglang_router_ip}:{args.sglang_router_port}{endpoint}" + return get_model_url_from_args(args, model_name, endpoint) class GenerateState(metaclass=SingletonMeta): @@ -155,7 +160,7 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A assert isinstance(sample.prompt, str) state = GenerateState(args) - url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + url = get_model_url(args, "default", "/generate") assert ( sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED @@ -356,11 +361,15 @@ async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]: assert not state.aborted state.aborted = True - if parse(sglang_router.__version__) <= parse("0.2.1"): - response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") + if getattr(args, "rollout_external", False): + urls = get_external_engine_base_urls_from_args(args) + if not urls: + raise ValueError("--rollout-external requires --rollout-external-engine-addrs for abort/admin requests") + elif parse(sglang_router.__version__) <= parse("0.2.1"): + response = await get(_default_router_url(args, "/list_workers")) urls = response["urls"] else: - response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") + response = await get(_default_router_url(args, "/workers")) urls = [worker["url"] for worker in response["workers"]] logger.info(f"Abort request for {urls}") diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 9b73a78221..314fa4409e 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -516,7 +516,10 @@ def add_rollout_arguments(parser): type=str, default=None, nargs="+", - help="Address and ports of the external engines.", + help=( + "Address and ports of the external engines. Entries can be host:port or " + "http(s)://... base URLs for engines exposed through an external gateway." + ), ) return parser diff --git a/slime/utils/url_utils.py b/slime/utils/url_utils.py new file mode 100644 index 0000000000..b47727568b --- /dev/null +++ b/slime/utils/url_utils.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +from urllib.parse import urlsplit, urlunsplit + + +@dataclass(frozen=True) +class ExternalEngineAddress: + base_url: str + host: str + port: int + dist_init_addr: str + is_url: bool + + +def normalize_base_url(base_url: str) -> str: + parsed = urlsplit(base_url) + if parsed.scheme not in {"http", "https"} or not parsed.netloc: + raise ValueError(f"Expected an http(s) URL, got {base_url!r}") + path = parsed.path.rstrip("/") + return urlunsplit((parsed.scheme, parsed.netloc, path, "", "")) + + +def join_url(base_url: str, endpoint: str) -> str: + base = normalize_base_url(base_url) + endpoint = endpoint if endpoint.startswith("/") else f"/{endpoint}" + return f"{base}{endpoint}" + + +def make_http_base_url(host: str, port: int) -> str: + return f"http://{host}:{port}" + + +def get_default_router_url_from_args(args, endpoint: str = "/generate") -> str: + if getattr(args, "sglang_router_url", None): + return join_url(args.sglang_router_url, endpoint) + return join_url(make_http_base_url(args.sglang_router_ip, args.sglang_router_port), endpoint) + + +def get_model_url_from_args(args, model_name: str, endpoint: str = "/generate") -> str: + routers = getattr(args, "sglang_model_routers", None) + if routers and model_name in routers: + router = routers[model_name] + if isinstance(router, str): + return join_url(router, endpoint) + ip, port = router + return join_url(make_http_base_url(ip, port), endpoint) + return get_default_router_url_from_args(args, endpoint) + + +def _format_host_for_addr(host: str) -> str: + if ":" in host and not host.startswith("["): + return f"[{host}]" + return host + + +def parse_external_engine_addr(addr: str) -> ExternalEngineAddress: + """Parse either ``host:port`` or an http(s) external engine base URL.""" + if "://" in addr: + parsed = urlsplit(addr) + if parsed.scheme not in {"http", "https"} or not parsed.hostname: + raise ValueError(f"Expected an http(s) external engine URL, got {addr!r}") + default_port = 443 if parsed.scheme == "https" else 80 + port = parsed.port or default_port + host = _format_host_for_addr(parsed.hostname) + return ExternalEngineAddress( + base_url=normalize_base_url(addr), + host=host, + port=port, + dist_init_addr=f"{host}:{port}", + is_url=True, + ) + + parsed = urlsplit(f"//{addr}") + if not parsed.hostname or parsed.port is None: + raise ValueError(f"Expected external engine address as host:port, got {addr!r}") + host = _format_host_for_addr(parsed.hostname) + port = parsed.port + return ExternalEngineAddress( + base_url=make_http_base_url(host, port), + host=host, + port=port, + dist_init_addr=f"{host}:{port}", + is_url=False, + ) + + +def get_external_engine_base_urls_from_args(args) -> list[str]: + addrs = getattr(args, "rollout_external_engine_addrs", None) or [] + return [parse_external_engine_addr(addr).base_url for addr in addrs] diff --git a/tests/utils/test_modal_delta_sync.py b/tests/utils/test_modal_delta_sync.py new file mode 100644 index 0000000000..9d7b53fd31 --- /dev/null +++ b/tests/utils/test_modal_delta_sync.py @@ -0,0 +1,229 @@ +import asyncio +import os +from argparse import Namespace +from types import SimpleNamespace + +import pytest + +from slime.utils.url_utils import ( + get_external_engine_base_urls_from_args, + get_model_url_from_args, + join_url, + parse_external_engine_addr, +) + + +def _modal_delta_sync(): + pytest.importorskip("modal") + from examples.delta_weight_sync import modal_delta_sync + + return modal_delta_sync + + +def test_parse_external_engine_addr_accepts_https_url(): + addr = parse_external_engine_addr("https://rollout.example.modal.run/base/") + + assert addr.base_url == "https://rollout.example.modal.run/base" + assert addr.host == "rollout.example.modal.run" + assert addr.port == 443 + assert addr.dist_init_addr == "rollout.example.modal.run:443" + assert addr.is_url is True + assert join_url(addr.base_url, "/health_generate") == "https://rollout.example.modal.run/base/health_generate" + + +def test_parse_external_engine_addr_accepts_host_port(): + addr = parse_external_engine_addr("127.0.0.1:8000") + + assert addr.base_url == "http://127.0.0.1:8000" + assert addr.host == "127.0.0.1" + assert addr.port == 8000 + assert addr.dist_init_addr == "127.0.0.1:8000" + assert addr.is_url is False + + +def test_external_engine_base_urls_from_args_normalizes_admin_urls(): + args = Namespace( + rollout_external_engine_addrs=[ + "https://rollout.example.modal.run/base/", + "127.0.0.1:8000", + ] + ) + + assert get_external_engine_base_urls_from_args(args) == [ + "https://rollout.example.modal.run/base", + "http://127.0.0.1:8000", + ] + + +def test_sglang_router_url_takes_generation_precedence(): + args = Namespace( + sglang_router_url="https://rollout.example.modal.run", + sglang_router_ip="10.0.0.1", + sglang_router_port=3000, + ) + + assert get_model_url_from_args(args, "default") == "https://rollout.example.modal.run/generate" + + +def test_named_router_may_be_full_url(): + args = Namespace( + sglang_router_ip="10.0.0.1", + sglang_router_port=3000, + sglang_model_routers={"actor": "https://actor.example.modal.run/base"}, + ) + + assert get_model_url_from_args(args, "actor") == "https://actor.example.modal.run/base/generate" + + +def test_validate_delta_update_payload_accepts_version_dir(tmp_path): + modal_delta_sync = _modal_delta_sync() + version_dir = tmp_path / "weight_v000123" + version_dir.mkdir() + + validated = modal_delta_sync.validate_delta_update_payload( + {"load_format": "delta", "model_path": str(version_dir)}, + delta_mount_path=str(tmp_path), + ) + + assert validated == os.path.realpath(version_dir) + + +@pytest.mark.parametrize( + "payload", + [ + {"load_format": "full", "model_path": "/delta/weight_v000123"}, + {"load_format": "delta"}, + {"load_format": "delta", "model_path": "/delta/not_a_version"}, + {"load_format": "delta", "model_path": "/tmp/weight_v000123"}, + ], +) +def test_validate_delta_update_payload_rejects_bad_paths(payload, tmp_path): + modal_delta_sync = _modal_delta_sync() + with pytest.raises(ValueError): + modal_delta_sync.validate_delta_update_payload(payload, delta_mount_path=str(tmp_path)) + + +def test_verify_delta_dir_ready_requires_done_and_safetensors(tmp_path): + modal_delta_sync = _modal_delta_sync() + version_dir = tmp_path / "weight_v000123" + version_dir.mkdir() + + with pytest.raises(FileNotFoundError): + modal_delta_sync.verify_delta_dir_ready(str(version_dir)) + + (version_dir / "DONE").write_text("") + with pytest.raises(FileNotFoundError): + modal_delta_sync.verify_delta_dir_ready(str(version_dir)) + + (version_dir / "rank0000_flush000000.safetensors").write_text("") + modal_delta_sync.verify_delta_dir_ready(str(version_dir)) + + +def test_external_abort_uses_explicit_engine_addrs(monkeypatch): + sglang_rollout = pytest.importorskip("slime.rollout.sglang_rollout") + posted = [] + + class FakeGenerateState: + def __init__(self, args): + self.aborted = False + self.pendings = set() + + async def fake_get(*args, **kwargs): + raise AssertionError("external abort should not query router worker endpoints") + + async def fake_post(url, payload): + posted.append((url, payload)) + return {"ok": True} + + monkeypatch.setattr(sglang_rollout, "GenerateState", FakeGenerateState) + monkeypatch.setattr(sglang_rollout, "get", fake_get) + monkeypatch.setattr(sglang_rollout, "post", fake_post) + + aborted_samples = asyncio.run( + sglang_rollout.abort( + Namespace( + rollout_external=True, + rollout_external_engine_addrs=["https://rollout.example.modal.run/base/"], + partial_rollout=False, + ), + rollout_id=0, + ) + ) + + assert posted == [("https://rollout.example.modal.run/base/abort_request", {"abort_all": True})] + assert aborted_samples == [] + + +def test_external_rollout_placement_group_reserves_only_actor_gpus(monkeypatch): + pytest.importorskip("ray") + pytest.importorskip("numpy") + pytest.importorskip("sglang") + + from slime.ray import placement_group as placement_group_mod + + calls = [] + + def fake_create_placement_group(num_gpus): + calls.append(num_gpus) + return object(), list(range(num_gpus)), list(range(num_gpus)) + + monkeypatch.setattr(placement_group_mod, "_create_placement_group", fake_create_placement_group) + + pgs = placement_group_mod.create_placement_groups( + Namespace( + debug_train_only=False, + debug_rollout_only=False, + colocate=False, + rollout_external=True, + actor_num_nodes=1, + actor_num_gpus_per_node=1, + rollout_num_gpus=1, + use_critic=False, + ) + ) + + assert calls == [1] + assert pgs["rollout"][1] == [] + assert pgs["rollout"][2] == [] + + +def test_external_rollout_engines_are_cpu_only_control_actors(monkeypatch): + pytest.importorskip("ray") + pytest.importorskip("numpy") + pytest.importorskip("sglang") + + from slime.ray import rollout as rollout_mod + + actor_options = [] + + class FakeRemoteActor: + def options(self, **kwargs): + actor_options.append(kwargs) + return self + + def remote(self, *args, **kwargs): + return SimpleNamespace(init=SimpleNamespace(remote=lambda **init_kwargs: init_kwargs)) + + monkeypatch.setattr(rollout_mod.ray, "remote", lambda cls: FakeRemoteActor()) + + group = rollout_mod.ServerGroup( + args=Namespace( + debug_train_only=False, + rollout_external=True, + rollout_external_engine_addrs=["https://rollout.example.modal.run"], + num_gpus_per_node=1, + sglang_dp_size=1, + ), + pg=(object(), [], []), + all_engines=[None], + num_gpus_per_engine=1, + num_new_engines=0, + ) + + handles, port_cursors = group.start_engines() + + assert port_cursors == {} + assert handles[0]["server_url"] == "https://rollout.example.modal.run" + assert actor_options[0]["num_gpus"] == 0 + assert actor_options[0]["num_cpus"] == 0.2 + assert "scheduling_strategy" not in actor_options[0] diff --git a/tests/utils/test_sglang_config.py b/tests/utils/test_sglang_config.py index 5015a74558..d7d56e5c80 100644 --- a/tests/utils/test_sglang_config.py +++ b/tests/utils/test_sglang_config.py @@ -1,6 +1,7 @@ """Unit tests for SglangConfig multi-model parsing with update_weights.""" import tempfile +from argparse import Namespace import pytest import yaml @@ -13,6 +14,10 @@ def _write_yaml(data: dict) -> str: return f.name +def _resolve_args() -> Namespace: + return Namespace(rollout_num_gpus_per_engine=2, hf_checkpoint="/path/to/actor") + + class TestSglangConfigUpdateWeights: def test_update_weights_default_true(self): """Models without explicit update_weights should default to True.""" @@ -29,6 +34,7 @@ def test_update_weights_default_true(self): } ) config = SglangConfig.from_yaml(path) + config.models[0].resolve(_resolve_args()) assert len(config.models) == 1 assert config.models[0].update_weights is True @@ -54,6 +60,8 @@ def test_update_weights_explicit_false(self): } ) config = SglangConfig.from_yaml(path) + for model in config.models: + model.resolve(_resolve_args()) assert len(config.models) == 2 assert config.models[0].name == "actor" assert config.models[0].update_weights is True @@ -87,9 +95,7 @@ def test_multi_model_total_gpus(self): class TestGetModelUrl: def test_get_model_url_basic(self): """get_model_url should return the correct URL for a named model.""" - from argparse import Namespace - - from slime.rollout.sglang_rollout import get_model_url + from slime.utils.url_utils import get_model_url_from_args args = Namespace( sglang_router_ip="10.0.0.1", @@ -99,34 +105,52 @@ def test_get_model_url_basic(self): "ref": ("10.0.0.1", 3001), }, ) - assert get_model_url(args, "actor") == "http://10.0.0.1:3000/generate" - assert get_model_url(args, "ref") == "http://10.0.0.1:3001/generate" - assert get_model_url(args, "ref", "/v1/chat/completions") == "http://10.0.0.1:3001/v1/chat/completions" + assert get_model_url_from_args(args, "actor") == "http://10.0.0.1:3000/generate" + assert get_model_url_from_args(args, "ref") == "http://10.0.0.1:3001/generate" + assert get_model_url_from_args(args, "ref", "/v1/chat/completions") == "http://10.0.0.1:3001/v1/chat/completions" def test_get_model_url_fallback(self): """get_model_url should fall back to default router if model not found.""" - from argparse import Namespace - - from slime.rollout.sglang_rollout import get_model_url + from slime.utils.url_utils import get_model_url_from_args args = Namespace( sglang_router_ip="10.0.0.1", sglang_router_port=3000, sglang_model_routers={"actor": ("10.0.0.1", 3000)}, ) - assert get_model_url(args, "unknown") == "http://10.0.0.1:3000/generate" + assert get_model_url_from_args(args, "unknown") == "http://10.0.0.1:3000/generate" def test_get_model_url_no_routers(self): """get_model_url should work when sglang_model_routers is not set.""" - from argparse import Namespace + from slime.utils.url_utils import get_model_url_from_args + + args = Namespace( + sglang_router_ip="10.0.0.1", + sglang_router_port=3000, + ) + assert get_model_url_from_args(args, "anything") == "http://10.0.0.1:3000/generate" + + def test_get_model_url_router_url_precedence(self): + """sglang_router_url should take precedence over router ip/port.""" + from slime.utils.url_utils import get_model_url_from_args + + args = Namespace( + sglang_router_url="https://rollout.example.modal.run", + sglang_router_ip="10.0.0.1", + sglang_router_port=3000, + ) + assert get_model_url_from_args(args, "anything") == "https://rollout.example.modal.run/generate" - from slime.rollout.sglang_rollout import get_model_url + def test_get_model_url_named_url_router(self): + """sglang_model_routers entries may be full URLs.""" + from slime.utils.url_utils import get_model_url_from_args args = Namespace( sglang_router_ip="10.0.0.1", sglang_router_port=3000, + sglang_model_routers={"actor": "https://actor.example.modal.run/base"}, ) - assert get_model_url(args, "anything") == "http://10.0.0.1:3000/generate" + assert get_model_url_from_args(args, "actor") == "https://actor.example.modal.run/base/generate" if __name__ == "__main__": diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000000..bda0207302 --- /dev/null +++ b/uv.lock @@ -0,0 +1,3 @@ +version = 1 +revision = 3 +requires-python = ">=3.13" From 3ac3c89f4158476d29676d44f0232c859272d35b Mon Sep 17 00:00:00 2001 From: Jason Mancuso <7891333+jvmncs@users.noreply.github.com> Date: Wed, 27 May 2026 18:49:01 -0400 Subject: [PATCH 2/2] address modal delta PR review comments --- examples/delta_weight_sync/README.md | 2 +- .../delta_weight_sync/modal_delta_sync.py | 4 +-- slime/backends/megatron_utils/sglang.py | 28 +------------------ slime/backends/sglang_utils/arguments.py | 15 ++-------- slime/backends/sglang_utils/sglang_engine.py | 4 +-- slime/ray/rollout.py | 10 +++---- slime/utils/arguments.py | 9 ++++++ slime/utils/url_utils.py | 4 +-- tests/utils/test_modal_delta_sync.py | 4 +-- tests/utils/test_sglang_config.py | 4 +-- 10 files changed, 29 insertions(+), 55 deletions(-) diff --git a/examples/delta_weight_sync/README.md b/examples/delta_weight_sync/README.md index 109148868d..745e1786ec 100644 --- a/examples/delta_weight_sync/README.md +++ b/examples/delta_weight_sync/README.md @@ -47,7 +47,7 @@ For custom trainer invocations, use the deployed URL as both the generation rout ```bash --rollout-external ---sglang-router-url https://your-rollout-url.modal.run +--rollout-router-url https://your-rollout-url.modal.run --rollout-external-engine-addrs https://your-rollout-url.modal.run --update-weight-delta-dir /delta --custom-delta-pre-push-path slime.backends.sglang_utils.modal_volume_hooks.commit_modal_delta_volume diff --git a/examples/delta_weight_sync/modal_delta_sync.py b/examples/delta_weight_sync/modal_delta_sync.py index a49a725c30..4e50683143 100644 --- a/examples/delta_weight_sync/modal_delta_sync.py +++ b/examples/delta_weight_sync/modal_delta_sync.py @@ -50,7 +50,7 @@ def _local_repo_root() -> Path: MEGATRON_COMMIT = "1dcf0dafa884ad52ffb243625717a3471643e087" ROLLOUT_BASE_IMAGE = os.environ.get("SLIME_MODAL_ROLLOUT_BASE_IMAGE", f"slimerl/sglang:{SGLANG_IMAGE_TAG}") -TRAINER_BASE_IMAGE = os.environ.get("SLIME_MODAL_TRAINER_BASE_IMAGE", "slimerl/slime-test:nightly-dev-20260429b") +TRAINER_BASE_IMAGE = os.environ.get("SLIME_MODAL_TRAINER_BASE_IMAGE", "slimerl/slime:nightly-dev-20260527a") AUTOINFERENCE_UTILS_VERSION = os.environ.get("AUTOINFERENCE_UTILS_VERSION", "0.2.0") HF_CACHE_PATH = "/root/.cache/huggingface" @@ -559,7 +559,7 @@ def _build_train_args( "1", "--rollout-num-gpus-per-engine", "1", - "--sglang-router-url", + "--rollout-router-url", rollout_url, "--rollout-external-engine-addrs", rollout_url, diff --git a/slime/backends/megatron_utils/sglang.py b/slime/backends/megatron_utils/sglang.py index 382b5f719e..c42c9b45fe 100644 --- a/slime/backends/megatron_utils/sglang.py +++ b/slime/backends/megatron_utils/sglang.py @@ -1,7 +1,4 @@ # the file to manage all sglang deps in the megatron actor -from dataclasses import dataclass -from enum import Enum - try: from sglang.srt.layers.quantization.fp8_utils import quant_weight_ue8m0, transform_scale_ue8m0 from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 @@ -16,30 +13,7 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions -try: - from sglang.srt.managers.io_struct import DeltaEncoding, DeltaParam, DeltaSpec -except ImportError: - - class DeltaEncoding(str, Enum): - INDICES = "indices" - DELTAS = "deltas" - DELTAS_ZSTD = "deltas_zstd" - - @dataclass - class DeltaParam: - name: str - dtype: str - shape: list[int] - pos_start: int - pos_end: int - pos_width: int - val_start: int - val_end: int - - @dataclass - class DeltaSpec: - encoding: DeltaEncoding - params: list[DeltaParam] +from sglang.srt.managers.io_struct import DeltaEncoding, DeltaParam, DeltaSpec from sglang.srt.utils import MultiprocessingSerializer diff --git a/slime/backends/sglang_utils/arguments.py b/slime/backends/sglang_utils/arguments.py index 219e332c8e..e4aa6a76cc 100644 --- a/slime/backends/sglang_utils/arguments.py +++ b/slime/backends/sglang_utils/arguments.py @@ -22,15 +22,6 @@ def add_sglang_router_arguments(parser): default=None, help="Port of the SGLang router", ) - parser.add_argument( - "--sglang-router-url", - type=str, - default=None, - help=( - "Full http(s) base URL for an externally managed SGLang router or single-engine endpoint. " - "When set, rollout generation uses this URL instead of --sglang-router-ip/port." - ), - ) parser.add_argument( "--sglang-router-request-timeout-secs", type=int, @@ -166,9 +157,9 @@ def validate_args(args): if getattr(args, "sglang_router_ip", None): args.sglang_router_ip = _wrap_ipv6(args.sglang_router_ip) - if getattr(args, "sglang_router_url", None): - args.sglang_router_url = normalize_base_url(args.sglang_router_url) - assert getattr(args, "rollout_external", False), "--sglang-router-url requires --rollout-external." + if getattr(args, "rollout_router_url", None): + args.rollout_router_url = normalize_base_url(args.rollout_router_url) + assert getattr(args, "rollout_external", False), "--rollout-router-url requires --rollout-external." if getattr(args, "rollout_external", False): external_engine_addrs = getattr(args, "rollout_external_engine_addrs", None) diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index 52125818fc..4857f1a025 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -127,8 +127,8 @@ def init( server_url=None, external_addr_is_url: bool = False, ): - self.router_ip = router_ip if router_ip is not None else getattr(self.args, "sglang_router_ip", None) - self.router_port = router_port if router_port is not None else getattr(self.args, "sglang_router_port", None) + self.router_ip = router_ip if router_ip is not None else self.args.sglang_router_ip + self.router_port = router_port if router_port is not None else self.args.sglang_router_port host = host or get_host_info()[1] diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 25b2b18955..4d531327d7 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -969,8 +969,8 @@ def _start_router(args, *, has_pd_disaggregation: bool = False, force_new: bool ``force_new`` is False, skip launching and return the existing values. When ``force_new`` is True (multi-model), always allocate a fresh port. """ - if not force_new and getattr(args, "sglang_router_url", None): - addr = parse_external_engine_addr(args.sglang_router_url) + if not force_new and getattr(args, "rollout_router_url", None): + addr = parse_external_engine_addr(args.rollout_router_url) return addr.host, addr.port if not force_new and args.sglang_router_ip is not None: @@ -1064,7 +1064,7 @@ def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: has_pd = model_cfg.has_pd_disaggregation router_ip, router_port = _start_router(args, has_pd_disaggregation=has_pd, force_new=(model_idx > 0)) - router_url = getattr(args, "sglang_router_url", None) if model_idx == 0 else None + router_url = getattr(args, "rollout_router_url", None) if model_idx == 0 else None # Write back for backward compat (first model only). if model_idx == 0: @@ -1076,7 +1076,7 @@ def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: has_epd = model_cfg.has_encoder_disaggregation - def _make_group(group_cfg, router_ip, router_port, overrides_extra=None): + def _make_group(group_cfg, router_ip, router_port, overrides_extra=None, router_url_value=router_url): nonlocal engine_offset, gpu_offset gpus_per_engine = group_cfg.num_gpus_per_engine num_gpu_per_engine_local = min(gpus_per_engine, args.num_gpus_per_node) @@ -1109,7 +1109,7 @@ def _make_group(group_cfg, router_ip, router_port, overrides_extra=None): model_path=overrides.get("model_path", args.hf_checkpoint), router_ip=router_ip, router_port=router_port, - router_url=router_url, + router_url=router_url_value, ) engine_offset += num_engines gpu_offset += group_cfg.num_gpus diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 314fa4409e..38de7801a3 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -511,6 +511,15 @@ def add_rollout_arguments(parser): default=False, help="Use external SGLang instances instead of launching them inside the framework.", ) + parser.add_argument( + "--rollout-router-url", + type=str, + default=None, + help=( + "Full http(s) base URL for an externally managed rollout generation router or " + "single-engine endpoint." + ), + ) parser.add_argument( "--rollout-external-engine-addrs", type=str, diff --git a/slime/utils/url_utils.py b/slime/utils/url_utils.py index b47727568b..2188fff513 100644 --- a/slime/utils/url_utils.py +++ b/slime/utils/url_utils.py @@ -30,8 +30,8 @@ def make_http_base_url(host: str, port: int) -> str: def get_default_router_url_from_args(args, endpoint: str = "/generate") -> str: - if getattr(args, "sglang_router_url", None): - return join_url(args.sglang_router_url, endpoint) + if getattr(args, "rollout_router_url", None): + return join_url(args.rollout_router_url, endpoint) return join_url(make_http_base_url(args.sglang_router_ip, args.sglang_router_port), endpoint) diff --git a/tests/utils/test_modal_delta_sync.py b/tests/utils/test_modal_delta_sync.py index 9d7b53fd31..897c473549 100644 --- a/tests/utils/test_modal_delta_sync.py +++ b/tests/utils/test_modal_delta_sync.py @@ -55,9 +55,9 @@ def test_external_engine_base_urls_from_args_normalizes_admin_urls(): ] -def test_sglang_router_url_takes_generation_precedence(): +def test_rollout_router_url_takes_generation_precedence(): args = Namespace( - sglang_router_url="https://rollout.example.modal.run", + rollout_router_url="https://rollout.example.modal.run", sglang_router_ip="10.0.0.1", sglang_router_port=3000, ) diff --git a/tests/utils/test_sglang_config.py b/tests/utils/test_sglang_config.py index d7d56e5c80..3508572635 100644 --- a/tests/utils/test_sglang_config.py +++ b/tests/utils/test_sglang_config.py @@ -131,11 +131,11 @@ def test_get_model_url_no_routers(self): assert get_model_url_from_args(args, "anything") == "http://10.0.0.1:3000/generate" def test_get_model_url_router_url_precedence(self): - """sglang_router_url should take precedence over router ip/port.""" + """rollout_router_url should take precedence over router ip/port.""" from slime.utils.url_utils import get_model_url_from_args args = Namespace( - sglang_router_url="https://rollout.example.modal.run", + rollout_router_url="https://rollout.example.modal.run", sglang_router_ip="10.0.0.1", sglang_router_port=3000, )