Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ jobs:
strategy:
fail-fast: false
matrix:
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_rollout_request_hook.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
defaults:
run:
working-directory: ${{ github.workspace }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
{'test_file': 'test_rm_deepscaler.py', 'num_gpus': 0},
{'test_file': 'test_sample.py', 'num_gpus': 0},
{'test_file': 'test_rollout_validation.py', 'num_gpus': 0},
{'test_file': 'test_rollout_request_hook.py', 'num_gpus': 0},
{'test_file': 'test_placement_group.py', 'num_gpus': 0},
{'test_file': 'test_external_sglang_engines.py', 'num_gpus': 0},
{'test_file': 'utils/test_hf_checkpoint_saver.py', 'num_gpus': 0},
Expand Down
14 changes: 14 additions & 0 deletions docs/en/get_started/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Below is a summary of all available customization interfaces and their purposes.
| :--- | :--- |
| [`--rollout-function-path`](#1-rollout-function---rollout-function-path) | Override the entire rollout generation logic. |
| [`--custom-generate-function-path`](#2-custom-generate-function---custom-generate-function-path) | Override only the generation step (e.g., for RAG or tool use). |
| [`--custom-rollout-request-hook-path`](#mutating-the-outgoing-request---custom-rollout-request-hook-path) | Mutate each outgoing `/generate` request (e.g., custom headers). |
| [`--custom-rm-path`](#3-reward-model---custom-rm-path) | Implement custom reward computation logic. |
| [`--dynamic-sampling-filter-path`](#4-dynamic-sampling-filter---dynamic-sampling-filter-path) | Filter samples during dynamic sampling (e.g., DAPO). |
| [`--buffer-filter-path`](#5-buffer-filter---buffer-filter-path) | Filter samples in the rollout buffer before training. |
Expand Down Expand Up @@ -118,6 +119,19 @@ If one full trajectory has a single total reward but is split into `K` training

**Example**: See [examples/search-r1/generate_with_search.py](../../../examples/search-r1/generate_with_search.py)

#### Mutating the outgoing request (`--custom-rollout-request-hook-path`)

When you keep the built-in generate function but need to adjust each `/generate` request just before it is sent, use `--custom-rollout-request-hook-path` instead of replacing the whole generate step. The hook receives a `request` dict describing how the call is sent — `url`, `payload`, `headers`, `max_retries`, `retry_sleep` — plus `args` and `sample`. It either mutates `request` in place (returning `None`) or returns a dict of updates:

```python
def hook(args, sample, request):
request["headers"] = {**(request["headers"] or {}), "Authorization": f"Bearer {get_token()}"}
```

Use it to add custom headers (auth tokens, routing keys), or for weight-version gating against an opaque rollout endpoint — set `request["payload"]["weight_version"]` so the fleet serves only a matching version, and raise `request["max_retries"]`/`request["retry_sleep"]` so slime backs off and waits for the fleet to load it.

The hook may be `async`. It runs for both built-in generate paths (the default buffered one and `sglang_streaming_rollout.generate_streaming`) only when configured — otherwise the request is sent unchanged. Your own custom generate functions that post requests directly are responsible for their own request shaping (call `apply_rollout_request_hook` if you want the same behavior).

---

### 3. Reward Model (`--custom-rm-path`)
Expand Down
7 changes: 5 additions & 2 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,11 @@ def update_weights(self) -> None:
) = ray.get(self.rollout_manager.get_updatable_engines_and_lock.remote())

reconnect_rollout_engines = self.args.offload_train and self.args.use_critic and not self.args.colocate
# An opaque HTTP rollout fleet exposes no engine handles; the trainer publishes the delta to
# disk instead of pushing, so it still runs update_weights (and connects once) with no engines.
publish_only = bool(getattr(self.args, "rollout_endpoint_url", None))

if not rollout_engines and not reconnect_rollout_engines:
if not rollout_engines and not reconnect_rollout_engines and not publish_only:
if dist.get_rank() == 0:
logger.info("No updatable SGLang engines are running; skip weight update.")
return
Expand All @@ -627,7 +630,7 @@ def update_weights(self) -> None:
elif self.args.offload_train:
reload_process_groups()

if num_new_engines > 0 or reconnect_rollout_engines:
if num_new_engines > 0 or reconnect_rollout_engines or publish_only:
self.weight_updater.connect_rollout_engines(
rollout_engines,
rollout_engine_lock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(
self.checksum_algorithm = args.update_weight_delta_checksum
self._snapshot: dict[str, np.ndarray] = {}
self._baseline_captured = False
# Opaque HTTP rollout: no engine handles, so publish the version to disk and let the fleet
# pull it, instead of pushing via per-engine RPCs.
self._publish_only = bool(getattr(args, "rollout_endpoint_url", None))
self._commit_hook: Callable | None = None
if args.custom_delta_pre_push_path:
from slime.utils.misc import load_function
Expand Down Expand Up @@ -86,13 +89,16 @@ def update_weights(self) -> None:
return

self.weight_version += 1
if dist.get_rank() == 0:
if dist.get_rank() == 0 and not self._publish_only:
ray.get([engine.pause_generation.remote() for engine in self.rollout_engines])
ray.get([engine.flush_cache.remote() for engine in self.rollout_engines])
dist.barrier(group=get_gloo_group())

self._publish()
self._reload_engines()
if self._publish_only:
self._announce_version()
else:
self._reload_engines()
self._record_metrics()

def _capture_baseline(self) -> None:
Expand Down Expand Up @@ -185,6 +191,16 @@ def _reload_engines(self) -> None:
ray.get([engine.continue_generation.remote() for engine in self.rollout_engines])
dist.barrier(group=get_gloo_group())

def _announce_version(self) -> None:
"""Publish-only: commit the version dir and advance the latest-version pointer, so the
external fleet pulls and applies it on its own. No engine handles, hence no reload RPCs."""
if self._commit_hook is not None:
self._commit_hook(self.args, self._version_dir, []) # opaque fleet: no engine handles
dist.barrier(group=get_gloo_group())
if dist.get_rank() == 0:
_atomic_write(os.path.join(self.delta_dir, "latest"), f"{self.weight_version:06d}".encode())
dist.barrier(group=get_gloo_group())

def _iter_hf_tensors(self):
"""Yield (name, gathered HF tensor) for every param: base-class TP then EP gather passes."""
for chunk_iter in (self._iter_non_expert_chunks(), self._iter_expert_chunks()):
Expand Down
21 changes: 21 additions & 0 deletions slime/backends/sglang_utils/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,24 @@ def start_external_rollout_servers(args, *, start_router) -> tuple[dict[str, Ext
)
}
return servers, init_handles


def normalize_rollout_endpoint_url(url: str) -> str:
"""Normalize an opaque HTTP rollout endpoint base URL (drop trailing slash)."""
url = url.rstrip("/")
parsed = urlparse(url)
if parsed.scheme not in ("http", "https") or parsed.netloc == "":
raise ValueError(f"Invalid --rollout-endpoint-url {url!r}. Use an absolute http:// or https:// URL.")
return url


def uses_rollout_endpoint(args) -> bool:
return bool(getattr(args, "rollout_endpoint_url", None))


def rollout_endpoint_servers(args) -> tuple[dict[str, ExternalRolloutServer], list]:
"""Rollout served by an opaque HTTP endpoint behind one URL. The fleet is elastic, so slime holds
no per-engine handles — hence no engines (weights are published to disk, not pushed) and generation
routes to the URL via get_model_url."""
logger.info("Rollout served by opaque HTTP endpoint: %s", args.rollout_endpoint_url)
return {"default": ExternalRolloutServer(engines=[], engine_gpu_counts=[], engine_gpu_offsets=[])}, []
2 changes: 1 addition & 1 deletion slime/ray/placement_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _get_placement_group_layout(args) -> tuple[int, int]:
if args.debug_train_only:
return actor_num_gpus, 0

if args.rollout_external:
if args.rollout_external or getattr(args, "rollout_endpoint_url", None):
if args.debug_rollout_only:
return 0, 0
return actor_num_gpus, actor_num_gpus
Expand Down
9 changes: 8 additions & 1 deletion slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS

from slime.backends.sglang_utils.external import start_external_rollout_servers
from slime.backends.sglang_utils.external import (
rollout_endpoint_servers,
start_external_rollout_servers,
uses_rollout_endpoint,
)
from slime.backends.sglang_utils.sglang_config import ModelConfig, ServerGroupConfig, SglangConfig
from slime.backends.sglang_utils.sglang_engine import SGLangEngine
from slime.rollout.base_types import call_rollout_fn
Expand Down Expand Up @@ -1081,6 +1085,9 @@ def start_rollout_servers(args, pg) -> tuple[dict[str, Any], list[Any]]:
Note: ``init_http_client`` should be called separately before this,
as the HTTP client is shared across all servers.
"""
if uses_rollout_endpoint(args):
return rollout_endpoint_servers(args)

if args.rollout_external:
return start_external_rollout_servers(args, start_router=_start_router)

Expand Down
79 changes: 68 additions & 11 deletions slime/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from packaging.version import parse
from tqdm import tqdm

from slime.backends.sglang_utils.external import uses_rollout_endpoint
from slime.backends.sglang_utils.server_control import abort_servers_until_idle
from slime.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput
from slime.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter
Expand Down Expand Up @@ -72,15 +73,49 @@ def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate")
resp = await post(url, json=payload)

Falls back to the default router if *model_name* is not found or
``sglang_model_routers`` is not set.
``sglang_model_routers`` is not set. With ``--rollout-endpoint-url`` set, returns that opaque
endpoint with *endpoint* appended (no router APIs are assumed to exist).
"""
if uses_rollout_endpoint(args):
return f"{args.rollout_endpoint_url}{endpoint}"
routers = getattr(args, "sglang_model_routers", None)
if routers and model_name in routers:
ip, port = routers[model_name]
return f"http://{ip}:{port}{endpoint}"
return f"http://{args.sglang_router_ip}:{args.sglang_router_port}{endpoint}"


async def apply_rollout_request_hook(
args: Namespace,
url: str,
payload: dict[str, Any],
*,
headers: dict | None,
sample: Sample,
) -> dict[str, Any]:
"""Run ``custom_rollout_request_hook_path`` on one outgoing /generate request.

The hook receives ``request = {"url", "payload", "headers", "max_retries", "retry_sleep"}`` along
with ``args`` and ``sample`` (which carries its own context, e.g. ``sample.index``) — everything
about how this one request is sent, and nothing about the rollout itself. It mutates ``request``
in place and returns None, or returns a dict of updates; this returns the resulting request.
Callers invoke this only when a hook is set, so the default path keeps calling ``post`` directly.
"""
request = {"url": url, "payload": payload, "headers": headers, "max_retries": 60, "retry_sleep": 1.0}
hook = load_function(args.custom_rollout_request_hook_path)
result = hook(args, sample, request)
if inspect.isawaitable(result):
result = await result
if result is not None:
if not isinstance(result, dict):
raise TypeError(
f"{args.custom_rollout_request_hook_path} must return None or a dict of request updates, "
f"got {type(result).__name__}"
)
request.update(result)
return request


class GenerateState(metaclass=SingletonMeta):
"""
The global state for the generation process.
Expand Down Expand Up @@ -154,7 +189,7 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A
assert isinstance(sample.prompt, str)

state = GenerateState(args)
url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
url = get_model_url(args, "default", "/generate")

assert (
sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED
Expand Down Expand Up @@ -197,7 +232,17 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A
headers = {"X-SMG-Routing-Key": sample.session_id}

with trace_span(sample, "sglang_generate", attrs={"max_new_tokens": sampling_params["max_new_tokens"]}) as span:
output = await post(url, payload, headers=headers)
if getattr(args, "custom_rollout_request_hook_path", None):
request = await apply_rollout_request_hook(args, url, payload, headers=headers, sample=sample)
output = await post(
request["url"],
request["payload"],
headers=request["headers"],
max_retries=request["max_retries"],
retry_sleep=request["retry_sleep"],
)
else:
output = await post(url, payload, headers=headers)
span.update(build_sglang_meta_trace_attrs(output["meta_info"]))

if "output_token_logprobs" in output["meta_info"]:
Expand Down Expand Up @@ -355,14 +400,26 @@ async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]:
assert not state.aborted
state.aborted = True

if parse(sglang_router.__version__) <= parse("0.2.1"):
response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers")
urls = response["urls"]
else:
response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers")
urls = [worker["url"] for worker in response["workers"]]

await abort_servers_until_idle(urls)
if uses_rollout_endpoint(args) and not args.partial_rollout:
# Opaque endpoint, surplus discarded: cancel locally — the client disconnect aborts the
# request on the fleet. No worker API to call, and nothing to collect.
for task in state.pendings:
task.cancel()
await asyncio.gather(*state.pendings, return_exceptions=True)
state.pendings = set()
return aborted_samples

if not uses_rollout_endpoint(args):
# Router: explicitly abort in-flight requests on each worker.
if parse(sglang_router.__version__) <= parse("0.2.1"):
response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers")
urls = response["urls"]
else:
response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers")
urls = [worker["url"] for worker in response["workers"]]
await abort_servers_until_idle(urls)
# Opaque endpoint + partial-rollout: the streaming tasks self-break on state.aborted and return
# their partials below; closing each stream disconnects, which aborts the request on the fleet.

# make sure all the pending tasks are finished
count = 0
Expand Down
Loading