diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 10f1c2d808..69e0cceedf 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -372,7 +372,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] + info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_rollout_request_hook.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "test_megatron_to_hf_router_dtype.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 5621374ad3..45b6068840 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -68,6 +68,7 @@ {'test_file': 'test_metric_report.py', 'num_gpus': 0}, {'test_file': 'test_metric_report_dist.py', 'num_gpus': 0}, {'test_file': 'test_loss_cp_invariance.py', 'num_gpus': 0}, + {'test_file': 'test_logprob_response_spans.py', 'num_gpus': 0}, {'test_file': 'test_value_temperature.py', 'num_gpus': 0}, {'test_file': 'test_cispo_loss.py', 'num_gpus': 0}, {'test_file': 'test_rm_f1.py', 'num_gpus': 0}, @@ -77,8 +78,10 @@ {'test_file': 'test_rm_deepscaler.py', 'num_gpus': 0}, {'test_file': 'test_sample.py', 'num_gpus': 0}, {'test_file': 'test_rollout_validation.py', 'num_gpus': 0}, + {'test_file': 'test_rollout_request_hook.py', 'num_gpus': 0}, {'test_file': 'test_placement_group.py', 'num_gpus': 0}, {'test_file': 'test_external_sglang_engines.py', 'num_gpus': 0}, + {'test_file': 'test_megatron_to_hf_router_dtype.py', 'num_gpus': 0}, {'test_file': 'utils/test_hf_checkpoint_saver.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_rollout_contracts.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_runtime_hook_contracts.py', 'num_gpus': 0}, diff --git a/docker/Dockerfile b/docker/Dockerfile index af1f1a3d94..bc5f7de1f1 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -115,6 +115,18 @@ RUN if [ "$ENABLE_SGLANG_PATCH" = "1" ]; then \ rm sglang.patch; \ fi +COPY docker/patch/${PATCH_VERSION}/sglang-top_p.patch /sgl-workspace/sglang/ +RUN if [ "$ENABLE_SGLANG_PATCH" = "1" ]; then \ + cd /sgl-workspace/sglang && \ + git update-index --refresh && \ + git apply sglang-top_p.patch --3way && \ + if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi && \ + rm sglang-top_p.patch; \ +fi + # ====================================== Install main package ============================================ ARG SLIME_COMMIT=main diff --git a/docker/patch/latest/sglang-top_p.patch b/docker/patch/latest/sglang-top_p.patch new file mode 100644 index 0000000000..e0cfba6ce4 --- /dev/null +++ b/docker/patch/latest/sglang-top_p.patch @@ -0,0 +1,910 @@ +diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py +index 70265a424f5..a19278e8828 100644 +--- a/python/sglang/srt/disaggregation/decode.py ++++ b/python/sglang/srt/disaggregation/decode.py +@@ -1488,6 +1488,8 @@ class DecodeTransferQueue(DecodeHiCacheTransferMixin): + output_token_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, ++ output_top_p_token_ids_len, ++ output_top_p_token_ids, + output_topk_p, + output_topk_index, + output_hidden_states, +@@ -1580,6 +1582,11 @@ class DecodeTransferQueue(DecodeHiCacheTransferMixin): + : decode_req.req.logprob.top_logprobs_num + ].tolist() + ) ++ top_p_token_ids_len = output_top_p_token_ids_len[0].item() ++ if top_p_token_ids_len > 0: ++ decode_req.req.logprob.output_top_p_token_ids.append( ++ output_top_p_token_ids[:top_p_token_ids_len].tolist() ++ ) + + if is_slime_profiling_enabled(): + apply_prefill_timing_payload( +diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py +index a44685777ea..f21e4dad4a5 100644 +--- a/python/sglang/srt/disaggregation/utils.py ++++ b/python/sglang/srt/disaggregation/utils.py +@@ -1,5 +1,6 @@ + from __future__ import annotations + ++import logging + import os + import random + import time +@@ -42,6 +43,8 @@ PREFILL_TIMING_DEST_ATTRS = ( + ("fwd_transfer_total_mb", float), + ("fwd_prefill_retry_count", int), + ) ++MAX_PD_TOP_P_TOKEN_IDS = 4096 ++logger = logging.getLogger(__name__) + + + class DisaggregationMode(Enum): +@@ -253,6 +256,12 @@ class MetadataBuffers: + self.output_top_logprobs_idx = torch.zeros( + (size, max_top_logprobs_num), dtype=torch.int32, device=device + ) ++ self.output_top_p_token_ids_len = torch.zeros( ++ (size, 16), dtype=torch.int32, device=device ++ ) ++ self.output_top_p_token_ids = torch.zeros( ++ (size, MAX_PD_TOP_P_TOKEN_IDS), dtype=torch.int32, device=device ++ ) + # For PD + spec decode + self.output_topk_p = torch.zeros( + (size, 16), dtype=torch.float32, device=device +@@ -277,6 +286,8 @@ class MetadataBuffers: + ("output_token_logprobs_idx", self.output_token_logprobs_idx), + ("output_top_logprobs_val", self.output_top_logprobs_val), + ("output_top_logprobs_idx", self.output_top_logprobs_idx), ++ ("output_top_p_token_ids_len", self.output_top_p_token_ids_len), ++ ("output_top_p_token_ids", self.output_top_p_token_ids), + ("output_topk_p", self.output_topk_p), + ("output_topk_index", self.output_topk_index), + ("output_hidden_states", self.output_hidden_states), +@@ -301,6 +312,8 @@ class MetadataBuffers: + self.output_token_logprobs_idx[idx].clone(), + self.output_top_logprobs_val[idx].clone(), + self.output_top_logprobs_idx[idx].clone(), ++ self.output_top_p_token_ids_len[idx].clone(), ++ self.output_top_p_token_ids[idx].clone(), + self.output_topk_p[idx].clone(), + self.output_topk_index[idx].clone(), + self.output_hidden_states[idx].clone(), +@@ -318,6 +331,7 @@ class MetadataBuffers: + self.cached_tokens[req.metadata_buffer_index][1] = req.cached_tokens_device + self.cached_tokens[req.metadata_buffer_index][2] = req.cached_tokens_host + self.cached_tokens[req.metadata_buffer_index][3] = req.cached_tokens_storage ++ self.output_top_p_token_ids_len[req.metadata_buffer_index][0] = 0 + if req.return_logprob: + if req.logprob.output_token_logprobs_val: # not none or empty list + self.output_token_logprobs_val[req.metadata_buffer_index][0] = ( +@@ -344,6 +358,28 @@ class MetadataBuffers: + dtype=torch.int32, + device="cpu", + ) ++ if req.logprob.output_top_p_token_ids: # not none or empty list ++ output_top_p_token_ids = req.logprob.output_top_p_token_ids[0] ++ if len(output_top_p_token_ids) > MAX_PD_TOP_P_TOKEN_IDS: ++ logger.warning( ++ "PD top-p token replay payload for the first output token " ++ "has %s ids, exceeding the metadata buffer cap %s. " ++ "Falling back to the sampled token only.", ++ len(output_top_p_token_ids), ++ MAX_PD_TOP_P_TOKEN_IDS, ++ ) ++ output_top_p_token_ids = [int(req.output_ids[0])] ++ top_p_len = len(output_top_p_token_ids) ++ self.output_top_p_token_ids_len[req.metadata_buffer_index][0] = ( ++ top_p_len ++ ) ++ self.output_top_p_token_ids[req.metadata_buffer_index][ ++ :top_p_len ++ ] = torch.tensor( ++ output_top_p_token_ids, ++ dtype=torch.int32, ++ device="cpu", ++ ) + # For PD + spec decode + if req.hidden_states_tensor is not None: + # speculative_eagle_topk should not be greater than 16 currently +diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py +index 15a3d6aac9b..8e681dd4013 100644 +--- a/python/sglang/srt/layers/logits_processor.py ++++ b/python/sglang/srt/layers/logits_processor.py +@@ -110,6 +110,8 @@ class LogitsProcessorOutput: + List[Union[List[float], torch.Tensor]] + ] = None + next_token_token_ids_logprobs_idx: Optional[List] = None ++ # The kept token ids used by rollout top-p replay. One tensor/list per sampled row. ++ next_token_top_p_token_ids: Optional[List[Optional[torch.Tensor]]] = None + + ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor + # The logprobs of input tokens. shape: [#token] +diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py +index 84d57b8e2c0..51f1aebb713 100644 +--- a/python/sglang/srt/layers/sampler.py ++++ b/python/sglang/srt/layers/sampler.py +@@ -12,7 +12,12 @@ from sglang.srt.layers.dp_attention import ( + ) + from sglang.srt.layers.logits_processor import LogitsProcessorOutput + from sglang.srt.layers.utils.hash import murmur_hash32 +-from sglang.srt.layers.utils.logprob import get_token_ids_logprobs, get_top_logprobs ++from sglang.srt.layers.utils.logprob import ( ++ get_token_ids_logprobs, ++ get_top_logprobs, ++ get_top_p_token_ids_from_probs, ++ renorm_logprob_over_top_p, ++) + from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo + from sglang.srt.sampling.sampling_params import TOP_K_ALL + from sglang.srt.server_args import get_global_server_args +@@ -124,6 +129,9 @@ class Sampler(nn.Module): + _aiter_greedy_sample(batch_next_token_ids, logits) + else: + batch_next_token_ids = torch.argmax(logits, -1) ++ self._attach_greedy_top_p_token_ids_to_output( ++ logits_output, batch_next_token_ids, sampling_info ++ ) + if return_logprob: + original_logprobs = logprobs = torch.nn.functional.log_softmax( + logits, dim=-1 +@@ -154,6 +162,7 @@ class Sampler(nn.Module): + if self.use_ascend_backend: + # Ascend backend: sample from logits directly. + batch_next_token_ids, logprobs = self._forward_ascend_backend( ++ logits_output, + logits, + sampling_info, + simple_sampling_case, +@@ -181,15 +190,45 @@ class Sampler(nn.Module): + logits[:] = torch.softmax(logits, dim=-1) + probs = logits + ++ self._attach_top_p_token_ids_to_output( ++ logits_output, ++ probs, ++ sampling_info, ++ simple_sampling_case, ++ ) + batch_next_token_ids = self._sample_from_probs( + probs, sampling_info, positions, simple_sampling_case + ) + if return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB: +- logprobs = ( +- logprobs_via_logsoftmax_kernel +- if logprobs_via_logsoftmax_kernel is not None +- else torch.log(probs) +- ) ++ top_p_logprobs = None ++ if sampling_info.need_return_top_p_token_ids: ++ # Force-keep the sampled token in the renorm denominator: ++ # SGLang samples with the flashinfer kernel but ++ # renorm_logprob_over_top_p computes the nucleus with a ++ # torch reimplementation. The two can disagree at the ++ # nucleus boundary, so a sampled token may fall outside ++ # the torch nucleus and get a -inf renormalized logprob ++ # (-> NaN downstream). Force-keeping it makes the ++ # denominator ``nucleus ∪ {sampled}``, matching the ++ # trainer which also force-keeps the target token. ++ top_p_logprobs = renorm_logprob_over_top_p( ++ probs=probs, ++ top_ks=sampling_info.top_ks, ++ top_ps=sampling_info.top_ps, ++ min_ps=sampling_info.min_ps, ++ need_top_p_sampling=sampling_info.need_top_p_sampling, ++ need_min_p_sampling=sampling_info.need_min_p_sampling, ++ request_mask=sampling_info.return_top_p_token_ids, ++ force_keep_token_ids=batch_next_token_ids, ++ ) ++ if top_p_logprobs is not None: ++ logprobs = top_p_logprobs ++ else: ++ logprobs = ( ++ logprobs_via_logsoftmax_kernel ++ if logprobs_via_logsoftmax_kernel is not None ++ else torch.log(probs) ++ ) + del probs + + # Attach logprobs to logits_output (in-place modification) +@@ -318,6 +357,7 @@ class Sampler(nn.Module): + + def _forward_ascend_backend( + self, ++ logits_output: LogitsProcessorOutput, + logits: torch.Tensor, + sampling_info: SamplingBatchInfo, + simple_sampling_case: bool, +@@ -334,6 +374,15 @@ class Sampler(nn.Module): + when return_logprob is False or SGLANG_RETURN_ORIGINAL_LOGPROB is set. + """ + logits.div_(sampling_info.temperatures) ++ if sampling_info.need_return_top_p_token_ids and not simple_sampling_case: ++ probs = torch.softmax(logits, dim=-1) ++ self._attach_top_p_token_ids_to_output( ++ logits_output, ++ probs, ++ sampling_info, ++ simple_sampling_case, ++ ) ++ del probs + batch_next_token_ids = self._sample_from_logits( + logits, sampling_info, simple_sampling_case, positions + ) +@@ -342,6 +391,49 @@ class Sampler(nn.Module): + logprobs = torch.log_softmax(logits, dim=-1) + return batch_next_token_ids, logprobs + ++ def _attach_greedy_top_p_token_ids_to_output( ++ self, ++ logits_output: LogitsProcessorOutput, ++ batch_next_token_ids: torch.Tensor, ++ sampling_info: SamplingBatchInfo, ++ ) -> None: ++ if not sampling_info.need_return_top_p_token_ids: ++ return ++ ++ request_mask = sampling_info.return_top_p_token_ids ++ logits_output.next_token_top_p_token_ids = [ ++ batch_next_token_ids[i : i + 1].to(torch.int32) ++ if bool(request_mask[i].item()) ++ else None ++ for i in range(len(batch_next_token_ids)) ++ ] ++ ++ def _attach_top_p_token_ids_to_output( ++ self, ++ logits_output: LogitsProcessorOutput, ++ probs: torch.Tensor, ++ sampling_info: SamplingBatchInfo, ++ simple_sampling_case: bool, ++ ) -> None: ++ if ( ++ not sampling_info.need_return_top_p_token_ids ++ or sampling_info.return_top_p_token_ids is None ++ or simple_sampling_case ++ ): ++ return ++ ++ top_p_token_ids = get_top_p_token_ids_from_probs( ++ probs=probs, ++ top_ks=sampling_info.top_ks, ++ top_ps=sampling_info.top_ps, ++ min_ps=sampling_info.min_ps, ++ need_top_p_sampling=sampling_info.need_top_p_sampling, ++ need_min_p_sampling=sampling_info.need_min_p_sampling, ++ request_mask=sampling_info.return_top_p_token_ids, ++ ) ++ if top_p_token_ids is not None: ++ logits_output.next_token_top_p_token_ids = top_p_token_ids ++ + def _attach_logprobs_to_output( + self, + logits_output: LogitsProcessorOutput, +diff --git a/python/sglang/srt/layers/utils/logprob.py b/python/sglang/srt/layers/utils/logprob.py +index 2ce41792a88..986f168d7b4 100644 +--- a/python/sglang/srt/layers/utils/logprob.py ++++ b/python/sglang/srt/layers/utils/logprob.py +@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, Optional, Union + import torch + + from sglang.srt.environ import envs ++from sglang.srt.sampling.sampling_params import TOP_K_ALL + + if TYPE_CHECKING: + from sglang.srt.layers.logits_processor import LogitsMetadata, LogitsProcessorOutput +@@ -88,6 +89,110 @@ def get_top_logprobs( + ) + + ++def _top_p_filter_rows( ++ top_ks: torch.Tensor, ++ top_ps: torch.Tensor, ++ min_ps: torch.Tensor, ++ need_top_p_sampling: bool, ++ need_min_p_sampling: bool, ++ request_mask: torch.Tensor, ++) -> torch.Tensor: ++ """Rows that were requested AND actually have a top-k/top-p/min-p filter.""" ++ row_has_filter = top_ks != TOP_K_ALL ++ if need_top_p_sampling: ++ row_has_filter = row_has_filter | (top_ps != 1.0) ++ if need_min_p_sampling: ++ row_has_filter = row_has_filter | (min_ps > 0) ++ return request_mask & row_has_filter ++ ++ ++def _top_p_keep_mask_sorted( ++ probs: torch.Tensor, ++ top_ks: torch.Tensor, ++ top_ps: torch.Tensor, ++ min_ps: torch.Tensor, ++ need_top_p_sampling: bool, ++ need_min_p_sampling: bool, ++) -> tuple[torch.Tensor, torch.Tensor]: ++ """Boolean nucleus keep-mask in descending-prob order, plus the sort indices. ++ ++ Reproduces SGLang's sampler truncation (rank < top_k, cumulative prob within ++ top_p, prob >= top1 * min_p) so replay sees the exact set the sampler keeps. ++ """ ++ probs_sort, probs_idx = probs.sort(dim=-1, descending=True) ++ ranks = torch.arange(probs_sort.shape[-1], device=probs_sort.device).view(1, -1) ++ keep = ranks < top_ks.view(-1, 1) ++ if need_top_p_sampling: ++ keep &= (torch.cumsum(probs_sort, dim=-1) - probs_sort) <= top_ps.view(-1, 1) ++ if need_min_p_sampling: ++ keep &= probs_sort >= (probs_sort[:, 0] * min_ps).view(-1, 1) ++ return keep, probs_idx ++ ++ ++def renorm_logprob_over_top_p( ++ probs: torch.Tensor, ++ top_ks: torch.Tensor, ++ top_ps: torch.Tensor, ++ min_ps: torch.Tensor, ++ need_top_p_sampling: bool, ++ need_min_p_sampling: bool, ++ request_mask: torch.Tensor, ++ force_keep_token_ids: Optional[torch.Tensor] = None, ++) -> Optional[torch.Tensor]: ++ rows = _top_p_filter_rows( ++ top_ks, top_ps, min_ps, need_top_p_sampling, need_min_p_sampling, request_mask ++ ) ++ if not bool(rows.any().item()): ++ return None ++ ++ keep, probs_idx = _top_p_keep_mask_sorted( ++ probs, top_ks, top_ps, min_ps, need_top_p_sampling, need_min_p_sampling ++ ) ++ # Scatter the keep-mask back to vocab order so we renormalize directly over ++ # vocab ids (and can force-keep specific token ids). ++ keep_vocab = torch.empty_like(keep) ++ keep_vocab.scatter_(-1, probs_idx, keep) ++ ++ if force_keep_token_ids is not None: ++ # Force-keep the sampled/accepted token so its renormalized logprob is ++ # finite even when SGLang's sampling kernel (e.g. flashinfer) keeps a ++ # boundary token that this torch nucleus drops. This matches the trainer, ++ # which also force-keeps the target token before renormalizing, so the ++ # rollout and training denominators are both ``nucleus ∪ {token}``. ++ # Non-filter rows are overwritten by the ``torch.where`` below, so ++ # force-keeping every row is harmless and avoids a row gather. ++ row_idx = torch.arange(keep_vocab.shape[0], device=keep_vocab.device) ++ keep_vocab[row_idx, force_keep_token_ids] = True ++ ++ kept_probs = probs * keep_vocab ++ kept_probs = kept_probs / kept_probs.sum(dim=-1, keepdim=True).clamp_min(1e-12) ++ return torch.where(rows.view(-1, 1), torch.log(kept_probs), torch.log(probs)) ++ ++ ++def get_top_p_token_ids_from_probs( ++ probs: torch.Tensor, ++ top_ks: torch.Tensor, ++ top_ps: torch.Tensor, ++ min_ps: torch.Tensor, ++ need_top_p_sampling: bool, ++ need_min_p_sampling: bool, ++ request_mask: torch.Tensor, ++) -> Optional[List[Optional[torch.Tensor]]]: ++ rows = _top_p_filter_rows( ++ top_ks, top_ps, min_ps, need_top_p_sampling, need_min_p_sampling, request_mask ++ ) ++ if not bool(rows.any().item()): ++ return None ++ ++ keep, probs_idx = _top_p_keep_mask_sorted( ++ probs, top_ks, top_ps, min_ps, need_top_p_sampling, need_min_p_sampling ++ ) ++ return [ ++ probs_idx[i][keep[i]].to(torch.int32) if bool(rows[i].item()) else None ++ for i in range(probs.shape[0]) ++ ] ++ ++ + def get_token_ids_logprobs_raw( + logprobs: torch.Tensor, + token_ids_logprobs_list: List[Optional[List[int]]], +@@ -403,6 +508,12 @@ def add_output_logprobs_for_spec_v1( + req.logprob.output_token_ids_logprobs_idx.append( + token_ids_logprobs_idx[pt] + ) ++ if logits_output.next_token_top_p_token_ids: ++ row_top_p_token_ids = logits_output.next_token_top_p_token_ids[pt] ++ if row_top_p_token_ids is not None: ++ if torch.is_tensor(row_top_p_token_ids): ++ row_top_p_token_ids = row_top_p_token_ids.tolist() ++ req.logprob.output_top_p_token_ids.append(row_top_p_token_ids) + pt += 1 + + +@@ -411,13 +522,14 @@ def compute_spec_v2_logprobs( + logits_output, + predict: torch.Tensor, + accept_index: torch.Tensor, ++ accept_lens: torch.Tensor, + speculative_num_steps: int, + ): + """Compute logprobs for accepted tokens after spec v2 verify sampling. + + Gathers logits at accepted positions, applies log_softmax (temperature-scaled +- if not greedy), and populates logits_output.next_token_logprobs (plus optional +- top-k / token-ids logprobs) so they flow through copy_to_cpu(). ++ if not greedy), and populates logits_output.next_token_logprobs plus optional ++ top-k / token-ids / top-p replay metadata so they flow through copy_to_cpu(). + """ + bs = len(batch.seq_lens) + max_accept = speculative_num_steps + 1 +@@ -425,6 +537,7 @@ def compute_spec_v2_logprobs( + + flat_accept_idx = accept_index.long().reshape(-1) + gathered_logits = logits_output.next_token_logits[flat_accept_idx] ++ temperatures = None + + if batch.sampling_info.is_all_greedy or envs.SGLANG_RETURN_ORIGINAL_LOGPROB.get(): + gathered_logprobs = torch.nn.functional.log_softmax(gathered_logits, dim=-1) +@@ -446,6 +559,80 @@ def compute_spec_v2_logprobs( + ] + logits_output.next_token_logprobs = token_logprobs.reshape(bs, max_accept) + ++ if batch.sampling_info.need_return_top_p_token_ids: ++ valid_accept_mask = ( ++ torch.arange(max_accept, device=device).view(1, -1) ++ < accept_lens.view(-1, 1) ++ ).reshape(-1) ++ request_mask = ( ++ torch.repeat_interleave( ++ batch.sampling_info.return_top_p_token_ids, max_accept ++ ) ++ & valid_accept_mask ++ ) ++ ++ if batch.sampling_info.is_all_greedy: ++ logits_output.next_token_top_p_token_ids = [ ++ accepted_token_ids[i : i + 1].to(torch.int32) ++ if bool(request_mask[i].item()) ++ else None ++ for i in range(bs * max_accept) ++ ] ++ else: ++ if temperatures is None: ++ temperatures = torch.repeat_interleave( ++ batch.sampling_info.temperatures, ++ max_accept, ++ dim=0, ++ ) ++ probs = torch.softmax(gathered_logits / temperatures, dim=-1) ++ expanded_top_ks = torch.repeat_interleave( ++ batch.sampling_info.top_ks, max_accept ++ ) ++ expanded_top_ps = torch.repeat_interleave( ++ batch.sampling_info.top_ps, max_accept ++ ) ++ expanded_min_ps = torch.repeat_interleave( ++ batch.sampling_info.min_ps, max_accept ++ ) ++ top_p_token_ids = get_top_p_token_ids_from_probs( ++ probs=probs, ++ top_ks=expanded_top_ks, ++ top_ps=expanded_top_ps, ++ min_ps=expanded_min_ps, ++ need_top_p_sampling=batch.sampling_info.need_top_p_sampling, ++ need_min_p_sampling=False, ++ request_mask=request_mask, ++ ) ++ if top_p_token_ids is not None: ++ logits_output.next_token_top_p_token_ids = top_p_token_ids ++ ++ renorm_logprobs = renorm_logprob_over_top_p( ++ probs=probs, ++ top_ks=expanded_top_ks, ++ top_ps=expanded_top_ps, ++ min_ps=expanded_min_ps, ++ need_top_p_sampling=batch.sampling_info.need_top_p_sampling, ++ need_min_p_sampling=False, ++ request_mask=request_mask, ++ # Force-keep the accepted token: a small fraction of ++ # speculatively accepted tokens land outside their own top-p ++ # nucleus, which would give a -inf renormalized logprob. ++ # Force-keeping makes the denominator ``nucleus ∪ {accepted}``, ++ # matching the trainer which also force-keeps the target token, ++ # so these tokens stay finite and on-policy. ++ force_keep_token_ids=accepted_token_ids.long(), ++ ) ++ if renorm_logprobs is not None: ++ idx = torch.arange(bs * max_accept, device=device) ++ renorm_token_logprobs = renorm_logprobs[idx, accepted_token_ids.long()] ++ renorm_token_logprobs.clamp_( ++ min=torch.finfo(renorm_token_logprobs.dtype).min ++ ) ++ logits_output.next_token_logprobs = renorm_token_logprobs.reshape( ++ bs, max_accept ++ ) ++ + if batch.top_logprobs_nums and any(x > 0 for x in batch.top_logprobs_nums): + top_logprobs_nums_expanded = [ + num for num in batch.top_logprobs_nums for _ in range(max_accept) +diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py +index 065546088a4..5daa5724781 100644 +--- a/python/sglang/srt/managers/detokenizer_manager.py ++++ b/python/sglang/srt/managers/detokenizer_manager.py +@@ -431,6 +431,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): + input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx, + output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val, + output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, ++ output_top_p_token_ids=recv_obj.output_top_p_token_ids, + output_token_entropy_val=recv_obj.output_token_entropy_val, + output_hidden_states=recv_obj.output_hidden_states, + routed_experts=routed_experts, +diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py +index e098565729b..17f2d5e96a0 100644 +--- a/python/sglang/srt/managers/io_struct.py ++++ b/python/sglang/srt/managers/io_struct.py +@@ -1138,6 +1138,7 @@ class BatchTokenIDOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin): + input_token_ids_logprobs_idx: List[List] + output_token_ids_logprobs_val: List[List] + output_token_ids_logprobs_idx: List[List] ++ output_top_p_token_ids: List[List[List[int]]] + output_token_entropy_val: List[float] + + # Hidden states +@@ -1204,6 +1205,7 @@ class BatchStrOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin): + input_token_ids_logprobs_idx: List[List] + output_token_ids_logprobs_val: List[List] + output_token_ids_logprobs_idx: List[List] ++ output_top_p_token_ids: List[List[List[int]]] + output_token_entropy_val: List[float] + + # Hidden states +diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py +index d675c124558..01fa4610ebe 100644 +--- a/python/sglang/srt/managers/multi_tokenizer_mixin.py ++++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py +@@ -199,6 +199,9 @@ def _handle_output_by_index(output, i): + output_token_ids_logprobs_idx=_extract_field_by_index( + output, "output_token_ids_logprobs_idx", i, check_length=False + ), ++ output_top_p_token_ids=_extract_field_by_index( ++ output, "output_top_p_token_ids", i, check_length=False ++ ), + output_token_entropy_val=_extract_field_by_index( + output, "output_token_entropy_val", i, check_length=False + ), +@@ -289,6 +292,9 @@ def _handle_output_by_index(output, i): + output_token_ids_logprobs_idx=_extract_field_by_index( + output, "output_token_ids_logprobs_idx", i, check_length=False + ), ++ output_top_p_token_ids=_extract_field_by_index( ++ output, "output_top_p_token_ids", i, check_length=False ++ ), + output_token_entropy_val=_extract_field_by_index( + output, "output_token_entropy_val", i, check_length=False + ), +diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py +index c369b070b57..a1b18a3ae9a 100755 +--- a/python/sglang/srt/managers/schedule_batch.py ++++ b/python/sglang/srt/managers/schedule_batch.py +@@ -638,6 +638,7 @@ class ReqLogprob: + None + ) + output_token_ids_logprobs_idx: Optional[list] = None ++ output_top_p_token_ids: Optional[list] = None + + + class Req(ReqDllmMixin): +@@ -883,6 +884,7 @@ class Req(ReqDllmMixin): + # Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring) + self.logprob.output_token_ids_logprobs_val = [] + self.logprob.output_token_ids_logprobs_idx = [] ++ self.logprob.output_top_p_token_ids = [] + self.hidden_states: List[List[float]] = [] + self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP + self.output_topk_p = None +diff --git a/python/sglang/srt/managers/scheduler_components/batch_result_processor.py b/python/sglang/srt/managers/scheduler_components/batch_result_processor.py +index a2823395871..5377e7ed8aa 100644 +--- a/python/sglang/srt/managers/scheduler_components/batch_result_processor.py ++++ b/python/sglang/srt/managers/scheduler_components/batch_result_processor.py +@@ -382,6 +382,11 @@ class SchedulerBatchResultProcessor: + logits_output.next_token_token_ids_logprobs_val = [ + v.tolist() for v in logits_output.next_token_token_ids_logprobs_val + ] ++ if logits_output.next_token_top_p_token_ids: ++ logits_output.next_token_top_p_token_ids = [ ++ v.tolist() if torch.is_tensor(v) else v ++ for v in logits_output.next_token_top_p_token_ids ++ ] + + def _apply_prefill_logprobs( + self, +@@ -733,6 +738,11 @@ class SchedulerBatchResultProcessor: + v.tolist() + for v in logits_output.next_token_token_ids_logprobs_val + ] ++ if logits_output.next_token_top_p_token_ids: ++ logits_output.next_token_top_p_token_ids = [ ++ v.tolist() if torch.is_tensor(v) else v ++ for v in logits_output.next_token_top_p_token_ids ++ ] + # else: Spec V1 — output_ids, update_finish_state, grammar, and reasoning tokens + # are already handled in the verify phase (eagle_info.py / ngram_info.py). + return next_token_ids, next_token_logprobs +@@ -777,6 +787,14 @@ class SchedulerBatchResultProcessor: + req.logprob.output_token_ids_logprobs_idx.append( + logits_output.next_token_token_ids_logprobs_idx[flat_idx] + ) ++ if logits_output.next_token_top_p_token_ids: ++ row_top_p_token_ids = logits_output.next_token_top_p_token_ids[ ++ i * max_accept + j ++ ] ++ if row_top_p_token_ids is not None: ++ if torch.is_tensor(row_top_p_token_ids): ++ row_top_p_token_ids = row_top_p_token_ids.tolist() ++ req.logprob.output_top_p_token_ids.append(row_top_p_token_ids) + + def _apply_decode_grammar( + self, +diff --git a/python/sglang/srt/managers/scheduler_components/logprob_result_processor.py b/python/sglang/srt/managers/scheduler_components/logprob_result_processor.py +index d97d9ae801e..a9f9372ef73 100644 +--- a/python/sglang/srt/managers/scheduler_components/logprob_result_processor.py ++++ b/python/sglang/srt/managers/scheduler_components/logprob_result_processor.py +@@ -314,6 +314,13 @@ class SchedulerLogprobResultProcessor: + output.next_token_token_ids_logprobs_idx[i] + ) + ++ if output.next_token_top_p_token_ids: ++ row_top_p_token_ids = output.next_token_top_p_token_ids[i] ++ if row_top_p_token_ids is not None: ++ if torch.is_tensor(row_top_p_token_ids): ++ row_top_p_token_ids = row_top_p_token_ids.tolist() ++ req.logprob.output_top_p_token_ids.append(row_top_p_token_ids) ++ + return num_input_logprobs + + def _initialize_empty_logprob_containers(self, req: Req) -> None: +diff --git a/python/sglang/srt/managers/scheduler_components/output_streamer.py b/python/sglang/srt/managers/scheduler_components/output_streamer.py +index 2574fcfb55c..a2cfea4509f 100644 +--- a/python/sglang/srt/managers/scheduler_components/output_streamer.py ++++ b/python/sglang/srt/managers/scheduler_components/output_streamer.py +@@ -291,6 +291,7 @@ class _GenerationStreamAccumulator: + input_token_ids_logprobs_idx: Optional[list] = None + output_token_ids_logprobs_val: Optional[list] = None + output_token_ids_logprobs_idx: Optional[list] = None ++ output_top_p_token_ids: Optional[list] = None + + def __post_init__(self) -> None: + if self.return_hidden_states: +@@ -313,6 +314,7 @@ class _GenerationStreamAccumulator: + self.input_token_ids_logprobs_idx = [] + self.output_token_ids_logprobs_val = [] + self.output_token_ids_logprobs_idx = [] ++ self.output_top_p_token_ids = [] + + def accept(self, *, req: Req) -> None: + if req.finished(): +@@ -450,6 +452,11 @@ class _GenerationStreamAccumulator: + send_output_token_logprobs_offset:logprob_end + ] + ) ++ self.output_top_p_token_ids.append( ++ req.logprob.output_top_p_token_ids[ ++ send_output_token_logprobs_offset:logprob_end ++ ] ++ ) + req.send_output_token_logprobs_offset = logprob_end + else: + self.output_token_logprobs_val.append([]) +@@ -458,6 +465,7 @@ class _GenerationStreamAccumulator: + self.output_top_logprobs_idx.append([]) + self.output_token_ids_logprobs_val.append([]) + self.output_token_ids_logprobs_idx.append([]) ++ self.output_top_p_token_ids.append([]) + + if self.return_hidden_states: + self.output_hidden_states.append( +@@ -516,6 +524,7 @@ class _GenerationStreamAccumulator: + input_token_ids_logprobs_idx=self.input_token_ids_logprobs_idx, + output_token_ids_logprobs_val=self.output_token_ids_logprobs_val, + output_token_ids_logprobs_idx=self.output_token_ids_logprobs_idx, ++ output_top_p_token_ids=self.output_top_p_token_ids, + output_token_entropy_val=None, + output_hidden_states=self.output_hidden_states, + routed_experts=self.routed_experts, +diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py +index 1f6dc90e471..a8c03bb0d65 100644 +--- a/python/sglang/srt/managers/tokenizer_manager.py ++++ b/python/sglang/srt/managers/tokenizer_manager.py +@@ -205,6 +205,7 @@ class ReqState: + output_top_logprobs: List[Any] = dataclasses.field(default_factory=list) + input_token_ids_logprobs: List[Any] = dataclasses.field(default_factory=list) + output_token_ids_logprobs: List[Any] = dataclasses.field(default_factory=list) ++ output_top_p_token_ids: List[List[int]] = dataclasses.field(default_factory=list) + customized_info_accumulated: Dict[str, List[Any]] = dataclasses.field( + default_factory=dict + ) +@@ -226,6 +227,21 @@ def _slice_streaming_output_meta_info( + meta_info[key] = meta_info[key][last_output_offset:] + + ++def _b64_encode_int32(values: List[int]) -> str: ++ int32_values = array("i", values) ++ assert int32_values.itemsize == 4 ++ return pybase64.b64encode(int32_values.tobytes()).decode("utf-8") ++ ++ ++def _encode_top_p_token_ids(rows: List[List[int]]) -> Tuple[str, str]: ++ token_ids = [] ++ offsets = [0] ++ for row in rows: ++ token_ids.extend(int(token_id) for token_id in row) ++ offsets.append(len(token_ids)) ++ return _b64_encode_int32(token_ids), _b64_encode_int32(offsets) ++ ++ + class InputFormat(Enum): + """Input format types for tokenization handling.""" + +@@ -2158,6 +2174,11 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): + meta_info["input_token_ids_logprobs"] = state.input_token_ids_logprobs + meta_info["output_token_ids_logprobs"] = state.output_token_ids_logprobs + ++ if state.output_top_p_token_ids: ++ token_ids, offsets = _encode_top_p_token_ids(state.output_top_p_token_ids) ++ meta_info["top_p_token_ids"] = token_ids ++ meta_info["top_p_token_offsets"] = offsets ++ + def convert_logprob_style( + self, + meta_info: dict, +@@ -2218,6 +2239,13 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): + recv_obj.output_token_ids_logprobs_idx[recv_obj_index] + ) + ++ output_top_p_token_ids = getattr(recv_obj, "output_top_p_token_ids", None) ++ if ( ++ output_top_p_token_ids is not None ++ and len(output_top_p_token_ids) > recv_obj_index ++ ): ++ state.output_top_p_token_ids.extend(output_top_p_token_ids[recv_obj_index]) ++ + self.add_logprob_to_meta_info( + meta_info, + state, +diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py +index e9ede57ad50..48cf8898fe0 100644 +--- a/python/sglang/srt/managers/utils.py ++++ b/python/sglang/srt/managers/utils.py +@@ -101,6 +101,11 @@ class GenerationBatchResult: + v.to("cpu", non_blocking=True) if torch.is_tensor(v) else v + for v in self.logits_output.next_token_token_ids_logprobs_val + ] ++ if self.logits_output.next_token_top_p_token_ids is not None: ++ self.logits_output.next_token_top_p_token_ids = [ ++ v.to("cpu", non_blocking=True) if torch.is_tensor(v) else v ++ for v in self.logits_output.next_token_top_p_token_ids ++ ] + if return_hidden_states and self.logits_output.hidden_states is not None: + self.logits_output.hidden_states = self.logits_output.hidden_states.to( + "cpu", non_blocking=True +@@ -185,6 +190,7 @@ def get_logprob_dict_from_result(result: GenerationBatchResult) -> dict: + "next_token_top_logprobs_idx": result.logits_output.next_token_top_logprobs_idx, + "next_token_token_ids_logprobs_val": result.logits_output.next_token_token_ids_logprobs_val, + "next_token_token_ids_logprobs_idx": result.logits_output.next_token_token_ids_logprobs_idx, ++ "next_token_top_p_token_ids": result.logits_output.next_token_top_p_token_ids, + "input_token_logprobs": result.logits_output.input_token_logprobs, + "input_top_logprobs_val": result.logits_output.input_top_logprobs_val, + "input_top_logprobs_idx": result.logits_output.input_top_logprobs_idx, +@@ -209,6 +215,9 @@ def get_logprob_from_pp_outputs( + next_token_token_ids_logprobs_idx=next_pp_outputs[ + "next_token_token_ids_logprobs_idx" + ], ++ next_token_top_p_token_ids=next_pp_outputs.tensors.get( ++ "next_token_top_p_token_ids", None ++ ), + input_token_logprobs=next_pp_outputs["input_token_logprobs"], + input_top_logprobs_val=next_pp_outputs["input_top_logprobs_val"], + input_top_logprobs_idx=next_pp_outputs["input_top_logprobs_idx"], +diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py +index f0defd28308..7da6029649e 100644 +--- a/python/sglang/srt/sampling/sampling_batch_info.py ++++ b/python/sglang/srt/sampling/sampling_batch_info.py +@@ -41,6 +41,11 @@ class SamplingBatchInfo: + + # Masking tensors for grammar-guided structured outputs + vocab_size: int ++ ++ # Whether to return kept token ids from top-p replay sampling. ++ return_top_p_token_ids: Optional[torch.Tensor] = None ++ need_return_top_p_token_ids: bool = False ++ + grammars: Optional[List] = None + rids_int: Optional[torch.Tensor] = None + bootstrap_room_ids_int: Optional[torch.Tensor] = None +@@ -93,6 +98,16 @@ class SamplingBatchInfo: + min_ps = torch.tensor( + [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device + ) ++ return_top_p_token_ids_cpu = [ ++ bool( ++ isinstance(r.sampling_params.custom_params, dict) ++ and r.sampling_params.custom_params.get("return_top_p_token_ids", False) ++ ) ++ for r in reqs ++ ] ++ return_top_p_token_ids = torch.tensor( ++ return_top_p_token_ids_cpu, dtype=torch.bool, device=device ++ ) + sampling_seed = ( + torch.tensor( + [ +@@ -179,6 +194,8 @@ class SamplingBatchInfo: + need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs), + need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs), + need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), ++ return_top_p_token_ids=return_top_p_token_ids, ++ need_return_top_p_token_ids=any(return_top_p_token_ids_cpu), + vocab_size=vocab_size, + penalizer_orchestrator=penalizer_orchestrator, + has_custom_logit_processor=has_custom_logit_processor, +@@ -281,6 +298,7 @@ class SamplingBatchInfo: + "top_ps", + "top_ks", + "min_ps", ++ "return_top_p_token_ids", + "sampling_seed", + ]: + value = getattr(self, item, None) +@@ -389,6 +407,7 @@ class SamplingBatchInfo: + "top_ps", + "top_ks", + "min_ps", ++ "return_top_p_token_ids", + "sampling_seed", + ]: + self_val = getattr(self, item, None) +@@ -400,6 +419,7 @@ class SamplingBatchInfo: + self.need_top_p_sampling |= other.need_top_p_sampling + self.need_top_k_sampling |= other.need_top_k_sampling + self.need_min_p_sampling |= other.need_min_p_sampling ++ self.need_return_top_p_token_ids |= other.need_return_top_p_token_ids + + self.adjusted_merge_batch(other) + +diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py +index 70de75f20be..6ea3b243177 100644 +--- a/python/sglang/srt/speculative/eagle_worker_v2.py ++++ b/python/sglang/srt/speculative/eagle_worker_v2.py +@@ -1313,7 +1313,12 @@ class EAGLEWorkerV2(BaseSpecWorker): + + if batch.return_logprob and not batch.forward_mode.is_idle(): + compute_spec_v2_logprobs( +- batch, logits_output, predict, accept_index, self.speculative_num_steps ++ batch, ++ logits_output, ++ predict, ++ accept_index, ++ accept_lens, ++ self.speculative_num_steps, + ) + + if not batch.forward_mode.is_idle() and self.topk > 1: +diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py +index 9aaf6b30673..6fb34c62f91 100644 +--- a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py ++++ b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py +@@ -828,7 +828,12 @@ class MultiLayerEagleWorkerV2(BaseSpecWorker): + + if batch.return_logprob and not batch.forward_mode.is_idle(): + compute_spec_v2_logprobs( +- batch, logits_output, predict, accept_index, self.speculative_num_steps ++ batch, ++ logits_output, ++ predict, ++ accept_index, ++ accept_lens, ++ self.speculative_num_steps, + ) + + next_draft_input = EagleDraftInput(bonus_tokens=bonus_tokens) diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index 191c20ad4a..f418ac4739 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -1,5 +1,5 @@ diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py -index a7bf9904a20..b0cb56aaece 100644 +index a7bf990..b0cb56a 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -32,6 +32,7 @@ class KVArgs: @@ -11,7 +11,7 @@ index a7bf9904a20..b0cb56aaece 100644 aux_data_lens: List[int] aux_item_lens: List[int] diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py -index e9efdcdd9ee..70265a424f5 100644 +index e9efdcd..70265a4 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -21,6 +21,7 @@ Life cycle of a request in the decode server @@ -183,7 +183,7 @@ index e9efdcdd9ee..70265a424f5 100644 return diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py -index b21aee9f7c2..87f0a6fa668 100644 +index b21aee9..87f0a6f 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -39,7 +39,10 @@ from sglang.srt.disaggregation.common.utils import ( @@ -322,7 +322,7 @@ index b21aee9f7c2..87f0a6fa668 100644 # Only the last chunk we need to send the aux data ret = self.send_aux( diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py -index ce1afdac3ad..de8fd054f70 100644 +index ce1afda..de8fd05 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -21,6 +21,8 @@ from __future__ import annotations @@ -440,7 +440,7 @@ index ce1afdac3ad..de8fd054f70 100644 release_kv_cache(req, self.tree_cache) # unlock the tree req.finished_reason = FINISH_LENGTH(length=0) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py -index e1d7d9c8db3..a44685777ea 100644 +index e1d7d9c..a446857 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -2,6 +2,7 @@ from __future__ import annotations @@ -646,7 +646,7 @@ index e1d7d9c8db3..a44685777ea 100644 ######################### diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py -index 88bf1947684..4ede8eb9078 100644 +index 88bf194..4ede8eb 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -71,6 +71,7 @@ from sglang.srt.managers.io_struct import ( @@ -679,7 +679,7 @@ index 88bf1947684..4ede8eb9078 100644 """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py -index d7368383d89..2c881d95bd5 100644 +index d736838..2c881d9 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -127,6 +127,7 @@ from sglang.srt.managers.io_struct import ( @@ -747,7 +747,7 @@ index d7368383d89..2c881d95bd5 100644 @auth_level(AuthLevel.ADMIN_OPTIONAL) async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py -index 435c30a5cfd..864a0f567a6 100644 +index 435c30a..864a0f5 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -299,6 +299,7 @@ class Envs: @@ -759,7 +759,7 @@ index 435c30a5cfd..864a0f567a6 100644 SGLANG_DISAGGREGATION_FORCE_QUERY_PREFILL_DP_RANK = EnvBool(False) # Extra slots in req_to_token_pool for decode workers (only effective when diff --git a/python/sglang/srt/layers/attention/dsa/dsa_indexer.py b/python/sglang/srt/layers/attention/dsa/dsa_indexer.py -index 85fcd4b9ec7..a49161f6154 100644 +index 85fcd4b..a49161f 100644 --- a/python/sglang/srt/layers/attention/dsa/dsa_indexer.py +++ b/python/sglang/srt/layers/attention/dsa/dsa_indexer.py @@ -2,6 +2,7 @@ from __future__ import annotations @@ -860,7 +860,7 @@ index 85fcd4b9ec7..a49161f6154 100644 if enable_dual_stream: current_stream = torch.cuda.current_stream() diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py -index 59ca3f9cce6..9c2d00fcd7c 100644 +index 59ca3f9..9c2d00f 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -793,6 +793,7 @@ class FusedMoE(torch.nn.Module): @@ -880,7 +880,7 @@ index 59ca3f9cce6..9c2d00fcd7c 100644 else loaded_weight ) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py -index 28a9d567a5e..e60a0bcfde0 100644 +index 28a9d56..e60a0bc 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -927,6 +927,10 @@ class CompressedTensorsLinearMethod(LinearMethodBase): @@ -906,7 +906,7 @@ index 28a9d567a5e..e60a0bcfde0 100644 self, layer: torch.nn.Module, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py -index 58562bb23db..c3dc1ceb0d2 100644 +index 58562bb..c3dc1ce 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py @@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( @@ -1020,69 +1020,10 @@ index 58562bb23db..c3dc1ceb0d2 100644 is_k_full=self.is_k_full, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py -index 987ec512122..e098565729b 100644 +index 987ec51..55a51e0 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py -@@ -1442,6 +1442,8 @@ class PauseContinueBroadcast: - class UpdateWeightFromDiskReqInput(BaseReq): - # The model path with the new weights - model_path: str -+ # Required iff ``load_format == "delta"``: basenames under ``model_path`` to apply. -+ files: Optional[List[str]] = None - # The format to load the weights - load_format: Optional[str] = None - # Whether to abort all requests before updating weights -@@ -1472,6 +1474,40 @@ class UpdateWeightFromDiskReqOutput(BaseReq): - num_paused_requests: Optional[int] = 0 - - -+class DeltaEncoding(str, Enum): -+ """Position encoding for delta weight updates.""" -+ -+ # int32 absolute nonzero offsets. -+ INDICES = "indices" -+ # uint16 gap-deltas between consecutive sorted positions; uint32 per-param fallback. -+ DELTAS = "deltas" -+ # ``deltas`` wrapped in zstd L1. -+ DELTAS_ZSTD = "deltas_zstd" -+ -+ -+@dataclass -+class DeltaParam: -+ """Per-param slice into the shared (positions, values) bucket.""" -+ -+ name: str -+ dtype: str -+ shape: List[int] -+ pos_start: int -+ pos_end: int -+ pos_width: int -+ val_start: int -+ val_end: int -+ -+ -+@dataclass -+class DeltaSpec: -+ """Decoding manifest for one delta bucket. ``checksum`` is verified on apply.""" -+ -+ encoding: DeltaEncoding -+ params: List[DeltaParam] -+ checksum: int = 0 -+ -+ - @dataclass - class UpdateWeightsFromDistributedReqInput(BaseReq): - names: List[str] -@@ -1487,6 +1523,8 @@ class UpdateWeightsFromDistributedReqInput(BaseReq): - weight_version: Optional[str] = None - # Optional format specification for loading - load_format: Optional[str] = None -+ # JSON-encoded DeltaSpec; required iff load_format == "delta". -+ delta: Optional[str] = None - # Whether to call torch.cuda.empty_cache() during flush - torch_empty_cache: bool = False - -@@ -1673,6 +1711,18 @@ class ResumeMemoryOccupationReqOutput(BaseReq): +@@ -1673,6 +1673,18 @@ class ResumeMemoryOccupationReqOutput(BaseReq): pass @@ -1101,7 +1042,7 @@ index 987ec512122..e098565729b 100644 @dataclass class CheckWeightsReqInput(BaseReq): action: str = "checksum" -@@ -2058,7 +2108,7 @@ class GetLoadsReqInput(BaseReq): +@@ -2058,7 +2070,7 @@ class GetLoadsReqInput(BaseReq): """Request for /v1/loads endpoint.""" VALID_SECTIONS = frozenset( @@ -1110,7 +1051,7 @@ index 987ec512122..e098565729b 100644 ) include: List[str] = field(default_factory=lambda: ["all"]) -@@ -2128,6 +2178,9 @@ class GetLoadsReqOutput(BaseReq): +@@ -2128,6 +2140,9 @@ class GetLoadsReqOutput(BaseReq): lora: Optional[LoRAMetrics] = None disaggregation: Optional[DisaggregationMetrics] = None queues: Optional[QueueMetrics] = None @@ -1121,7 +1062,7 @@ index 987ec512122..e098565729b 100644 @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py -index 42ea8431091..c369b070b57 100755 +index 42ea843..c369b07 100755 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -943,6 +943,7 @@ class Req(ReqDllmMixin): @@ -1159,7 +1100,7 @@ index 42ea8431091..c369b070b57 100755 ): # Even the last remaining request cannot fit in memory. diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py -index 8e32640fc6a..98966842506 100644 +index 8e32640..9896684 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -124,6 +124,7 @@ from sglang.srt.managers.io_struct import ( @@ -1206,7 +1147,7 @@ index 8e32640fc6a..98966842506 100644 def _pause_engine(self) -> Tuple[List[Req], int]: diff --git a/python/sglang/srt/managers/scheduler_components/load_inquirer.py b/python/sglang/srt/managers/scheduler_components/load_inquirer.py -index 3f10d7edaff..712322a95af 100644 +index 3f10d7e..712322a 100644 --- a/python/sglang/srt/managers/scheduler_components/load_inquirer.py +++ b/python/sglang/srt/managers/scheduler_components/load_inquirer.py @@ -202,6 +202,88 @@ class SchedulerLoadInquirer: @@ -1305,7 +1246,7 @@ index 3f10d7edaff..712322a95af 100644 + inflight=inflight, ) diff --git a/python/sglang/srt/managers/scheduler_components/output_streamer.py b/python/sglang/srt/managers/scheduler_components/output_streamer.py -index cac80715856..2574fcfb55c 100644 +index cac8071..2574fcf 100644 --- a/python/sglang/srt/managers/scheduler_components/output_streamer.py +++ b/python/sglang/srt/managers/scheduler_components/output_streamer.py @@ -481,7 +481,7 @@ class _GenerationStreamAccumulator: @@ -1318,7 +1259,7 @@ index cac80715856..2574fcfb55c 100644 dp_ranks = [dp_rank] * len(self.rids) if self.rids else None return BatchTokenIDOutput( diff --git a/python/sglang/srt/managers/scheduler_components/profiler_manager.py b/python/sglang/srt/managers/scheduler_components/profiler_manager.py -index 31df519f9e8..cdcf41cd8bc 100644 +index 31df519..cdcf41c 100644 --- a/python/sglang/srt/managers/scheduler_components/profiler_manager.py +++ b/python/sglang/srt/managers/scheduler_components/profiler_manager.py @@ -377,7 +377,7 @@ class SchedulerProfilerManager: @@ -1331,7 +1272,7 @@ index 31df519f9e8..cdcf41cd8bc 100644 if self.profile_in_progress: # force trace flush diff --git a/python/sglang/srt/managers/scheduler_components/weight_updater.py b/python/sglang/srt/managers/scheduler_components/weight_updater.py -index 77bf823b081..9ab3abe5618 100644 +index 77bf823..9ab3abe 100644 --- a/python/sglang/srt/managers/scheduler_components/weight_updater.py +++ b/python/sglang/srt/managers/scheduler_components/weight_updater.py @@ -16,6 +16,7 @@ from sglang.srt.constants import ( @@ -1421,7 +1362,7 @@ index 77bf823b081..9ab3abe5618 100644 return ResumeMemoryOccupationReqOutput() diff --git a/python/sglang/srt/managers/tokenizer_control_mixin.py b/python/sglang/srt/managers/tokenizer_control_mixin.py -index c9939a1fc93..ee25e5e70e0 100644 +index c9939a1..ee25e5e 100644 --- a/python/sglang/srt/managers/tokenizer_control_mixin.py +++ b/python/sglang/srt/managers/tokenizer_control_mixin.py @@ -48,6 +48,8 @@ from sglang.srt.managers.io_struct import ( @@ -1459,7 +1400,7 @@ index c9939a1fc93..ee25e5e70e0 100644 self: TokenizerManager, obj: CheckWeightsReqInput, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index 357e3c4675a..1f6dc90e471 100644 +index 357e3c4..71319d7 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1641,7 +1641,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): @@ -1480,24 +1421,6 @@ index 357e3c4675a..1f6dc90e471 100644 self.is_pause_cond.notify_all() async def update_weights_from_disk( -@@ -1704,7 +1704,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): - self.model_update_result = asyncio.Future() - if self.server_args.dp_size == 1: - result = await self.model_update_result -- if result.success: -+ if result.success and obj.load_format != "delta": - self._update_model_path_info(obj.model_path, obj.load_format) - return result.success, result.message, result.num_paused_requests - else: # self.server_args.dp_size > 1 -@@ -1712,7 +1712,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): - result = await self.model_update_result - - all_success = all([r.success for r in result]) -- if all_success is True: -+ if all_success is True and obj.load_format != "delta": - self._update_model_path_info(obj.model_path, obj.load_format) - all_message = [r.message for r in result] - all_message = " | ".join(all_message) @@ -2343,25 +2343,23 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): priority = getattr(state.obj, "priority", None) if priority is not None: @@ -1532,7 +1455,7 @@ index 357e3c4675a..1f6dc90e471 100644 if state.finished: # Get detailed cache breakdown if available diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py -index bd9184408ed..71bbe8f400f 100644 +index bd91844..b7cb6d4 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -29,6 +29,7 @@ from sglang.srt.managers.io_struct import ( @@ -1543,31 +1466,22 @@ index bd9184408ed..71bbe8f400f 100644 SendWeightsToRemoteInstanceReqInput, UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, -@@ -98,6 +99,7 @@ class BaseTpWorker(ABC): - recv_req.model_path, - recv_req.load_format, - recapture_cuda_graph=recv_req.recapture_cuda_graph, -+ files=recv_req.files, +@@ -155,6 +156,13 @@ class BaseTpWorker(ABC): ) return success, message -@@ -152,6 +154,14 @@ class BaseTpWorker(ABC): - recv_req.shapes, - recv_req.group_name, - recv_req.load_format, -+ recv_req.delta, -+ ) -+ return success, message -+ + def post_process_weights(self, recv_req: PostProcessWeightsReqInput): + success, message = self.model_runner.post_process_weights( + restore_weights_before_load=recv_req.restore_weights_before_load, + post_process_quantization=recv_req.post_process_quantization, - ) - return success, message ++ ) ++ return success, message ++ + def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): + monkey_patch_torch_reductions() diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py -index 353a02ee0be..7e3e3f58cb9 100644 +index 353a02e..7e3e3f5 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -1009,9 +1009,7 @@ class HiRadixCache(RadixCache): @@ -1590,7 +1504,7 @@ index 353a02ee0be..7e3e3f58cb9 100644 def _evict_regular(self, node: TreeNode): diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py -index 8efe9aae94e..79e9885c92f 100644 +index 8efe9aa..79e9885 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -2244,9 +2244,12 @@ class DSATokenToKVPool(MLATokenToKVPool): @@ -1610,7 +1524,7 @@ index 8efe9aae94e..79e9885c92f 100644 self.index_k_with_scale_buffer = [ torch.zeros( diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py -index bd6adb6e398..5ea935f76e9 100644 +index bd6adb6..5ea935f 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -467,6 +467,9 @@ class RadixCache(KVCacheEventMixin, BasePrefixCache): @@ -1635,37 +1549,18 @@ index bd6adb6e398..5ea935f76e9 100644 return DecLockRefResult(delta=delta) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 3b30eb0e1f7..d715bc6893d 100644 +index 3b30eb0..0e33834 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py -@@ -20,7 +20,9 @@ import datetime +@@ -20,6 +20,7 @@ import datetime import gc import hashlib import inspect +import json import logging -+import math import os import socket - import threading -@@ -28,7 +30,7 @@ import time - from collections import defaultdict - from dataclasses import dataclass, replace - from pathlib import Path --from typing import Any, Callable, List, Optional, Tuple, Union -+from typing import Any, Callable, Dict, List, Optional, Tuple, Union - - import torch - import torch.distributed as dist -@@ -137,6 +139,7 @@ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model - from sglang.srt.layers.utils.cp_utils import is_mla_prefill_cp_enabled - from sglang.srt.lora.lora_manager import LoRAManager - from sglang.srt.lora.lora_registry import LoRARef -+from sglang.srt.managers.io_struct import DeltaEncoding, DeltaParam, DeltaSpec - from sglang.srt.managers.schedule_batch import sanity_check_mm_pad_shift_value - from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator - from sglang.srt.mem_cache.memory_pool import ReqToTokenPool -@@ -548,7 +551,10 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -548,7 +549,10 @@ class ModelRunner(ModelRunnerKVCacheMixin): self.forward_stream = torch.get_device_module(self.device).Stream() # CPU offload @@ -1677,7 +1572,7 @@ index 3b30eb0e1f7..d715bc6893d 100644 self._weight_checker = WeightChecker(model_runner=self) -@@ -796,7 +802,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -796,7 +800,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): self.maybe_init_ngram_embedding() # Init routed experts capturer @@ -1687,202 +1582,16 @@ index 3b30eb0e1f7..d715bc6893d 100644 self.init_indexer_capturer() -@@ -1657,8 +1664,14 @@ class ModelRunner(ModelRunnerKVCacheMixin): - load_format: str, +@@ -1658,7 +1663,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): weight_name_filter: Optional[Callable[[str], bool]] = None, recapture_cuda_graph: bool = False, -+ files: Optional[List[str]] = None, ) -> tuple[bool, str]: - """Update engine weights in-place from the disk.""" + """Update engine weights in-place from disk.""" -+ if load_format == "delta": -+ if not files: -+ return False, "load_format='delta' requires non-empty `files`" -+ return self._apply_delta([os.path.join(model_path, f) for f in files]) -+ logger.info( f"Update engine weights online from disk begin. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id, empty_cache=False):.2f} GB" -@@ -1888,6 +1901,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): - shapes, - group_name, - load_format: Optional[str] = None, -+ delta: Optional[str] = None, - ): - """ - Update specific parameter in the model weights online -@@ -1908,6 +1922,18 @@ class ModelRunner(ModelRunnerKVCacheMixin): - return self._update_bucketed_weights_from_distributed( - names, dtypes, shapes, group_name - ) -+ if load_format == "delta": -+ if delta is None: -+ return False, "load_format='delta' requires a DeltaSpec in the request" -+ spec_dict = json.loads(delta) -+ spec = DeltaSpec( -+ encoding=DeltaEncoding(spec_dict["encoding"]), -+ params=[DeltaParam(**p) for p in spec_dict["params"]], -+ checksum=int(spec_dict["checksum"]), -+ ) -+ return self._apply_delta_from_distributed( -+ names, dtypes, shapes, group_name, spec -+ ) - try: - weights = [] - handles = [] -@@ -1971,6 +1997,151 @@ class ModelRunner(ModelRunnerKVCacheMixin): - logger.error(error_msg) - return False, error_msg - -+ def _decode_delta_one_param( -+ self, -+ encoding: DeltaEncoding, -+ positions: torch.Tensor, -+ values: torch.Tensor, -+ p: DeltaParam, -+ ) -> torch.Tensor: -+ """Decode one param's sparse delta into a NaN-masked full tensor.""" -+ numel = math.prod(p.shape) -+ param_dtype = _resolve_torch_dtype(p.dtype) -+ flat = torch.full((numel,), float("nan"), dtype=param_dtype, device=self.device) -+ val_slice = values[p.val_start : p.val_end] -+ if val_slice.numel() == 0: -+ return flat.view(tuple(p.shape)) -+ -+ pos_bytes = positions[p.pos_start : p.pos_end] -+ if encoding is DeltaEncoding.INDICES: -+ width = 4 -+ elif encoding in (DeltaEncoding.DELTAS, DeltaEncoding.DELTAS_ZSTD): -+ width = p.pos_width -+ else: -+ raise ValueError(f"unsupported delta encoding: {encoding!r}") -+ -+ n_elems = pos_bytes.numel() // width -+ b = pos_bytes.view(n_elems, width).to(torch.int64) -+ if width == 2: -+ unpacked = b[:, 0] | (b[:, 1] << 8) -+ else: -+ unpacked = b[:, 0] | (b[:, 1] << 8) | (b[:, 2] << 16) | (b[:, 3] << 24) -+ -+ if encoding is DeltaEncoding.INDICES: -+ idx = unpacked -+ else: -+ idx = (unpacked + 1).cumsum(dim=0) - 1 -+ -+ flat.index_copy_(0, idx, val_slice.to(param_dtype)) -+ return flat.view(tuple(p.shape)) -+ -+ def _apply_delta_payload( -+ self, -+ encoding: DeltaEncoding, -+ params: List[DeltaParam], -+ positions: torch.Tensor, -+ values: torch.Tensor, -+ expected_checksum: int, -+ ) -> None: -+ actual_checksum = _delta_checksum(positions, values) -+ if actual_checksum != expected_checksum: -+ raise RuntimeError( -+ f"delta checksum mismatch: expected={expected_checksum} got={actual_checksum}" -+ ) -+ -+ chunk_byte_cap = self.server_args.update_weight_delta_chunk_bytes -+ with _delta_apply_context(self.model): -+ chunk: List[Tuple[str, torch.Tensor]] = [] -+ chunk_bytes = 0 -+ for p in params: -+ t = self._decode_delta_one_param(encoding, positions, values, p) -+ tensor_bytes = t.numel() * t.element_size() -+ if chunk_bytes + tensor_bytes > chunk_byte_cap and chunk: -+ self.model.load_weights(chunk) -+ chunk = [] -+ chunk_bytes = 0 -+ chunk.append((p.name, t)) -+ chunk_bytes += tensor_bytes -+ if chunk: -+ self.model.load_weights(chunk) -+ -+ def _decode_and_apply_blob(self, blob: bytes) -> None: -+ from safetensors.torch import load as st_load -+ -+ hdr_len = int.from_bytes(blob[:8], "little") -+ meta = json.loads(blob[8 : 8 + hdr_len]).get("__metadata__", {}) -+ encoding = DeltaEncoding(meta["encoding"]) -+ params = [DeltaParam(**p) for p in json.loads(meta["params"])] -+ expected_checksum = int(meta["checksum"]) -+ -+ tensors = st_load(blob) -+ positions = tensors["__positions__"].to(self.device, non_blocking=True) -+ values = tensors["__values__"].to(self.device, non_blocking=True) -+ self._apply_delta_payload( -+ encoding, params, positions, values, expected_checksum -+ ) -+ -+ def _apply_delta_from_distributed( -+ self, -+ names: List[str], -+ dtypes: List[str], -+ shapes: List[List[int]], -+ group_name: str, -+ delta: DeltaSpec, -+ ) -> tuple[bool, str]: -+ try: -+ recv: Dict[str, torch.Tensor] = {} -+ handles = [] -+ for name, dtype, shape in zip(names, dtypes, shapes): -+ target_dtype = _resolve_torch_dtype(dtype) -+ t = torch.empty(shape, dtype=target_dtype, device=self.device) -+ handles.append( -+ torch.distributed.broadcast( -+ t, -+ src=0, -+ group=self._model_update_group[group_name], -+ async_op=True, -+ ) -+ ) -+ recv[name] = t -+ for handle in handles: -+ handle.wait() -+ -+ self._apply_delta_payload( -+ delta.encoding, -+ delta.params, -+ recv["__positions__"], -+ recv["__values__"], -+ delta.checksum, -+ ) -+ return True, "ok" -+ except Exception as e: -+ error_msg = f"Failed to apply delta from distributed: {e}." -+ logger.error(error_msg) -+ return False, error_msg -+ -+ def _apply_delta(self, paths: List[str]) -> tuple[bool, str]: -+ import concurrent.futures -+ -+ n_files = len(paths) -+ workers = min(n_files, self.server_args.update_weight_delta_read_workers) -+ -+ def _read_and_decompress(path: str) -> bytes: -+ with open(path, "rb") as fh: -+ return _maybe_zstd_decompress(fh.read()) -+ -+ try: -+ for i in range(0, n_files, workers): -+ with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as pool: -+ batch = list(pool.map(_read_and_decompress, paths[i : i + workers])) -+ for blob in batch: -+ self._decode_and_apply_blob(blob) -+ return True, f"Applied {n_files} delta file(s)" -+ except Exception as e: -+ error_msg = f"Failed to apply delta update from disk: {e}." -+ logger.error(error_msg) -+ return False, error_msg -+ - def update_weights_from_tensor( - self, - named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]], -@@ -3468,11 +3639,17 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -3468,11 +3473,17 @@ class ModelRunner(ModelRunnerKVCacheMixin): output.expert_distribution_metrics = recorder_outputs.get("metrics") no_copy_to_cpu = not self.server_args.disable_overlap_schedule @@ -1901,7 +1610,7 @@ index 3b30eb0e1f7..d715bc6893d 100644 no_copy_to_cpu=no_copy_to_cpu, ) -@@ -3480,7 +3657,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -3480,7 +3491,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): output.indexer_topk_output = indexer_capturer.on_forward_end( forward_batch=forward_batch, can_run_graph=output.can_run_graph, @@ -1910,7 +1619,7 @@ index 3b30eb0e1f7..d715bc6893d 100644 no_copy_to_cpu=no_copy_to_cpu, ) -@@ -3718,6 +3895,39 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -3718,6 +3729,39 @@ class ModelRunner(ModelRunnerKVCacheMixin): logger.error(f"IPC weight update failed: {e}") return False, str(e) @@ -1950,120 +1659,10 @@ index 3b30eb0e1f7..d715bc6893d 100644 def prealloc_symmetric_memory_pool(self): # PyTorch mempools never de-fragment memory in OOM scenarios, so we need to pre-allocate a large chunk of memory to limit fragmentation. if ( -@@ -3767,6 +3977,123 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -3767,6 +3811,13 @@ class ModelRunner(ModelRunnerKVCacheMixin): return output -+def _param_storage_index(model): -+ import bisect -+ -+ starts: List[int] = [] -+ ends: List[int] = [] -+ owners: List[torch.Tensor] = [] -+ seen: set = set() -+ for tensors in (model.named_parameters(), model.named_buffers()): -+ for _, t in tensors: -+ if t.is_meta: -+ continue -+ try: -+ ptr = t.data_ptr() -+ except RuntimeError: -+ continue -+ if ptr == 0 or ptr in seen: -+ continue -+ seen.add(ptr) -+ sz = t.numel() * t.element_size() -+ starts.append(ptr) -+ ends.append(ptr + sz) -+ owners.append(t) -+ -+ order = sorted(range(len(starts)), key=lambda i: starts[i]) -+ starts = [starts[i] for i in order] -+ ends = [ends[i] for i in order] -+ owners = [owners[i] for i in order] -+ -+ def find_parent(dst): -+ try: -+ ptr = dst.data_ptr() -+ except RuntimeError: -+ return None -+ idx = bisect.bisect_right(starts, ptr) - 1 -+ if 0 <= idx < len(starts) and starts[idx] <= ptr < ends[idx]: -+ return owners[idx] -+ return None -+ -+ return find_parent -+ -+ -+@contextlib.contextmanager -+def _delta_apply_context(model): -+ is_param_target = _param_storage_index(model) -+ original_copy_ = torch.Tensor.copy_ -+ original_fill_ = torch.Tensor.fill_ -+ -+ def patched_copy_(self, src, *args, **kwargs): -+ if is_param_target(self) is not None: -+ src_aligned = ( -+ src.to(device=self.device, dtype=self.dtype) -+ if src.dtype != self.dtype -+ else src -+ ) -+ mask = ~torch.isnan(src_aligned) -+ self[mask] = src_aligned[mask] -+ return self -+ return original_copy_(self, src, *args, **kwargs) -+ -+ def patched_fill_(self, value): -+ if is_param_target(self) is not None: -+ try: -+ if math.isnan(value): -+ return self -+ except TypeError: -+ pass -+ return original_fill_(self, value) -+ return original_fill_(self, value) -+ -+ original_post_load = getattr(model, "post_load_weights", None) -+ if original_post_load is not None: -+ -+ def wrapped_post_load(*args, **kwargs): -+ current_copy = torch.Tensor.copy_ -+ current_fill = torch.Tensor.fill_ -+ torch.Tensor.copy_ = original_copy_ -+ torch.Tensor.fill_ = original_fill_ -+ try: -+ return original_post_load(*args, **kwargs) -+ finally: -+ torch.Tensor.copy_ = current_copy -+ torch.Tensor.fill_ = current_fill -+ -+ model.post_load_weights = wrapped_post_load -+ -+ torch.Tensor.copy_ = patched_copy_ -+ torch.Tensor.fill_ = patched_fill_ -+ try: -+ yield -+ finally: -+ torch.Tensor.copy_ = original_copy_ -+ torch.Tensor.fill_ = original_fill_ -+ if original_post_load is not None: -+ model.post_load_weights = original_post_load -+ -+ -+def _delta_checksum(positions: torch.Tensor, values: torch.Tensor) -> int: -+ p = int(torch.hash_tensor(positions).item()) if positions.numel() else 0 -+ v = int(torch.hash_tensor(values).item()) if values.numel() else 0 -+ return p ^ (v << 1) -+ -+ -+def _maybe_zstd_decompress(blob: bytes) -> bytes: -+ if blob.startswith(b"\x28\xb5\x2f\xfd"): -+ import zstandard -+ -+ return zstandard.ZstdDecompressor().decompress(blob) -+ return blob -+ -+ +def _resolve_torch_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype @@ -2075,7 +1674,7 @@ index 3b30eb0e1f7..d715bc6893d 100644 params_dict = dict(model.named_parameters()) for name, tensor in named_tensors: diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py -index 2f0074924db..8d62df83c74 100644 +index 2f00749..8d62df8 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py @@ -45,6 +45,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): @@ -2186,7 +1785,7 @@ index 2f0074924db..8d62df83c74 100644 continue diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py -index 3ffe4dde7fd..9869f11623e 100644 +index 3ffe4dd..9869f11 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -1034,9 +1034,14 @@ class Qwen3LLMModel(Qwen3Model): @@ -2208,7 +1807,7 @@ index 3ffe4dde7fd..9869f11623e 100644 positions, hidden_states, diff --git a/python/sglang/srt/multimodal/processors/glm4v.py b/python/sglang/srt/multimodal/processors/glm4v.py -index db684259d2f..17d2cb6958a 100644 +index db68425..17d2cb6 100644 --- a/python/sglang/srt/multimodal/processors/glm4v.py +++ b/python/sglang/srt/multimodal/processors/glm4v.py @@ -1,7 +1,13 @@ @@ -2276,7 +1875,7 @@ index db684259d2f..17d2cb6958a 100644 image_grid_thw = None video_grid_thw = None diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py -index b8774ebade5..fa01537b201 100644 +index b8774eb..fa01537 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -678,7 +678,7 @@ class QwenVLImageProcessor(SGLangBaseProcessor): @@ -2289,7 +1888,7 @@ index b8774ebade5..fa01537b201 100644 image_data=image_data, video_data=request_obj.video_data, diff --git a/python/sglang/srt/observability/req_time_stats.py b/python/sglang/srt/observability/req_time_stats.py -index 2de10730c94..d3ce2c62d21 100644 +index 2de1073..d3ce2c6 100644 --- a/python/sglang/srt/observability/req_time_stats.py +++ b/python/sglang/srt/observability/req_time_stats.py @@ -23,7 +23,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -2552,40 +2151,8 @@ index 2de10730c94..d3ce2c62d21 100644 return meta_data def format_duration(self, duration: float) -> str: -diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py -index 6c77ff64f92..706161ab928 100644 ---- a/python/sglang/srt/server_args.py -+++ b/python/sglang/srt/server_args.py -@@ -854,6 +854,8 @@ class ServerArgs: - weight_loader_prefetch_checkpoints: bool = False - weight_loader_prefetch_num_threads: int = 4 - weight_loader_drop_cache_after_load: bool = False -+ update_weight_delta_chunk_bytes: int = 512 * 1024 * 1024 -+ update_weight_delta_read_workers: int = 4 - remote_instance_weight_loader_seed_instance_ip: Optional[str] = None - remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None - remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None -@@ -7119,6 +7121,18 @@ class ServerArgs: - action="store_true", - help="Call posix_fadvise(DONTNEED) on each safetensors shard after loading it.", - ) -+ parser.add_argument( -+ "--update-weight-delta-chunk-bytes", -+ type=int, -+ default=ServerArgs.update_weight_delta_chunk_bytes, -+ help="Maximum bytes per delta weight chunk when applying delta updates.", -+ ) -+ parser.add_argument( -+ "--update-weight-delta-read-workers", -+ type=int, -+ default=ServerArgs.update_weight_delta_read_workers, -+ help="Number of worker threads used to read delta weight files.", -+ ) - parser.add_argument( - "--remote-instance-weight-loader-seed-instance-ip", - type=str, diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py -index 96c7286af76..9e3e2bd7142 100644 +index 96c7286..9e3e2bd 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -458,8 +458,12 @@ class EAGLEDraftCudaGraphRunner: @@ -2604,7 +2171,7 @@ index 96c7286af76..9e3e2bd7142 100644 buffers.hidden_states is not None and forward_batch.spec_info.hidden_states is not None diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py -index 6bf5d6182af..70de75f20be 100644 +index 6bf5d61..61b9603 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -530,6 +530,21 @@ class EagleDraftWorker(BaseDraftWorker): @@ -2640,28 +2207,8 @@ index 6bf5d6182af..70de75f20be 100644 # Organize the results if ( self.topk == 1 -@@ -1480,6 +1499,7 @@ class EAGLEWorkerV2(BaseSpecWorker): - recv_req.model_path, - recv_req.load_format, - recapture_cuda_graph=recv_req.recapture_cuda_graph, -+ files=recv_req.files, - ) - if not success: - return success, message -diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py -index 04b3841a23d..9aaf6b30673 100644 ---- a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py -+++ b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py -@@ -856,6 +856,7 @@ class MultiLayerEagleWorkerV2(BaseSpecWorker): - recv_req.model_path, - recv_req.load_format, - recapture_cuda_graph=recv_req.recapture_cuda_graph, -+ files=recv_req.files, - ) - if not success: - return success, message diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py -index 4556d06b16f..9c28114f85d 100644 +index 4556d06..9c28114 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2399,6 +2399,7 @@ class SafeUnpickler(pickle.Unpickler): diff --git a/docker/version.txt b/docker/version.txt index 6adf74a6e0..522939d93c 100644 --- a/docker/version.txt +++ b/docker/version.txt @@ -1 +1 @@ -nightly-dev-20260614a +nightly-dev-20260618a diff --git a/docs/en/advanced/delta-weight-sync.md b/docs/en/advanced/delta-weight-sync.md index a1e670f81d..d421597297 100644 --- a/docs/en/advanced/delta-weight-sync.md +++ b/docs/en/advanced/delta-weight-sync.md @@ -1,111 +1,86 @@ # Delta Weight Sync -- [Why](#why) -- [Quick Start](#quick-start) -- [Mode vs Transport](#mode-vs-transport) -- [How It Works](#how-it-works) -- [Encoding Choice](#encoding-choice) -- [Why Not Colocated](#why-not-colocated) +Delta weight sync keeps non-colocated rollout engines up to date by shipping only the bytes +that changed between two syncs, instead of a full checkpoint each time. It targets large-model +training/inference disaggregation across clusters or datacenters, where writing the whole actor +every sync is the dominant cost. -## Why +It is **disk-transport only** and reloads through the **ordinary** `update_weights_from_disk` +endpoint, so the inference engine needs no delta-specific support. -Slime's default sync broadcasts every parameter every step. The cost scales linearly with model size and dominates the sync phase, even though only a few percent of weights change between consecutive RL steps. Delta sync keeps a pinned-CPU snapshot of the last broadcast and ships only the positions whose bytes differ. - -The motivating use case is **training/inference disaggregation** — running the trainer and the rollout engines in *different datacenters* over a shared filesystem with bandwidth on the order of 100s of MB/s, where a full broadcast is infeasible but a sparse delta (~3% density, ~5 GB for a 355B model) is. The same delta machinery also runs over NCCL inside a single datacenter, where it serves as the validation baseline that proves the wire encoding and apply logic are correct. - -Prior art: selective overwrite is inspired by [arXiv:2509.19128](https://arxiv.org/abs/2509.19128); the cross-DC disaggregation motivation is from [Fireworks AI — Frontier RL Is Cheaper Than You Think](https://fireworks.ai/blog/frontier-rl-is-cheaper-than-you-think). Another public production-shaped reference is the [Composer 2 technical report by the Cursor Research Team](https://arxiv.org/html/2603.24477v2), which describes Cursor partnering with Fireworks AI for RL inference and syncing every training-step update through shared S3, delta compression, and cross-region inference-cluster reconstruction. - -## Quick Start - -Disk transport (training/inference disaggregation — the main use case): +## Configuration ```bash --update-weight-mode delta --update-weight-transport disk ---update-weight-encoding deltas_zstd # best for ≤ 300 MB/s shared FS --update-weight-disk-dir /shared/fs/delta-updates +--update-weight-local-checkpoint-dir /local/nvme/rollout-ckpt +--update-weight-delta-encoding xor # or: overwrite +--update-weight-delta-checksum xxh3-128 # or: blake3, adler32 ``` -NCCL transport (intra-datacenter validation baseline): - -```bash ---update-weight-mode delta ---update-weight-transport nccl ---update-weight-encoding indices # lowest compute, no compression -``` - -Full-checkpoint disk transport (simple external-engine fallback): - -```bash ---update-weight-mode full ---update-weight-transport disk ---update-weight-disk-dir /shared/fs/full-updates -``` - -This writes a complete HF checkpoint under `weight_v{N:06d}/` for every sync, -then asks each SGLang engine to reload it with `update_weights_from_disk`. It is -useful when the trainer cannot form an NCCL group with pre-launched rollout -engines, but it is much heavier than delta sync for large models. - -Receiver-side delta tuning (applies to delta NCCL and delta disk): - -```bash ---sglang-update-weight-delta-chunk-bytes $((2 * 1024 * 1024 * 1024)) # byte cap per load_weights call ---sglang-update-weight-delta-read-workers 4 # parallel I/O threads (disk only) -``` - -See [examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh](../../../examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh) for a complete launcher. - -## Mode vs Transport - -`--update-weight-mode` decides **what** gets sent; `--update-weight-transport` -decides **how** it reaches SGLang. - -| mode | transport | behavior | -|---|---|---| -| `full` | `nccl` | default path: broadcast every HF weight chunk over a trainer-engine NCCL group | -| `full` | `disk` | write a complete HF checkpoint under `--update-weight-disk-dir`, then call `update_weights_from_disk` | -| `delta` | `nccl` | broadcast sparse changed positions + values over NCCL | -| `delta` | `disk` | write sparse safetensors under `--update-weight-disk-dir`, then call `update_weights_from_disk(load_format="delta")` | - -`--update-weight-delta-dir` is kept only as a backward-compatible alias for -`--update-weight-disk-dir`; new launchers should use the transport-level name. - -## How It Works - -Delta NCCL and delta disk share one sender pipeline, one wire layout, and one receiver-side decoder; only the per-flush carrier differs. - -**Sender (per sync, PP-source rank only):** - -1. **Diff** the current weights against the pinned-CPU snapshot via bytewise compare (`current.view(int_dtype) != snapshot.view(int_dtype)`) — lossless, dtype-agnostic, no arithmetic. -2. **Encode** changed (position, value) pairs into a packed `__positions__` byte blob + `__values__` tensor + per-param decoding manifest. The encoding (`indices`, `deltas`, `deltas_zstd`) governs only how positions are packed; values are sent verbatim in the param's dtype. -3. **Bucket** per-chunk encodes up to `--update-weight-buffer-size` bytes, then flush: - - NCCL: broadcast `(__positions__, __values__)` to the rollout engines with a `DeltaSpec` (encoding + per-param manifest) carried in the Ray RPC. - - Disk: write one safetensors file per flush under `weight_v{N:06d}/`. Async background thread does the I/O + optional zstd compression off the critical path. -4. **Snapshot the just-sent values** via a D2H copy on a side stream so it overlaps with the next chunk's encode. - -**End-of-sync (disk only):** write a `DONE` marker, then rank 0 fires one HTTP push per engine and removes the directory after every engine acknowledges. - -**Receiver:** - -For both transports, the receiver ends up calling the same `_apply_delta_payload(encoding, params, positions, values)` helper. It decodes each param's slice into a full-shape tensor with NaN at unchanged positions, then routes it through `model.load_weights(...)` under a `_delta_apply_context` that patches `Tensor.copy_` / `Tensor.fill_` to perform NaN-masked overwrite. Auxiliary writes (scratch buffers, fp8 scales, MoE biases via `post_load_weights`) keep their normal semantics. - -Selective overwrite has no arithmetic — the receiver writes the trainer's exact bytes at changed positions — so it's lossless by construction and there's no notion of drift to fight with periodic base re-syncs. - -## Encoding Choice - -`--update-weight-encoding` picks how positions are packed. All three share the same on-wire layout (`__positions__` uint8 blob + `__values__` tensor + per-param manifest); decoder dispatches on the metadata. - -| value | positions | when to pick | -|---|---|---| -| `indices` | int32 absolute positions (4 bytes / nnz) | NCCL or fast intra-cluster FS (≥ ~600 MB/s) | -| `deltas` | uint16 gap-deltas with uint32 fallback (~2 bytes / nnz at 2% density) | medium FS bandwidth (~300-500 MB/s) | -| `deltas_zstd` | `deltas` wrapped in zstd L1 on disk | cross-DC / cross-region shared FS (≤ ~300 MB/s) | - -**Why gap-encoded positions are smaller**: positions come out of `mask.nonzero()` already sorted ascending. At density `p`, the expected gap between consecutive nonzero positions is `1/p`, and `P(gap > 65535) ≈ exp(-p · 65535)`. At p = 2% that's effectively zero, so uint16 fits with a uint32 per-param fallback for pathological inputs. Half the position bytes of `indices`, lossless. - -**Break-even with `indices`** at our density (~2%): `deltas` halves the positions blob (which dominates the wire); `zstd` shaves another ~35-40% on top by compressing the gap byte stream, at the cost of ~250ms/file compress + ~150ms/file decompress. The crossover with `indices` is where compress/decompress compute exceeds the bandwidth savings — empirically around 500 MB/s for `deltas` and 300 MB/s for `deltas_zstd`. - -## Why Not Colocated - -Colocated weight sync uses CUDA IPC: only a memory handle (~64 B) crosses processes. Delta encoding's "bytes saved on the wire" benefit is zero, while the bookkeeping (snapshot + diff + sparse encode) is pure overhead. Slime rejects `--update-weight-mode delta --colocate` at argparse time. +| Flag | Role | +|---|---| +| `--update-weight-disk-dir` | Shared filesystem directory the trainer publishes deltas to and the rollout hosts read from. | +| `--update-weight-local-checkpoint-dir` | Host-local (e.g. NVMe) full HF checkpoint that the delta is applied into in place. Each host materializes it from `--hf-checkpoint` at engine start. | +| `--update-weight-delta-encoding` | On-disk delta encoding: `xor` (default) or `overwrite`. | +| `--update-weight-delta-checksum` | Per-tensor integrity checksum: `xxh3-128` (default), `blake3`, or `adler32`. | + +Deltas are always zstd-compressed (level 1); profiling showed it dominates lz4 / gzip / snappy / brotli on both wire size and decompress speed for this data, so it is not a knob. + +## How it works + +1. **Seed.** On the first sync the trainer captures a CPU snapshot of every parameter — seeded + from `--hf-checkpoint`, which is exactly what each rollout host materializes its local + checkpoint from. Nothing is published; this snapshot is the base the next sync diffs against. +2. **Publish.** On every later sync the trainer diffs each gathered HF tensor against the + snapshot, encodes and compresses the change, and writes a new version directory + `weight_v{N:06d}/` under `--update-weight-disk-dir`. The directory is a canonical HF + checkpoint — `model-NNNNN.safetensors` files holding the compressed diff tensors plus a + `model.safetensors.index.json` (tensor name → file) carrying the apply metadata — so the + artifact is portable, not tied to the trainer's parallelism layout. The snapshot is then + advanced to the new values for the next diff. +3. **Apply.** Each rollout host applies the new version's delta into its local checkpoint in + place. The apply is parallelized across tensors and verified per-tensor (see Integrity). +4. **Reload.** The engines reload the patched local checkpoint through the vanilla + `update_weights_from_disk` path — they never see the delta format. + +Because the snapshot is seeded from `--hf-checkpoint` (the engine's actual base) rather than +from the current GPU weights, the scheme is correct for any model even where the Megatron→HF +round-trip is not byte-exact (e.g. trimmed vocab-padding rows in the embedding / LM head). + +## Encodings + +Both encodings are byte-level and dtype-blind, so the same path works for quantized checkpoints. +The engine reads the choice from each version's index metadata. + +- **`xor`** (default): writes `new ^ old`. Smallest wire and fastest to apply (sequential, + cache-friendly; the unchanged bytes are zeros the compressor crushes). It is an involution, + so it must be applied **exactly once** against the correct base — applying it twice reverts. +- **`overwrite`**: writes the changed positions and their new absolute values. Larger on the + wire and a less cache-friendly scattered apply, but **idempotent**: re-applying it (or + finishing a partially-applied delta) converges to the same state regardless of how many times + it runs. Use it when re-applicability matters more than wire size. + +## Integrity + +The trainer stores a per-tensor checksum of each tensor's new state in the version. After +applying, every host recomputes the checksum and **raises on any mismatch**, so a corrupt delta +or a wrong base fails loud instead of serving bad weights. The apply also refuses to run out of +order: a version only applies on top of its declared base version. + +`--update-weight-delta-checksum` selects the algorithm. The checksum is not the apply bottleneck +(the apply is decompress + XOR bound), so this is a digest-property choice, not a speed one: +`xxh3-128` (default) is the widest fast non-cryptographic digest; `blake3` is cryptographic, for +untrusted storage; `adler32` is for interop with systems that expect it. + +## Shared-filesystem visibility hooks + +On a POSIX shared filesystem (NFS, Lustre, …) no extra step is needed. Object-store-backed +volumes that need an explicit commit/refresh to make writes visible across hosts can supply two +optional hooks, loaded by import path — no vendor-specific code lives in slime: + +- `--custom-delta-pre-push-path`: called after a version's files are written, before the engines + are told to read it (e.g. commit the volume). Signature: `hook(args, version_dir, rollout_engines)`. +- `--custom-delta-pre-read-path`: called on each rollout host before it reads the delta directory + (e.g. refresh the volume). Signature: `hook(delta_dir, target_version)`. diff --git a/docs/en/advanced/external-rollout-engines.md b/docs/en/advanced/external-rollout-engines.md index 498afa0c5f..6ff6f6da72 100644 --- a/docs/en/advanced/external-rollout-engines.md +++ b/docs/en/advanced/external-rollout-engines.md @@ -79,28 +79,16 @@ This keeps the full-checkpoint directories after engines acknowledge the load. ## Update With Delta -Delta update targets large-model training/inference disaggregation across clusters or datacenters. Instead of writing a full checkpoint, the trainer keeps a pinned-CPU snapshot of the previous sync, detects byte-level changes, and sends only changed positions and values. - -Recommended for cross-cluster / shared-filesystem deployments: +Delta update targets large-model training/inference disaggregation across clusters or datacenters. Instead of writing a full checkpoint every sync, the trainer keeps a CPU snapshot of the previous sync, diffs each parameter against it, and publishes only the changed bytes; every rollout host applies the delta into its local checkpoint and reloads via the vanilla `update_weights_from_disk` endpoint. ```bash --update-weight-mode delta --update-weight-transport disk ---update-weight-encoding deltas_zstd --update-weight-disk-dir /shared/fs/delta-updates +--update-weight-local-checkpoint-dir /local/nvme/rollout-ckpt ``` -With disk transport, each sync writes sparse safetensors under `weight_v{N:06d}/`, then calls `update_weights_from_disk(load_format="delta")`. SGLang overwrites only changed positions in the current weights; unchanged positions stay in place. - -For intra-datacenter validation or bandwidth-rich environments, NCCL transport is also available: - -```bash ---update-weight-mode delta ---update-weight-transport nccl ---update-weight-encoding indices -``` - -For encoding choices, wire layout, receiver-side selective overwrite, and tuning parameters, see [Delta Weight Sync](delta-weight-sync.md). +See [Delta Weight Sync](delta-weight-sync.md) for the mechanism, encodings, integrity checks, and shared-filesystem visibility hooks. ## Deployment Checklist diff --git a/docs/en/get_started/customization.md b/docs/en/get_started/customization.md index 77f5cd5e34..2878537c91 100644 --- a/docs/en/get_started/customization.md +++ b/docs/en/get_started/customization.md @@ -10,6 +10,7 @@ Below is a summary of all available customization interfaces and their purposes. | :--- | :--- | | [`--rollout-function-path`](#1-rollout-function---rollout-function-path) | Override the entire rollout generation logic. | | [`--custom-generate-function-path`](#2-custom-generate-function---custom-generate-function-path) | Override only the generation step (e.g., for RAG or tool use). | +| [`--custom-rollout-request-hook-path`](#mutating-the-outgoing-request---custom-rollout-request-hook-path) | Mutate each outgoing `/generate` request (e.g., custom headers). | | [`--custom-rm-path`](#3-reward-model---custom-rm-path) | Implement custom reward computation logic. | | [`--dynamic-sampling-filter-path`](#4-dynamic-sampling-filter---dynamic-sampling-filter-path) | Filter samples during dynamic sampling (e.g., DAPO). | | [`--buffer-filter-path`](#5-buffer-filter---buffer-filter-path) | Filter samples in the rollout buffer before training. | @@ -118,6 +119,19 @@ If one full trajectory has a single total reward but is split into `K` training **Example**: See [examples/search-r1/generate_with_search.py](../../../examples/search-r1/generate_with_search.py) +#### Mutating the outgoing request (`--custom-rollout-request-hook-path`) + +When you keep the built-in generate function but need to adjust each `/generate` request just before it is sent, use `--custom-rollout-request-hook-path` instead of replacing the whole generate step. The hook receives a `request` dict describing how the call is sent — `url`, `payload`, `headers`, `max_retries`, `retry_sleep` — plus `args` and `sample`. It either mutates `request` in place (returning `None`) or returns a dict of updates: + +```python +def hook(args, sample, request): + request["headers"] = {**(request["headers"] or {}), "Authorization": f"Bearer {get_token()}"} +``` + +Use it to add custom headers (auth tokens, routing keys), or for weight-version gating against an opaque rollout endpoint — set `request["payload"]["weight_version"]` so the fleet serves only a matching version, and raise `request["max_retries"]`/`request["retry_sleep"]` so slime backs off and waits for the fleet to load it. + +The hook may be `async`. It runs for both built-in generate paths (the default buffered one and `sglang_streaming_rollout.generate_streaming`) only when configured — otherwise the request is sent unchanged. Your own custom generate functions that post requests directly are responsible for their own request shaping (call `apply_rollout_request_hook` if you want the same behavior). + --- ### 3. Reward Model (`--custom-rm-path`) @@ -298,8 +312,6 @@ def get_pg_loss_reducer( - Dr.GRPO: Divide by a constant instead of effective token count - Custom loss normalization strategies -**Example**: `examples/DrGRPO/custom_reducer.py:get_pg_loss_reducer` - --- ### 12. Reward Post-Processing (`--custom-reward-post-process-path`) diff --git a/docs/en/index.rst b/docs/en/index.rst index f3401ce3a7..83230af1fa 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -38,9 +38,9 @@ Start by Use Case - Build agentic RL workflows: :doc:`get_started/agent` - Configure production SGLang rollout topology: :doc:`advanced/sglang-config` - Connect external rollout engines: :doc:`advanced/external-rollout-engines` +- Sync weights as byte-level deltas: :doc:`advanced/delta-weight-sync` - Use PD disaggregation: :doc:`advanced/pd-disaggregation` - Use BF16 training with FP8 rollout or FP8 KV cache: :doc:`advanced/low-precision` -- Use delta weight sync: :doc:`advanced/delta-weight-sync` - Understand CI and reliability coverage: :doc:`developer_guide/ci` - Debug, trace, and profile long-running jobs: :doc:`developer_guide/debug`, :doc:`developer_guide/trace`, :doc:`developer_guide/profiling` diff --git a/docs/zh/advanced/delta-weight-sync.md b/docs/zh/advanced/delta-weight-sync.md index f009dc954a..0aa5434472 100644 --- a/docs/zh/advanced/delta-weight-sync.md +++ b/docs/zh/advanced/delta-weight-sync.md @@ -1,107 +1,54 @@ # Delta 权重同步 -- [背景](#背景) -- [快速开始](#快速开始) -- [同步模式与传输方式](#同步模式与传输方式) -- [工作原理](#工作原理) -- [编码选择](#编码选择) -- [为何不支持 colocated](#为何不支持-colocated) +Delta 权重同步只发送两次同步之间发生变化的字节,而不是每次都写一份完整 checkpoint,以此让非 colocate 的 rollout engine 保持最新。它面向大模型、跨集群或跨数据中心的训推解耦场景——这种场景下每次都写整份 actor 权重是主要开销。 -## 背景 +它**只支持 disk transport**,并且通过**原生**的 `update_weights_from_disk` 端点 reload,因此推理引擎不需要任何 delta 相关的支持。 -slime 默认的权重同步会在每一步广播全部参数,开销随模型规模线性增长,即使每步真正变化的权重只有几个百分点。Delta 同步在内存中保留上一次同步后的参数快照(pinned CPU),只发送字节发生变化的位置。 - -最主要的应用场景是 **训练 / 推理跨数据中心解耦** —— 训练器和推理引擎运行在不同数据中心,通过共享文件系统通信(带宽通常在百 MB/s 级别)。在这种环境下,全量广播不可行,而 ~3% 密度的稀疏 delta(355B 模型约 5 GB)是可行的。同一套 delta 机制在数据中心内部跑 NCCL,作为验证基线,确认 wire 编码和 apply 逻辑正确。 - -参考资料:选择性覆写借鉴自 [arXiv:2509.19128](https://arxiv.org/abs/2509.19128),跨数据中心的动机来自 [Fireworks AI — Frontier RL Is Cheaper Than You Think](https://fireworks.ai/blog/frontier-rl-is-cheaper-than-you-think)。另一个接近生产形态的公开参考是 [Cursor Research Team 的 Composer 2 技术报告](https://arxiv.org/html/2603.24477v2):其中描述了 Cursor 与 Fireworks AI 合作运行 RL inference,并通过共享 S3、delta compression 和跨区域 inference 集群重建来同步每步训练权重。 - -## 快速开始 - -磁盘传输(跨数据中心训推解耦,主要场景): +## 配置 ```bash --update-weight-mode delta --update-weight-transport disk ---update-weight-encoding deltas_zstd # ≤ 300 MB/s 共享 FS 推荐 --update-weight-disk-dir /shared/fs/delta-updates +--update-weight-local-checkpoint-dir /local/nvme/rollout-ckpt +--update-weight-delta-encoding xor # 或: overwrite +--update-weight-delta-checksum xxh3-128 # 或: blake3, adler32 ``` -NCCL 传输(数据中心内部验证基线): - -```bash ---update-weight-mode delta ---update-weight-transport nccl ---update-weight-encoding indices # 计算最少,无压缩 -``` - -全量 checkpoint 磁盘传输(外部引擎的简单兜底路径): - -```bash ---update-weight-mode full ---update-weight-transport disk ---update-weight-disk-dir /shared/fs/full-updates -``` - -这会在每次同步时写一个完整 HF checkpoint 到 `weight_v{N:06d}/`,然后让每个 -SGLang engine 通过 `update_weights_from_disk` 重新加载。它适用于训练器无法和预启动 -rollout engine 建 NCCL group 的场景,但对大模型来说比 delta 同步重很多。 - -接收端 delta 调优(适用于 delta NCCL 和 delta 磁盘): - -```bash ---sglang-update-weight-delta-chunk-bytes $((2 * 1024 * 1024 * 1024)) # 每次 load_weights 字节上限 ---sglang-update-weight-delta-read-workers 4 # 并行 I/O 线程数(仅磁盘传输) -``` - -完整启动脚本见 [examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh](../../../examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh)。 - -## 同步模式与传输方式 - -`--update-weight-mode` 决定**发送什么**,`--update-weight-transport` 决定**如何送到 SGLang**。 +| 参数 | 作用 | +|---|---| +| `--update-weight-disk-dir` | 训练端发布 delta、rollout host 读取 delta 的共享文件系统目录。 | +| `--update-weight-local-checkpoint-dir` | host 本地(如 NVMe)的完整 HF checkpoint,delta 原地 apply 到这里。每个 host 在 engine 启动时由 `--hf-checkpoint` 物化。 | +| `--update-weight-delta-encoding` | 磁盘上的 delta 编码:`xor`(默认)或 `overwrite`。 | +| `--update-weight-delta-checksum` | 逐 tensor 完整性 checksum:`xxh3-128`(默认)、`blake3` 或 `adler32`。 | -| 同步模式 (`mode`) | 传输方式 (`transport`) | 行为 | -|---|---|---| -| `full` | `nccl` | 默认路径:通过训练器和 engine 之间的 NCCL group 广播所有 HF 权重 chunk | -| `full` | `disk` | 在 `--update-weight-disk-dir` 下写完整 HF checkpoint,然后调用 `update_weights_from_disk` | -| `delta` | `nccl` | 通过 NCCL 广播稀疏变化位置和值 | -| `delta` | `disk` | 在 `--update-weight-disk-dir` 下写稀疏 safetensors,然后调用 `update_weights_from_disk(load_format="delta")` | - -`--update-weight-delta-dir` 只保留为 `--update-weight-disk-dir` 的向后兼容 alias; -新启动脚本应该使用传输方式级别的目录参数。 +delta 始终用 zstd(level 1)压缩;profiling 显示对这类数据它在 wire 大小和解压速度上都优于 lz4 / gzip / snappy / brotli,所以不做成可配置项。 ## 工作原理 -Delta NCCL 和 delta 磁盘共用同一条发送管线、同一种 wire 布局以及同一套接收端解码器;只有每个 bucket 的承载层不同。 - -**发送端(每次同步,仅 PP 源 rank):** - -1. **求差**:通过逐字节比较 `current.view(int_dtype) != snapshot.view(int_dtype)` 检测变化。无算术、无损、与 dtype 无关。 -2. **编码**:将变化的 (位置, 值) 对打包成 `__positions__` 字节块 + `__values__` 张量 + per-param 解码 manifest。编码方式(`indices` / `deltas` / `deltas_zstd`)只影响位置如何打包,值始终按参数本身的 dtype 原样发送。 -3. **打包并发送**:每个 chunk 编码后累积至 `--update-weight-buffer-size` 字节再 flush: - - NCCL:广播 `(__positions__, __values__)`,Ray RPC 同时携带 `DeltaSpec`(编码 + per-param manifest)。 - - 磁盘:每个 flush 写一个 safetensors 文件到 `weight_v{N:06d}/` 目录,后台线程负责 I/O 和可选的 zstd 压缩,不阻塞关键路径。 -4. **更新快照**:刚发送的值在 side stream 上 D2H 拷贝,与下一个 chunk 的编码重叠。 +1. **Seed。** 第一次同步时,训练端为每个参数捕获一份 CPU snapshot——从 `--hf-checkpoint` seed,而这正是每个 rollout host 物化本地 checkpoint 的来源。此次不发布任何东西;这份 snapshot 就是下一次同步 diff 的基准。 +2. **Publish。** 之后每次同步,训练端把每个 gather 出的 HF tensor 与 snapshot 做 diff,编码、压缩,写到 `--update-weight-disk-dir` 下的新版本目录 `weight_v{N:06d}/`。该目录是一份 canonical HF checkpoint——`model-NNNNN.safetensors` 文件装着压缩后的 diff tensor,外加 `model.safetensors.index.json`(tensor 名 → 文件)承载 apply 元数据——所以这个产物是可移植的,不绑定训练端的并行 layout。随后 snapshot 推进到新值,供下次 diff。 +3. **Apply。** 每个 rollout host 把新版本的 delta 原地 apply 进它的本地 checkpoint。apply 在 tensor 之间并行,并逐 tensor 校验(见“完整性”)。 +4. **Reload。** engine 通过原生 `update_weights_from_disk` 路径 reload 打过补丁的本地 checkpoint——它从不接触 delta 格式。 -**同步结束(仅磁盘):** 写 `DONE` 标记,rank 0 对每个引擎触发一次 HTTP push,所有引擎确认后清理目录。 +由于 snapshot 是从 `--hf-checkpoint`(engine 真正的 base)seed,而不是从当前 GPU 权重 seed,即使 Megatron→HF 往返不是逐字节相等(例如 embedding / LM head 中被裁掉的 vocab padding 行),该方案对任意模型也都正确。 -**接收端:** 两种传输最终都进入同一个 `_apply_delta_payload(encoding, params, positions, values)` 帮助函数。它把每个参数的切片解码成全形状张量,未变化位置填 NaN,然后通过 `model.load_weights(...)` 应用;过程中 `_delta_apply_context` 替换 `Tensor.copy_` / `Tensor.fill_`,对参数存储执行 NaN 掩码覆写。辅助写入(scratch buffer、fp8 scale、MoE bias 等通过 `post_load_weights` 写入的派生张量)保留正常语义。 +## 编码 -选择性覆写没有任何算术运算 —— 接收端在变化位置直接写入训练端的精确字节 —— 因此天然无损,也不存在数值漂移问题,无需周期性 base 同步。 +两种编码都是字节级、与 dtype 无关的,所以量化 checkpoint 也走同一条路径。engine 从每个版本的 index 元数据读取所用编码。 -## 编码选择 +- **`xor`**(默认):写 `new ^ old`。wire 最小、apply 最快(顺序访问、对 cache 友好;未变化的字节是 0,被压缩器压到极小)。它是一个对合(involution),所以必须**恰好对正确的 base apply 一次**——apply 两次会还原。 +- **`overwrite`**:写变化的位置及其新的绝对值。wire 更大、apply 是对 cache 不友好的分散写,但**幂等**:重复 apply(或把部分 apply 的 delta 补完)无论执行多少次都收敛到同一状态。当“可重复 apply”比 wire 大小更重要时用它。 -`--update-weight-encoding` 决定位置如何打包。三种编码共用同一种 wire 布局(`__positions__` uint8 块 + `__values__` 张量 + per-param manifest),解码端根据 metadata 分派。 +## 完整性 -| 取值 | 位置编码 | 推荐场景 | -|---|---|---| -| `indices` | int32 绝对位置(4 字节 / nnz) | NCCL 或高速集群内 FS(≥ ~600 MB/s) | -| `deltas` | uint16 增量(异常时 uint32 兜底,2% 密度下约 2 字节 / nnz) | 中等带宽 FS(~300-500 MB/s) | -| `deltas_zstd` | `deltas` 文件再用 zstd L1 压缩 | 跨数据中心 / 跨区共享 FS(≤ ~300 MB/s) | +训练端把每个 tensor 新状态的逐 tensor checksum 存进版本里。apply 之后每个 host 重新计算 checksum,**任何不匹配都会 raise**,所以损坏的 delta 或错误的 base 会直接报错失败,而不会把坏权重提供出去。apply 还拒绝乱序执行:一个版本只会在它声明的 base 版本之上 apply。 -**为何 gap 编码更省**:`mask.nonzero()` 返回的位置已经升序排列。密度 `p` 时连续非零位置的期望间隔为 `1/p`,且 `P(gap > 65535) ≈ exp(-p · 65535)`,p = 2% 时这个概率实际上为零,所以 uint16 完全够用,uint32 仅作 per-param 兜底。位置开销比 `indices` 减半,且无损。 +`--update-weight-delta-checksum` 选择算法。checksum 不是 apply 的瓶颈(apply 受解压 + XOR 限制),所以这是一个 digest 属性的选择,而非速度选择:`xxh3-128`(默认)是最宽的快速非加密 digest;`blake3` 是加密 digest,用于不可信存储;`adler32` 用于与期望它的系统互操作。 -**`deltas_zstd` 的额外收益**:在 gap 字节流上做 zstd L1 还能再减少 ~35-40%,代价是每文件约 250ms 压缩 + 150ms 解压。当共享 FS 带宽 ≤ 300 MB/s 时,带宽节省超过额外计算开销。 +## 共享文件系统可见性 hook -## 为何不支持 colocated +在 POSIX 共享文件系统(NFS、Lustre……)上不需要额外步骤。对于需要显式 commit/refresh 才能让写入跨 host 可见的对象存储卷,可以提供两个可选 hook(通过 import 路径加载——slime 里不存在任何厂商特定代码): -Colocated 同步通过 CUDA IPC:进程间传递的只是一个内存句柄(~64 B)。Delta 编码的"wire 节省"在此为零,而其簿记开销(快照 + 求差 + 稀疏编码)反而是纯损失。slime 在参数校验阶段拒绝 `--update-weight-mode delta --colocate`。 +- `--custom-delta-pre-push-path`:在一个版本的文件写完之后、通知 engine 读取之前调用(例如 commit volume)。签名:`hook(args, version_dir, rollout_engines)`。 +- `--custom-delta-pre-read-path`:在每个 rollout host 读取 delta 目录之前调用(例如 refresh volume)。签名:`hook(delta_dir, target_version)`。 diff --git a/docs/zh/advanced/external-rollout-engines.md b/docs/zh/advanced/external-rollout-engines.md index 9aae0ef5ec..0007bcff69 100644 --- a/docs/zh/advanced/external-rollout-engines.md +++ b/docs/zh/advanced/external-rollout-engines.md @@ -79,28 +79,16 @@ full checkpoint update from disk 是 external 场景最简单的兜底路径: ## Update With Delta -delta update 面向大模型、跨集群或跨数据中心训推解耦。它不写完整 checkpoint,而是在训练端保留上一次同步后的 pinned CPU snapshot,逐字节检测变化,只发送变化位置和值。 - -跨集群 / 共享文件系统推荐: +delta update 面向大模型、跨集群或跨数据中心训推解耦。它不每次都写完整 checkpoint,而是在训练端保留上一次同步的 CPU snapshot,逐参数比对,只发布变化的字节;每个 rollout host 把 delta apply 进自己的本地 checkpoint,再通过原生 `update_weights_from_disk` 端点 reload。 ```bash --update-weight-mode delta --update-weight-transport disk ---update-weight-encoding deltas_zstd --update-weight-disk-dir /shared/fs/delta-updates +--update-weight-local-checkpoint-dir /local/nvme/rollout-ckpt ``` -在 disk transport 下,每次同步会写一组稀疏 safetensors 到 `weight_v{N:06d}/`,然后调用 `update_weights_from_disk(load_format="delta")`。SGLang 侧只把变化位置覆写到当前权重上,不变位置保持原值。 - -在同一数据中心内做实现验证或带宽不紧张时,也可以用 NCCL transport: - -```bash ---update-weight-mode delta ---update-weight-transport nccl ---update-weight-encoding indices -``` - -编码如何选择、delta wire layout、接收端 selective overwrite 以及调优参数见 [Delta 权重同步](delta-weight-sync.md)。 +机制、编码、完整性校验以及共享文件系统可见性 hook 详见 [Delta 权重同步](delta-weight-sync.md)。 ## 部署检查清单 diff --git a/docs/zh/get_started/customization.md b/docs/zh/get_started/customization.md index 5b95f05463..fd067c04c9 100644 --- a/docs/zh/get_started/customization.md +++ b/docs/zh/get_started/customization.md @@ -298,8 +298,6 @@ def get_pg_loss_reducer( - Dr.GRPO:除以常数而非有效 token 数 - 自定义损失归一化策略 -**示例**: `examples/DrGRPO/custom_reducer.py:get_pg_loss_reducer` - --- ### 12. 奖励后处理 (`--custom-reward-post-process-path`) diff --git a/docs/zh/index.rst b/docs/zh/index.rst index 7075a28b84..747deddf63 100644 --- a/docs/zh/index.rst +++ b/docs/zh/index.rst @@ -38,9 +38,9 @@ slime 的设计目标,是让这两大能力彼此强化,同时避免把系 - 构建 agentic RL workflow::doc:`get_started/agent` - 配置生产级 SGLang rollout topology::doc:`advanced/sglang-config` - 接入 external rollout engines::doc:`advanced/external-rollout-engines` +- 以字节级 delta 同步权重::doc:`advanced/delta-weight-sync` - 使用 PD disaggregation::doc:`advanced/pd-disaggregation` - 使用 BF16 训练 + FP8 rollout 或 FP8 KV cache::doc:`advanced/low-precision` -- 使用 delta weight sync::doc:`advanced/delta-weight-sync` - 了解 CI 和可靠性覆盖::doc:`developer_guide/ci` - 调试、trace 和 profiling 长时间任务::doc:`developer_guide/debug`、:doc:`developer_guide/trace`、:doc:`developer_guide/profiling` diff --git a/examples/README.md b/examples/README.md index 128b1562d4..4618f6414c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -11,7 +11,7 @@ These examples provide concrete examples to leverage slime in your own RL workfl - **[low_precision](./low_precision)**: Examples of FP8 training and inference for improved throughput and stability. - **[multi_agent](./multi_agent)**: Example of running multi-agent RL with `slime`. - **[on_policy_distillation](./on_policy_distillation)**: Example implementation for on-policy distillation, extending the reinforcement learning pipeline to support teacher–student distillation directly within on-policy training. -- **[delta_weight_sync](./delta_weight_sync)**: Non-colocated weight sync that ships only changed positions + values over disk (training/inference disaggregation) or NCCL. +- **[delta_weight_sync](./delta_weight_sync)**: Non-colocated weight sync that ships only the changed bytes over a shared filesystem (training/inference disaggregation), reloading via the vanilla `update_weights_from_disk` path. - **[reproducibility](./reproducibility)**: Guides on achieving bitwise experiment reproduction using deterministic modes. - **[retool](./retool)**: Demonstrates the retool functionality for tool-enabled language model generation. - **[search-r1](./search-r1)**: A minimal reproduction of Search-R1, featuring multi-turn conversation and tool-calling. diff --git a/examples/delta_weight_sync/README.md b/examples/delta_weight_sync/README.md index b2c7521578..0879ba9fcb 100644 --- a/examples/delta_weight_sync/README.md +++ b/examples/delta_weight_sync/README.md @@ -1,67 +1,40 @@ # Delta Weight Sync -Non-colocated weight sync that ships only changed positions + values instead of every parameter. Two transports over one wire format and one receiver-side decoder: +Non-colocated weight sync that ships only the **changed bytes** between two syncs instead of a +full checkpoint, for training/inference disaggregation across clusters or datacenters. The +trainer publishes per-tensor deltas to a shared filesystem as a canonical HF checkpoint +directory; each rollout host applies them into a host-local checkpoint and the engines reload +through the ordinary `update_weights_from_disk` path — the inference engine needs no +delta-specific support. -- **Disk** (the point) — write per-flush safetensors to a shared filesystem; one HTTP push per sync. Designed for **training/inference disaggregation** across datacenters where bandwidth between trainer and rollout is on the order of 100s of MB/s. -- **NCCL** (the baseline) — broadcast each per-flush bucket directly. Used intra-datacenter to validate that the wire encoding and apply logic are correct, separate from any shared-FS variable. +See [Delta Weight Sync](../../docs/en/advanced/delta-weight-sync.md) for the full mechanism, +encodings, integrity checks, and shared-filesystem visibility hooks. -Both modes are lossless by construction (selective overwrite via NaN sentinel; no arithmetic). +## Try it -## Files +`run-glm4.7-30B-A3B-delta.sh` runs the disk delta path on GLM-4.7-Flash, non-colocated across a +2-node (16-GPU) Ray cluster. See its header for prerequisites. -- `run-glm4.7-355B-A32B-delta.sh`: 16-node (8 actor + 8 rollout) GLM-4.7-355B-A32B launcher. Disk transport active by default; NCCL block commented below it. +## Minimal flags -## Usage +Add to a non-colocated training run (the trainer and engines only need to share the filesystem +at `--update-weight-disk-dir`): ```bash -bash examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh +--update-weight-mode delta \ +--update-weight-transport disk \ +--update-weight-disk-dir /shared/fs/delta-updates \ +--update-weight-local-checkpoint-dir /local/nvme/rollout-ckpt \ +--update-weight-delta-encoding xor \ +--update-weight-delta-checksum xxh3-128 ``` -**Disk (default):** +- `--update-weight-disk-dir` — shared directory the trainer writes deltas to and the hosts read. +- `--update-weight-local-checkpoint-dir` — host-local full HF checkpoint the delta patches in + place; materialized from `--hf-checkpoint` at engine start. +- `--update-weight-delta-encoding` — `xor` (smallest/fastest) or `overwrite` (idempotent). +- `--update-weight-delta-checksum` — `xxh3-128` (default), `blake3`, or `adler32`. -```bash -DELTA_ARGS=( - --update-weight-mode delta - --update-weight-transport disk - --update-weight-encoding deltas_zstd - --update-weight-disk-dir /shared/fs/delta-updates -) -``` - -**NCCL (baseline):** - -```bash -DELTA_ARGS=( - --update-weight-mode delta - --update-weight-transport nccl - --update-weight-encoding indices -) -``` - -Receiver-side byte cap (both transports): - -```bash ---sglang-update-weight-delta-chunk-bytes $((2 * 1024 * 1024 * 1024)) -``` - -See [docs/en/advanced/delta-weight-sync.md](../../docs/en/advanced/delta-weight-sync.md) for the wire protocol, encoding choice, and design. - -## Results - -W&B traces comparing delta sync against the full-sync baseline on GLM-4.7-355B-A32B / DAPO-Math-17k. - -![Raw reward](./raw_reward.png) - -![Train/rollout logprob abs diff](./train_rollout_logprob_abs_diff.png) - -![Update weights time](./update_weights_time.png) - -> **Note on the small curve-to-curve gap.** RL training is inherently non-deterministic (cuBLAS reductions, FlashAttention split-K, NCCL all-reduce ordering, dynamic-batch token assignment). Two identically-configured *full*-sync runs would diverge the same way. Delta sync's selective overwrite is bit-exact with full sync per step (no arithmetic, no drift); the trajectory matches, the bits don't. - -![Update weights density](./update_weights_density.png) - -*Per-sync change density (`perf/update_weights_density`) — fraction of weight positions that moved between consecutive syncs. Sync 0 is omitted: it's the snapshot-seeding pass with density = 1.0, which would compress the y-axis.* - -## Why these encoding defaults - -Per-sync change density during RL fine-tuning at conservative LRs sits around **2-3%** ([arXiv:2602.03839](https://arxiv.org/pdf/2602.03839) reports ~1% on a related setup; we measured ~2-3% on this run). Below the 3.125% break-even point, gap-encoded positions are smaller than absolute indices — the disk default `deltas_zstd` adds zstd L1 on top to squeeze the gap byte stream further (~35-40%), which is the right tradeoff when shared-FS bandwidth is ≤ 300 MB/s. Intra-datacenter NCCL has no bandwidth pressure, so `indices` (lowest compute, biggest payload) is the cleaner default there. +For object-store-backed volumes that need an explicit commit/refresh to make writes visible +across hosts, supply `--custom-delta-pre-push-path` / `--custom-delta-pre-read-path` (no +vendor-specific code lives in slime; see the doc). diff --git a/examples/delta_weight_sync/run-glm4.7-30B-A3B-delta.sh b/examples/delta_weight_sync/run-glm4.7-30B-A3B-delta.sh new file mode 100644 index 0000000000..a399b20bbe --- /dev/null +++ b/examples/delta_weight_sync/run-glm4.7-30B-A3B-delta.sh @@ -0,0 +1,109 @@ +#!/bin/bash +# Disk delta weight-sync demo on GLM-4.7-Flash (30B-A3B), non-colocated, 2 nodes x 8 GPU. +# The trainer publishes per-tensor deltas to --update-weight-disk-dir as a canonical HF directory; +# each rollout host applies them into --update-weight-local-checkpoint-dir and reloads via the +# vanilla update_weights_from_disk path. +# +# Prerequisites: +# - A 2-node (16-GPU) Ray cluster, this script run on the head node. +# - GLM-4.7-Flash HF checkpoint + its torch_dist conversion (tools/convert_hf_to_torch_dist.py). +# - dapo-math-17k.jsonl. +# - --update-weight-disk-dir on a filesystem both nodes share. On an object-store-backed volume +# that needs an explicit commit/refresh to surface writes across hosts, also pass +# --custom-delta-pre-push-path / --custom-delta-pre-read-path (see the doc). + +set -ex +export PYTHONUNBUFFERED=1 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../scripts/models/glm4.7-30B-A3B.sh" + +MODEL_DIR=${MODEL_DIR:-/root/models/GLM-4.7-Flash} +DATA_PATH=${DATA_PATH:-/root/datasets/dapo-math-17k/dapo-math-17k.jsonl} + +CKPT_ARGS=( + --hf-checkpoint "${MODEL_DIR}" + --ref-load "${MODEL_DIR}_torch_dist" +) + +ROLLOUT_ARGS=( + --prompt-data "${DATA_PATH}" + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 3 + --rollout-batch-size 32 + --n-samples-per-prompt 4 + --rollout-max-response-len 8192 + --global-batch-size 128 +) + +# Disk delta weight sync (the point of this example). +WEIGHT_SYNC_ARGS=( + --update-weight-mode delta + --update-weight-transport disk + --update-weight-disk-dir /shared/fs/glm47-delta-updates + --update-weight-local-checkpoint-dir /local/nvme/glm47-rollout-ckpt + --update-weight-delta-encoding xor + --update-weight-delta-checksum xxh3-128 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --pipeline-model-parallel-size 2 + --context-parallel-size 2 + --expert-model-parallel-size 8 + --expert-tensor-parallel-size 1 + --sequence-parallel + --use-dynamic-batch-size + --max-tokens-per-gpu 32768 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.0 + --kl-loss-type low_var_kl +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 8 + --sglang-mem-fraction-static 0.8 + --sglang-enable-dp-attention + --sglang-dp-size 8 +) + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" + } +}" + +# Non-colocated: 16 actor GPUs (2 x 8) train while a 16-GPU rollout pool generates (delta mode +# requires non-colocation). +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 2 \ + --actor-num-gpus-per-node 8 \ + --rollout-num-gpus 16 \ + ${MODEL_ARGS[@]} \ + "${CKPT_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${WEIGHT_SYNC_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" diff --git a/examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh b/examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh deleted file mode 100755 index 9df77c4eff..0000000000 --- a/examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh +++ /dev/null @@ -1,192 +0,0 @@ -#!/bin/bash - -# Non-colocated GLM-4.7-355B-A32B with delta weight sync. -# 8 actor nodes (TP=8, PP=4, EP=16) + 64 rollout GPUs (8 H100 nodes worth), 16 nodes total. -# Disk transport is active by default; the NCCL block below it is commented out. - -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - -set -ex - -export PYTHONUNBUFFERED=1 -unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 -fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -source "/root/slime/scripts/models/glm4.5-355B-A32B.sh" - -CKPT_ARGS=( - --hf-checkpoint /root/GLM-4.7-355B-A32B - --ref-load /root/GLM-4.7-355B-A32B_torch_dist/ -) - -ROLLOUT_ARGS=( - --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type deepscaler - --num-rollout 3000 - --rollout-batch-size 64 - --n-samples-per-prompt 8 - --rollout-max-response-len 8192 - --rollout-temperature 1 - - --num-steps-per-rollout 4 - --balance-data - --rollout-stop-token-ids 151329 151336 151338 -) - -EVAL_ARGS=( - --eval-interval 20 - --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 8 - --eval-max-response-len 8192 - --eval-top-p 1 -) - -PERF_ARGS=( - --tensor-model-parallel-size 8 - --sequence-parallel - --pipeline-model-parallel-size 4 - --context-parallel-size 2 - --expert-model-parallel-size 16 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - --use-dynamic-batch-size - --max-tokens-per-gpu 16384 -) - -GRPO_ARGS=( - --advantage-estimator gspo - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --kl-coef 0.00 - --entropy-coef 0.00 - --eps-clip 1e-4 - --eps-clip-high 2e-4 - --use-tis -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 - - --optimizer-cpu-offload - --overlap-cpu-optimizer-d2h-h2d - --use-precision-aware-optimizer -) - -WANDB_ARGS=( - # --use-wandb - # --wandb-project slime-delta - # --wandb-group glm4.7-355B-delta -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 32 - --sglang-mem-fraction-static 0.7 - --sglang-enable-dp-attention - --sglang-dp-size 4 - --sglang-ep-size 32 - --sglang-enable-dp-lm-head - --sglang-moe-dense-tp-size 1 - - # Receiver batches up to this many bytes per model.load_weights call. Bigger - # amortizes per-call cost (name resolution, MoE expert remap) but raises peak HBM. - --sglang-update-weight-delta-chunk-bytes $((2 * 1024 * 1024 * 1024)) - - # Max parallel I/O threads for reading delta files from disk (disk transport only). - --sglang-update-weight-delta-read-workers 4 - - # mtp - --sglang-speculative-algorithm EAGLE - --sglang-speculative-num-steps 3 - --sglang-speculative-eagle-topk 1 - --sglang-speculative-num-draft-tokens 4 -) - -# Delta weight sync. Pick one of the two blocks below. - -# ── Disk (default) — for training/inference disaggregation across datacenters ──── -# `deltas_zstd` is the right pick when shared-FS bandwidth is ≤ ~300 MB/s. -DELTA_ARGS=( - --update-weight-mode delta - --update-weight-transport disk - --update-weight-encoding deltas_zstd - --update-weight-disk-dir /shared/fs/delta-updates -) - -# ── NCCL (baseline) — intra-datacenter, no shared FS ──────────────────────────── -# DELTA_ARGS=( -# --update-weight-mode delta -# --update-weight-transport nccl -# --update-weight-encoding indices -# ) - -MISC_ARGS=( - --attention-dropout 0.0 - --hidden-dropout 0.0 - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - --attention-backend flash - --moe-token-dispatcher-type flex - --moe-enable-deepep - --update-weight-buffer-size $((2 * 1024 * 1024 * 1024)) -) - -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - -RUNTIME_ENV_JSON=$(cat <=0.2.3 tensorboard transformers wandb +xxhash # disk delta weight sync (checksum + codec) +zstandard diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 941659f1fe..702d1570af 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -142,12 +142,14 @@ def init( ), "--update-weight-mode=delta is not supported with --colocate" update_weight_cls = UpdateWeightFromTensor elif self.args.update_weight_mode == "delta": - # Lazy import: the delta module pulls DeltaEncoding/DeltaParam/DeltaSpec from - # sglang, which only exist on newer images. Importing eagerly would break old - # images even when delta mode is unused. - from .update_weight.update_weight_from_distributed_delta import UpdateWeightFromDistributedDelta + # Delta sync is disk-transport only: each host applies the published deltas into + # its local checkpoint and the engines reload via vanilla update_weights_from_disk. + assert ( + self.args.update_weight_transport == "disk" + ), "--update-weight-mode=delta requires --update-weight-transport=disk" + from .update_weight.update_weight_from_disk_delta import UpdateWeightFromDiskDelta - update_weight_cls = UpdateWeightFromDistributedDelta + update_weight_cls = UpdateWeightFromDiskDelta else: assert self.args.update_weight_mode == "full" if self.args.update_weight_transport == "disk": @@ -255,38 +257,20 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: for mm_dict in rollout_data["multimodal_train_inputs"] ] - if self.args.qkv_format == "bshd": - # TODO: micro-batch wise dynamic, possibly move to @data.py:get_data_iterator - max_seq_len = max(rollout_data["total_lengths"]) - - # pad to reduce memory fragmentation and maybe make the computation faster - pad_size = mpu.get_tensor_model_parallel_world_size() * self.args.data_pad_size_multiplier - max_seq_len = (max_seq_len + pad_size - 1) // pad_size * pad_size - - rollout_data["max_seq_lens"] = [max_seq_len] * len(rollout_data["tokens"]) - for key in ["rollout_log_probs", "teacher_log_probs"]: if key not in rollout_data: continue rollout_data[key] = [ - slice_log_prob_with_cp( - log_prob, - total_length, - response_length, - self.args.qkv_format, - rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, - ).to( + slice_log_prob_with_cp(log_prob, total_length, response_length).to( device=device, dtype=torch.float32, non_blocking=True, ) - for i, (log_prob, total_length, response_length) in enumerate( - zip( - rollout_data[key], - rollout_data["total_lengths"], - rollout_data["response_lengths"], - strict=False, - ) + for log_prob, total_length, response_length in zip( + rollout_data[key], + rollout_data["total_lengths"], + rollout_data["response_lengths"], + strict=False, ) ] return rollout_data @@ -390,6 +374,7 @@ def compute_log_prob( data_iterator, num_microbatches, store_prefix=store_prefix, + use_rollout_top_p_replay=True, ) def train(self, rollout_id: int, rollout_data_ref: Box, external_data=None): @@ -604,13 +589,21 @@ def update_weights(self) -> None: ray.get(self.rollout_manager.recover_updatable_engines.remote()) dist.barrier(group=get_gloo_group()) - rollout_engines, rollout_engine_lock, num_new_engines, engine_gpu_counts, engine_gpu_offsets = ray.get( - self.rollout_manager.get_updatable_engines_and_lock.remote() - ) + ( + rollout_engines, + rollout_engine_lock, + num_new_engines, + engine_gpu_counts, + engine_gpu_offsets, + all_engine_actors, + ) = ray.get(self.rollout_manager.get_updatable_engines_and_lock.remote()) reconnect_rollout_engines = self.args.offload_train and self.args.use_critic and not self.args.colocate + # An opaque HTTP rollout fleet exposes no engine handles; the trainer publishes the delta to + # disk instead of pushing, so it still runs update_weights (and connects once) with no engines. + publish_only = bool(getattr(self.args, "rollout_endpoint_url", None)) - if not rollout_engines and not reconnect_rollout_engines: + if not rollout_engines and not reconnect_rollout_engines and not publish_only: if dist.get_rank() == 0: logger.info("No updatable SGLang engines are running; skip weight update.") return @@ -620,12 +613,13 @@ def update_weights(self) -> None: elif self.args.offload_train: reload_process_groups() - if num_new_engines > 0 or reconnect_rollout_engines: + if num_new_engines > 0 or reconnect_rollout_engines or publish_only: self.weight_updater.connect_rollout_engines( rollout_engines, rollout_engine_lock, engine_gpu_counts=engine_gpu_counts, engine_gpu_offsets=engine_gpu_offsets, + all_engine_actors=all_engine_actors, ) dist.barrier(group=get_gloo_group()) if dist.get_rank() == 0: diff --git a/slime/backends/megatron_utils/cp_utils.py b/slime/backends/megatron_utils/cp_utils.py index 448c154c6c..a97c45cc42 100644 --- a/slime/backends/megatron_utils/cp_utils.py +++ b/slime/backends/megatron_utils/cp_utils.py @@ -9,8 +9,6 @@ def get_logits_and_tokens_offset_with_cp( total_length: int, response_length: int, - qkv_format: str = "thd", - max_seq_len: int | None = None, ): """ All offsets start from the begining of the prompt. @@ -20,11 +18,7 @@ def get_logits_and_tokens_offset_with_cp( assert cp_size > 1 prompt_length = total_length - response_length - if qkv_format == "thd": - chunk_size = (total_length + 2 * cp_size - 1) // (2 * cp_size) - else: - assert max_seq_len is not None, "max_seq_len must be provided for qkv_format=bshd" - chunk_size = (max_seq_len + 2 * cp_size - 1) // (2 * cp_size) + chunk_size = (total_length + 2 * cp_size - 1) // (2 * cp_size) # the offset of 2 chunks chunk_0 = (cp_rank * chunk_size, (cp_rank + 1) * chunk_size) @@ -56,8 +50,6 @@ def get_sum_of_sample_mean( loss_masks: list[torch.Tensor], sample_denoms: list[torch.Tensor] | torch.Tensor | None = None, calculate_per_token_loss: bool = False, - qkv_format: str = "thd", - max_seq_lens: list[int] | None = None, ) -> Callable[[torch.Tensor], torch.Tensor]: """ Calculate correct sample mean for CP. @@ -100,18 +92,14 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: cp_chunk_lengths: list[int] = [] chunked_loss_masks: list[torch.Tensor] = [] - for i, (total_length, response_length, loss_mask) in enumerate( - zip(total_lengths, response_lengths, loss_masks, strict=False) - ): - max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None + for total_length, response_length, loss_mask in zip(total_lengths, response_lengths, loss_masks, strict=False): prompt_length = total_length - response_length - _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_seq_len - ) + _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp(total_length, response_length) loss_mask_0 = loss_mask[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length] loss_mask_1 = loss_mask[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length] - chunked_loss_masks.append(torch.cat([loss_mask_0, loss_mask_1], dim=0)) - cp_chunk_lengths.append(chunked_loss_masks[i].size(0)) + chunked_loss_mask = torch.cat([loss_mask_0, loss_mask_1], dim=0) + chunked_loss_masks.append(chunked_loss_mask) + cp_chunk_lengths.append(chunked_loss_mask.size(0)) def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor: return sum( @@ -299,15 +287,10 @@ def zero(len: int) -> torch.Tensor: def slice_with_cp( tokens: torch.Tensor, pad_value: tuple[int, float, Callable], - qkv_format: str = "thd", - max_seq_len: int | None = None, ) -> torch.Tensor: cp_rank = mpu.get_context_parallel_rank() cp_size = mpu.get_context_parallel_world_size() - if qkv_format == "bshd": - assert max_seq_len is not None - def pad_tokens(tokens, pad): if isinstance(pad_value, Callable): pad_func = pad_value @@ -319,16 +302,10 @@ def pad_tokens(tokens, pad): return tokens if cp_size == 1: - if qkv_format == "bshd": - pad = max_seq_len - tokens.size(0) - tokens = pad_tokens(tokens, pad) return tokens token_len = len(tokens) - if qkv_format == "thd": - chunk_size = (token_len + 2 * cp_size - 1) // (2 * cp_size) - else: - chunk_size = (max_seq_len + 2 * cp_size - 1) // (2 * cp_size) + chunk_size = (token_len + 2 * cp_size - 1) // (2 * cp_size) # pad pad = 2 * cp_size * chunk_size - token_len @@ -344,8 +321,6 @@ def slice_log_prob_with_cp( log_prob: list[float] | torch.Tensor, total_length: int, response_length: int, - qkv_format: str = "thd", - max_token_len: int | None = None, ) -> list[float] | torch.Tensor: assert len(log_prob) == response_length, ( f"log_prob length mismatch: len(log_prob)={len(log_prob)}, " @@ -358,9 +333,7 @@ def slice_log_prob_with_cp( return log_prob prompt_length = total_length - response_length - _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_token_len - ) + _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp(total_length, response_length) chunk_1 = log_prob[logits_offset[0][0] - (prompt_length - 1) : logits_offset[0][1] - (prompt_length - 1)] chunk_2 = log_prob[logits_offset[1][0] - (prompt_length - 1) : logits_offset[1][1] - (prompt_length - 1)] diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 7c12a5a778..00f319928f 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -29,7 +29,6 @@ def get_batch( data_iterator: "DataIterator", keys: Sequence[str], pad_multiplier: int = 128, - qkv_format: str = "thd", allgather_cp: bool = False, ) -> dict[str, torch.Tensor | PackedSeqParams | list[torch.Tensor] | None]: """ @@ -67,63 +66,53 @@ def get_batch( cp_size = mpu.get_context_parallel_world_size() cp_rank = mpu.get_context_parallel_rank() - if qkv_format == "bshd": - max_seqlen = batch["max_seq_lens"][0] - assert max([t.size(0) for t in tokens]) <= max_seqlen - tokens = [slice_with_cp(t, pad_token_id, qkv_format, max_seqlen) for t in tokens] - tokens = torch.stack(tokens) - packed_seq_params = None + if allgather_cp: + # DSA mode: concatenate all sequences first, then slice once with CP. + # We also pad the *global* concatenated stream to make per-rank chunks equal. + cu_seqlens_list: list[int] = [0] + for t in tokens: + cu_seqlens_list.append(cu_seqlens_list[-1] + t.size(0)) - elif qkv_format == "thd": - if allgather_cp: - # DSA mode: concatenate all sequences first, then slice once with CP. - # We also pad the *global* concatenated stream to make per-rank chunks equal. - cu_seqlens_list: list[int] = [0] - for t in tokens: - cu_seqlens_list.append(cu_seqlens_list[-1] + t.size(0)) - - tokens = torch.cat(tokens, dim=0) - - # Pad global stream so (1) divisible by cp_size (equal chunks), - # (2) divisible by pad_size (reduce fragmentation). - global_pad_size = cp_size * pad_size - pad = (global_pad_size - tokens.size(0) % global_pad_size) % global_pad_size - if pad != 0: - tokens = F.pad(tokens, (0, pad), value=pad_token_id) - cu_seqlens_list.append(cu_seqlens_list[-1] + pad) - - cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int, device=torch.cuda.current_device()) - tokens = tokens.chunk(cp_size, dim=0)[cp_rank] - else: - tokens = [slice_with_cp(t, pad_token_id, qkv_format) for t in tokens] - - cu_seqlens = [0] - for t in tokens: - cu_seqlens.append(cu_seqlens[-1] + t.size(0)) - - tokens = torch.cat(tokens) - - # Always pad to reduce memory fragmentation and maybe make the computation faster - pad = (pad_size - tokens.size(0) % pad_size) % pad_size - if pad != 0: - tokens = F.pad(tokens, (0, pad), value=pad_token_id) - cu_seqlens.append(cu_seqlens[-1] + pad) - - # thd requires the cu_seqlens to be of the origin length - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int).cuda() * cp_size - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - packed_seq_params = PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - qkv_format="thd", - ) - - tokens = tokens.unsqueeze(0) + tokens = torch.cat(tokens, dim=0) + + # Pad global stream so (1) divisible by cp_size (equal chunks), + # (2) divisible by pad_size (reduce fragmentation). + global_pad_size = cp_size * pad_size + pad = (global_pad_size - tokens.size(0) % global_pad_size) % global_pad_size + if pad != 0: + tokens = F.pad(tokens, (0, pad), value=pad_token_id) + cu_seqlens_list.append(cu_seqlens_list[-1] + pad) + + cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int, device=torch.cuda.current_device()) + tokens = tokens.chunk(cp_size, dim=0)[cp_rank] else: - raise ValueError(f"Unsupported qkv_format: {qkv_format}") + tokens = [slice_with_cp(t, pad_token_id) for t in tokens] + + cu_seqlens = [0] + for t in tokens: + cu_seqlens.append(cu_seqlens[-1] + t.size(0)) + + tokens = torch.cat(tokens) + + # Always pad to reduce memory fragmentation and maybe make the computation faster + pad = (pad_size - tokens.size(0) % pad_size) % pad_size + if pad != 0: + tokens = F.pad(tokens, (0, pad), value=pad_token_id) + cu_seqlens.append(cu_seqlens[-1] + pad) + + # thd requires the cu_seqlens to be of the origin length + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int).cuda() * cp_size + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + qkv_format="thd", + ) + + tokens = tokens.unsqueeze(0) batch["tokens"] = tokens batch["packed_seq_params"] = packed_seq_params @@ -142,18 +131,16 @@ def get_batch( if allgather_cp: loss_masks.append(loss_mask) continue - loss_mask = slice_with_cp(loss_mask, 0, qkv_format, max_seqlen) + loss_mask = slice_with_cp(loss_mask, 0) loss_masks.append(loss_mask) - if qkv_format == "bshd": - loss_masks = torch.stack(loss_masks) - elif qkv_format == "thd" and allgather_cp: + if allgather_cp: # DSA: concatenate first (same as tokens), pad globally (same pad as above), then slice once. loss_masks = torch.cat(loss_masks, dim=0) if pad != 0: loss_masks = F.pad(loss_masks, (0, pad), value=0) loss_masks = loss_masks.chunk(cp_size, dim=0)[cp_rank].unsqueeze(0) - elif qkv_format == "thd": + else: loss_masks = torch.cat(loss_masks) loss_masks = F.pad(loss_masks, (0, pad), value=0).unsqueeze(0) @@ -278,7 +265,6 @@ def log_rollout_data( response_lengths = rollout_data["response_lengths"] loss_masks = rollout_data["loss_masks"] total_lengths = rollout_data["total_lengths"] - max_seq_lens = rollout_data.get("max_seq_lens", None) # Same per-rollout denominators the training loss uses, so reported # log_probs / returns / advantages / etc. live in the same per-rollout # mean space (rather than per-sample) as the gradient signal. @@ -298,8 +284,9 @@ def log_rollout_data( "sample_indices", "rollout_ids", "rollout_mask_sums", + "rollout_top_p_token_ids", + "rollout_top_p_token_offsets", "rollout_routed_experts", - "max_seq_lens", "global_batch_sizes", "num_microbatches", "micro_batch_indices", @@ -329,8 +316,6 @@ def log_rollout_data( response_lengths, loss_masks, rollout_mask_sums, - qkv_format=args.qkv_format, - max_seq_lens=max_seq_lens, ) # Compute (sum, count) via the shared helper so this # path and the unit tests stay in sync. diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index a456939e74..72afdfa66c 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -31,6 +31,25 @@ slice_log_prob_with_cp, ) +ROLLOUT_TOP_P_TOKEN_KEYS = ( + "rollout_top_p_token_ids", + "rollout_top_p_token_offsets", +) + + +def get_rollout_top_p_logprob_kwargs(args: Namespace, batch: dict[str, Any]) -> dict[str, Any]: + if args.rollout_top_p == 1.0: + return {} + + top_p_token_ids = batch.get("rollout_top_p_token_ids") + top_p_token_offsets = batch.get("rollout_top_p_token_offsets") + if top_p_token_ids is None or top_p_token_offsets is None: + raise ValueError("rollout_top_p != 1.0 requires rollout_top_p_token_ids and rollout_top_p_token_offsets.") + return { + "top_p_token_ids": top_p_token_ids, + "top_p_token_offsets": top_p_token_offsets, + } + def get_responses( logits: torch.Tensor, @@ -39,7 +58,6 @@ def get_responses( unconcat_tokens: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int], - max_seq_lens: list[int] | None = None, apply_temperature: bool = True, ) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: """Yield response-aligned `(logits_chunk, tokens_chunk)` pairs per sample. @@ -64,17 +82,10 @@ def get_responses( `[R, V]` (policy) or `[R, 1]` (value) and `tokens_chunk` is shape `[R]` (1D int64), both aligned to response tokens for one sample. """ - qkv_format = args.qkv_format - assert logits.dtype == torch.float32, f"{logits.dtype}" assert len(logits.shape) == 3, f"{logits.shape}" - - if qkv_format == "thd": - assert logits.size(0) == 1, f"{logits.shape}" - logits = logits.squeeze(0) - else: - assert max_seq_lens is not None - logits = logits.view(-1, logits.size(-1)) + assert logits.size(0) == 1, f"{logits.shape}" + logits = logits.squeeze(0) if apply_temperature and args.rollout_temperature != 1.0: logits = logits.div(args.rollout_temperature) @@ -82,18 +93,10 @@ def get_responses( cp_size = mpu.get_context_parallel_world_size() end = 0 seq_start = 0 - for i, (tokens, total_length, response_length) in enumerate( - zip(unconcat_tokens, total_lengths, response_lengths, strict=False) - ): - max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None - + for tokens, total_length, response_length in zip(unconcat_tokens, total_lengths, response_lengths, strict=False): if cp_size == 1: - if qkv_format == "bshd": - end = max_seq_len * i + total_length - start = end - response_length - else: - end += total_length - start = end - response_length + end += total_length + start = end - response_length logits_chunk = logits[start - 1 : end - 1] tokens_chunk = tokens[-response_length:] elif args.allgather_cp: @@ -122,7 +125,7 @@ def get_responses( else: # TODO: this is super ugly... do better abstraction. chunk_size, chunks_offset, logits_offset, tokens_offset = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_seq_len + total_length, response_length ) logits_0, logits_1 = logits[end : end + chunk_size], logits[end + chunk_size : end + 2 * chunk_size] @@ -149,10 +152,8 @@ def _allgather_cp_redistribute( res: dict[str, list[torch.Tensor]], *, logits_local_len: int, - args: Namespace, total_lengths: list[int], response_lengths: list[int], - max_seq_lens: list[int] | None = None, ) -> None: """Redistribute response tensors from allgather-CP layout to zigzag ring-attn layout. @@ -166,10 +167,8 @@ def _allgather_cp_redistribute( Args: res: Dict mapping metric names to lists of per-sample tensors. logits_local_len: Local sequence length on this rank. - args: Configuration (needs ``qkv_format``). total_lengths: Total sequence lengths (prompt + response) per sample. response_lengths: Response segment lengths per sample. - max_seq_lens: Optional padded max sequence lengths per sample. """ cp_group = mpu.get_context_parallel_group() cp_rank = mpu.get_context_parallel_rank() @@ -220,13 +219,10 @@ def _allgather_cp_redistribute( # Re-slice each sample into zigzag CP pattern new_values = [] - for idx, (full_resp, total_length, response_length) in enumerate( - zip(all_cat.split(response_lengths, dim=0), total_lengths, response_lengths, strict=False) + for full_resp, total_length, response_length in zip( + all_cat.split(response_lengths, dim=0), total_lengths, response_lengths, strict=False ): - max_seq_len = max_seq_lens[idx] if max_seq_lens is not None else None - new_values.append( - slice_log_prob_with_cp(full_resp, total_length, response_length, args.qkv_format, max_seq_len) - ) + new_values.append(slice_log_prob_with_cp(full_resp, total_length, response_length)) res[key] = new_values @@ -237,8 +233,6 @@ def _build_shifted_tokens( unconcat_tokens: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int], - qkv_format: str, - max_seq_lens: list[int] | None, allgather_cp: bool, ) -> torch.Tensor: """Build shifted target tokens for the full packed/padded logits.""" @@ -248,12 +242,11 @@ def _build_shifted_tokens( if cp_size > 1 and not allgather_cp: full_tokens = torch.zeros(T, dtype=torch.long, device=device) end = 0 - for i, (tokens, total_length, response_length) in enumerate( - zip(unconcat_tokens, total_lengths, response_lengths, strict=False) + for tokens, total_length, response_length in zip( + unconcat_tokens, total_lengths, response_lengths, strict=False ): - max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None chunk_size_cp, chunks_offset, logits_offset, tokens_offset = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_seq_len + total_length, response_length ) for half, base in ((0, end), (1, end + chunk_size_cp)): lo = logits_offset[half][0] - chunks_offset[half][0] @@ -266,15 +259,10 @@ def _build_shifted_tokens( T_global = sum(total_lengths) if allgather_cp else T full_tokens = torch.zeros(T_global, dtype=torch.long, device=device) - if qkv_format == "thd" or allgather_cp: - offset = 0 - for tokens, total_length in zip(unconcat_tokens, total_lengths, strict=False): - full_tokens[offset : offset + total_length - 1] = tokens[1:total_length] - offset += total_length - else: # bshd, cp1 - for i, (tokens, total_length) in enumerate(zip(unconcat_tokens, total_lengths, strict=False)): - seq_start = max_seq_lens[i] * i - full_tokens[seq_start : seq_start + total_length - 1] = tokens[1:total_length] + offset = 0 + for tokens, total_length in zip(unconcat_tokens, total_lengths, strict=False): + full_tokens[offset : offset + total_length - 1] = tokens[1:total_length] + offset += total_length # allgather-CP: slice to local chunk if allgather_cp: @@ -292,13 +280,117 @@ def _build_shifted_tokens( return full_tokens +def _fill_topp_mask_rows( + keep: torch.Tensor, + ids: list[int], + offsets: list[int], + response_start: int, + local_start: int, + length: int, + vocab_start: int, + vocab_end: int, +) -> None: + end = min(response_start + length, max(len(offsets) - 1, 0)) + for response_idx in range(response_start, end): + local_ids = [ + token_id - vocab_start + for token_id in ids[offsets[response_idx] : offsets[response_idx + 1]] + if vocab_start <= token_id < vocab_end + ] + row = local_start + response_idx - response_start + keep[row].fill_(False) + if local_ids: + keep[row, torch.tensor(local_ids, device=keep.device, dtype=torch.long)] = True + + +def _build_topp_keep_mask( + T: int, + vocab_local: int, + device: torch.device, + top_p_token_ids: list[list[int]], + top_p_token_offsets: list[list[int]], + total_lengths: list[int], + response_lengths: list[int], + allgather_cp: bool, +) -> torch.Tensor: + """Build a ``[T, vocab_local]`` boolean keep-mask aligned to local logits. + + For response token ``r`` of a sample, the rollout top-p nucleus is + ``ids[offsets[r]:offsets[r + 1]]``. Rows without a recorded nucleus stay + all-True, so only response rows with replay data are masked. + """ + cp_size = mpu.get_context_parallel_world_size() + tp_rank = mpu.get_tensor_model_parallel_rank() + vocab_start = tp_rank * vocab_local + vocab_end = vocab_start + vocab_local + + # Normalize ragged payloads (may arrive as CPU int32 tensors) to python lists. + top_p_token_ids = [t.tolist() if torch.is_tensor(t) else list(t) for t in top_p_token_ids] + top_p_token_offsets = [t.tolist() if torch.is_tensor(t) else list(t) for t in top_p_token_offsets] + + keep = torch.ones((T, vocab_local), dtype=torch.bool, device=device) + + if cp_size > 1 and not allgather_cp: + local_base = 0 + for ids, offsets, total_length, response_length in zip( + top_p_token_ids, top_p_token_offsets, total_lengths, response_lengths, strict=False + ): + prompt_length = total_length - response_length + chunk_size_cp, chunks_offset, logits_offset, tokens_offset = get_logits_and_tokens_offset_with_cp( + total_length, response_length + ) + for half, base in ((0, local_base), (1, local_base + chunk_size_cp)): + local_start = base + logits_offset[half][0] - chunks_offset[half][0] + length = logits_offset[half][1] - logits_offset[half][0] + response_start = tokens_offset[half][0] - prompt_length + _fill_topp_mask_rows(keep, ids, offsets, response_start, local_start, length, vocab_start, vocab_end) + local_base += 2 * chunk_size_cp + return keep + + if allgather_cp: + cp_rank = mpu.get_context_parallel_rank() + chunk_start = cp_rank * T + chunk_end = chunk_start + T + seq_start = 0 + for ids, offsets, total_length, response_length in zip( + top_p_token_ids, top_p_token_offsets, total_lengths, response_lengths, strict=False + ): + prompt_length = total_length - response_length + logit_global_start = seq_start + prompt_length - 1 + logit_global_end = seq_start + total_length - 1 + s = max(logit_global_start, chunk_start) + e = min(logit_global_end, chunk_end) + if e > s: + _fill_topp_mask_rows( + keep, + ids, + offsets, + s - logit_global_start, + s - chunk_start, + e - s, + vocab_start, + vocab_end, + ) + seq_start += total_length + return keep + + offset = 0 + for ids, offsets, total_length, response_length in zip( + top_p_token_ids, top_p_token_offsets, total_lengths, response_lengths, strict=False + ): + end = offset + total_length + start = end - response_length + _fill_topp_mask_rows(keep, ids, offsets, 0, start - 1, response_length, vocab_start, vocab_end) + offset += total_length + + return keep + + def _extract_per_sample( log_prob_full: torch.Tensor, entropy_full: torch.Tensor | None, total_lengths: list[int], response_lengths: list[int], - qkv_format: str, - max_seq_lens: list[int] | None, allgather_cp: bool, ) -> tuple[list[torch.Tensor], list[torch.Tensor | None]]: """Slice per-sample response log-probs/entropy from full-length 1-D tensors.""" @@ -309,10 +401,9 @@ def _extract_per_sample( if cp_size > 1 and not allgather_cp: # zigzag CP pos = 0 - for i, (total_length, response_length) in enumerate(zip(total_lengths, response_lengths, strict=False)): - max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None + for total_length, response_length in zip(total_lengths, response_lengths, strict=False): chunk_size_cp, chunks_offset, logits_offset, _tokens_offset = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_seq_len + total_length, response_length ) lo0 = logits_offset[0][0] - chunks_offset[0][0] hi0 = logits_offset[0][1] - chunks_offset[0][0] @@ -364,22 +455,14 @@ def _extract_per_sample( else: # cp1 - if qkv_format == "thd": - offset = 0 - for total_length, response_length in zip(total_lengths, response_lengths, strict=False): - end = offset + total_length - start = end - response_length - log_probs_list.append(log_prob_full[start - 1 : end - 1]) - if entropy_full is not None: - entropy_list.append(entropy_full[start - 1 : end - 1]) - offset += total_length - else: # bshd - for i, (total_length, response_length) in enumerate(zip(total_lengths, response_lengths, strict=False)): - end = max_seq_lens[i] * i + total_length - start = end - response_length - log_probs_list.append(log_prob_full[start - 1 : end - 1]) - if entropy_full is not None: - entropy_list.append(entropy_full[start - 1 : end - 1]) + offset = 0 + for total_length, response_length in zip(total_lengths, response_lengths, strict=False): + end = offset + total_length + start = end - response_length + log_probs_list.append(log_prob_full[start - 1 : end - 1]) + if entropy_full is not None: + entropy_list.append(entropy_full[start - 1 : end - 1]) + offset += total_length return log_probs_list, entropy_list @@ -393,7 +476,8 @@ def get_log_probs_and_entropy( response_lengths: list[int], with_entropy: bool = False, non_loss_data: bool = True, - max_seq_lens: list[int] | None = None, + top_p_token_ids: list[list[int]] | None = None, + top_p_token_offsets: list[list[int]] | None = None, ) -> dict[str, list[torch.Tensor]]: """Compute per-token log-probabilities (and optionally entropy) on responses. @@ -405,17 +489,10 @@ def get_log_probs_and_entropy( to avoid retaining the computation graph and to skip cloning. """ assert non_loss_data - qkv_format = args.qkv_format - assert logits.dtype == torch.float32, f"{logits.dtype}" assert len(logits.shape) == 3, f"{logits.shape}" - - if qkv_format == "thd": - assert logits.size(0) == 1, f"{logits.shape}" - logits = logits.squeeze(0) - else: - assert max_seq_lens is not None - logits = logits.view(-1, logits.size(-1)) + assert logits.size(0) == 1, f"{logits.shape}" + logits = logits.squeeze(0) # Apply rollout temperature scaling to logits to match rollout-time log-probs. rollout_temperature = getattr(args, "rollout_temperature", 1.0) @@ -428,9 +505,21 @@ def get_log_probs_and_entropy( chunk_size = args.log_probs_chunk_size # --- build full shifted-token target tensor --- - full_tokens = _build_shifted_tokens( - T, device, unconcat_tokens, total_lengths, response_lengths, qkv_format, max_seq_lens, args.allgather_cp - ) + full_tokens = _build_shifted_tokens(T, device, unconcat_tokens, total_lengths, response_lengths, args.allgather_cp) + + # --- build top-p nucleus keep-mask (logprob only; entropy stays unmasked) --- + top_p_keep_mask = None + if top_p_token_ids is not None and top_p_token_offsets is not None: + top_p_keep_mask = _build_topp_keep_mask( + T, + logits.size(-1), + device, + top_p_token_ids, + top_p_token_offsets, + total_lengths, + response_lengths, + args.allgather_cp, + ) # --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy --- log_prob_full, entropy_full = calculate_log_probs_and_entropy( @@ -439,6 +528,7 @@ def get_log_probs_and_entropy( tp_group, with_entropy=with_entropy, chunk_size=chunk_size, + log_prob_keep_mask=top_p_keep_mask, ) log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T] @@ -448,8 +538,6 @@ def get_log_probs_and_entropy( entropy_full, total_lengths, response_lengths, - qkv_format, - max_seq_lens, args.allgather_cp, ) @@ -462,10 +550,8 @@ def get_log_probs_and_entropy( _allgather_cp_redistribute( res, logits_local_len=T, - args=args, total_lengths=total_lengths, response_lengths=response_lengths, - max_seq_lens=max_seq_lens, ) return torch.empty((0,), device=device), res @@ -480,7 +566,6 @@ def get_values( response_lengths: list[int], with_entropy: bool = False, non_loss_data: bool = True, - max_seq_lens: list[int] | None = None, ) -> dict[str, list[torch.Tensor]]: """Extract per-token value predictions over response tokens. @@ -508,7 +593,6 @@ def get_values( unconcat_tokens=unconcat_tokens, total_lengths=total_lengths, response_lengths=response_lengths, - max_seq_lens=max_seq_lens, apply_temperature=False, ): assert logits_chunk.size(-1) == 1, f"{logits_chunk.shape}" @@ -522,10 +606,8 @@ def get_values( _allgather_cp_redistribute( res, logits_local_len=logits.size(1), - args=args, total_lengths=total_lengths, response_lengths=response_lengths, - max_seq_lens=max_seq_lens, ) return torch.empty((0,), device=logits.device), res @@ -607,8 +689,6 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) response_lengths: list[int] = rollout_data.get("response_lengths") loss_masks: list[torch.Tensor] = rollout_data.get("loss_masks") total_lengths: list[int] = rollout_data.get("total_lengths") - max_seq_lens: list[int] | None = rollout_data.get("max_seq_lens", None) - # return when not the last pp stage. if not mpu.is_pipeline_last_stage(): return @@ -700,11 +780,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) total_len = total_lengths[i] response_len = response_lengths[i] prompt_len = total_len - response_len - max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None - - _, _, _, token_offsets = get_logits_and_tokens_offset_with_cp( - total_len, response_len, args.qkv_format, max_seq_len - ) + _, _, _, token_offsets = get_logits_and_tokens_offset_with_cp(total_len, response_len) # Convert global offsets to response-space offsets s0, e0 = token_offsets[0] @@ -833,7 +909,6 @@ def policy_loss_function( response_lengths = batch["response_lengths"] total_lengths = batch["total_lengths"] - max_seq_lens = batch.get("max_seq_lens", None) _, log_probs_and_entropy = get_log_probs_and_entropy( logits, @@ -842,7 +917,7 @@ def policy_loss_function( total_lengths=total_lengths, response_lengths=response_lengths, with_entropy=True, - max_seq_lens=max_seq_lens, + **get_rollout_top_p_logprob_kwargs(args, batch), ) log_probs = log_probs_and_entropy["log_probs"] @@ -947,8 +1022,6 @@ def policy_loss_function( modified_response_masks, batch["rollout_mask_sums"], args.calculate_per_token_loss, - args.qkv_format, - max_seq_lens, ) # Determine pg_loss reducer: use custom if specified, otherwise default @@ -1063,7 +1136,6 @@ def value_loss_function( unconcat_tokens=batch["unconcat_tokens"], total_lengths=batch["total_lengths"], response_lengths=batch["response_lengths"], - max_seq_lens=batch.get("max_seq_lens", None), ) values = torch.cat([value.flatten() for value in values["values"]], dim=0) @@ -1122,7 +1194,6 @@ def sft_loss_function( total_lengths=total_lengths, response_lengths=response_lengths, with_entropy=False, - max_seq_lens=batch.get("max_seq_lens", None), ) log_probs = log_probs_and_entropy["log_probs"] @@ -1183,8 +1254,6 @@ def loss_function( batch["loss_masks"], batch["rollout_mask_sums"], args.calculate_per_token_loss, - args.qkv_format, - batch.get("max_seq_lens", None), ) match args.loss_type: diff --git a/slime/backends/megatron_utils/megatron_to_hf/deepseekv3.py b/slime/backends/megatron_utils/megatron_to_hf/deepseekv3.py index 205b025556..97742c9541 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/deepseekv3.py +++ b/slime/backends/megatron_utils/megatron_to_hf/deepseekv3.py @@ -2,6 +2,8 @@ import torch +from .dtype_utils import to_model_dtype + def convert_deepseekv3_to_hf(args, name, param): if name == "module.module.embedding.word_embeddings.weight": @@ -126,9 +128,9 @@ def convert_deepseekv3_to_hf(args, name, param): elif rest == "pre_mlp_layernorm.weight": return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] elif rest == "mlp.router.weight": - return [(f"model.layers.{layer_idx}.mlp.gate.weight", param)] + return [(f"model.layers.{layer_idx}.mlp.gate.weight", to_model_dtype(args, param))] elif rest == "mlp.router.expert_bias": - return [(f"model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", param)] + return [(f"model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", to_model_dtype(args, param))] mtp_layer_pattern = r"module\.module\.mtp\.layers\.(\d+)\.(.+)" match = re.match(mtp_layer_pattern, name) diff --git a/slime/backends/megatron_utils/megatron_to_hf/dtype_utils.py b/slime/backends/megatron_utils/megatron_to_hf/dtype_utils.py new file mode 100644 index 0000000000..9394864b13 --- /dev/null +++ b/slime/backends/megatron_utils/megatron_to_hf/dtype_utils.py @@ -0,0 +1,17 @@ +import torch + + +def to_model_dtype(args, param): + """Cast a router param back to the model dtype before export. + + The MoE router runs in fp32 (--moe-router-dtype fp32), so Megatron can hold its weight / + expert_bias buffer in fp32 even when the model dtype is bf16/fp16. The HF base checkpoint + stores those buffers in the model dtype, and update_weight_from_disk_delta XORs each freshly + converted tensor against the base HF bytes — so a leftover fp32 router is a byte-width mismatch + against a bf16/fp16 base. Cast back so the exported byte shape matches the base on disk. + """ + if getattr(args, "bf16", False): + return param.to(torch.bfloat16) + if getattr(args, "fp16", False): + return param.to(torch.float16) + return param diff --git a/slime/backends/megatron_utils/megatron_to_hf/glm4moe.py b/slime/backends/megatron_utils/megatron_to_hf/glm4moe.py index 33a64e6e8f..f518d3b559 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/glm4moe.py +++ b/slime/backends/megatron_utils/megatron_to_hf/glm4moe.py @@ -2,6 +2,8 @@ import torch +from .dtype_utils import to_model_dtype + def convert_glm4moe_to_hf(args, name, param): if name == "module.module.embedding.word_embeddings.weight": @@ -108,9 +110,9 @@ def convert_glm4moe_to_hf(args, name, param): elif rest == "pre_mlp_layernorm.weight": return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] elif rest == "mlp.router.weight": - return [(f"model.layers.{layer_idx}.mlp.gate.weight", param)] + return [(f"model.layers.{layer_idx}.mlp.gate.weight", to_model_dtype(args, param))] elif rest == "mlp.router.expert_bias": - return [(f"model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", param)] + return [(f"model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", to_model_dtype(args, param))] # qk norm elif rest == "self_attention.q_layernorm.weight": diff --git a/slime/backends/megatron_utils/megatron_to_hf/gpt_oss.py b/slime/backends/megatron_utils/megatron_to_hf/gpt_oss.py index b90507f04b..93bdab5ef0 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/gpt_oss.py +++ b/slime/backends/megatron_utils/megatron_to_hf/gpt_oss.py @@ -2,6 +2,8 @@ import torch +from .dtype_utils import to_model_dtype + def convert_gpt_oss_to_hf(args, name, param): """Convert Megatron GPT-OSS parameter names to HF format for weight update to SGLang.""" @@ -98,8 +100,8 @@ def convert_gpt_oss_to_hf(args, name, param): return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] # Router elif rest == "mlp.router.weight": - return [(f"model.layers.{layer_idx}.mlp.router.weight", param)] + return [(f"model.layers.{layer_idx}.mlp.router.weight", to_model_dtype(args, param))] elif rest == "mlp.router.bias": - return [(f"model.layers.{layer_idx}.mlp.router.bias", param)] + return [(f"model.layers.{layer_idx}.mlp.router.bias", to_model_dtype(args, param))] raise ValueError(f"Unknown parameter name: {name}") diff --git a/slime/backends/megatron_utils/megatron_to_hf/minimax_m2.py b/slime/backends/megatron_utils/megatron_to_hf/minimax_m2.py index 3701e30f2f..ad820817b7 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/minimax_m2.py +++ b/slime/backends/megatron_utils/megatron_to_hf/minimax_m2.py @@ -2,6 +2,8 @@ import torch +from .dtype_utils import to_model_dtype + def convert_minimax_m2_to_hf(args, name, param): """Convert Megatron parameter names/tensors to HuggingFace format for MiniMax-M2.5. @@ -79,8 +81,10 @@ def convert_minimax_m2_to_hf(args, name, param): # Router elif rest == "mlp.router.weight": - return [(f"model.layers.{layer_idx}.block_sparse_moe.gate.weight", param)] + return [(f"model.layers.{layer_idx}.block_sparse_moe.gate.weight", to_model_dtype(args, param))] elif rest == "mlp.router.expert_bias": - return [(f"model.layers.{layer_idx}.block_sparse_moe.e_score_correction_bias", param)] + return [ + (f"model.layers.{layer_idx}.block_sparse_moe.e_score_correction_bias", to_model_dtype(args, param)) + ] raise ValueError(f"Unknown parameter name: {name}") diff --git a/slime/backends/megatron_utils/megatron_to_hf/qwen3_5.py b/slime/backends/megatron_utils/megatron_to_hf/qwen3_5.py index 2aabd86eba..f2f4f4dfbc 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/qwen3_5.py +++ b/slime/backends/megatron_utils/megatron_to_hf/qwen3_5.py @@ -2,6 +2,8 @@ import torch +from .dtype_utils import to_model_dtype + def _convert_mtp_layer(args, name, param, layer_idx): """Convert MTP layer parameters from Megatron to HuggingFace format.""" @@ -158,9 +160,9 @@ def convert_qwen3_5_to_hf(args, name, param): elif rest == "pre_mlp_layernorm.weight": return [(f"{prefix}.post_attention_layernorm.weight", param)] elif rest == "mlp.router.weight": - return [(f"{prefix}.mlp.gate.weight", param)] + return [(f"{prefix}.mlp.gate.weight", to_model_dtype(args, param))] elif rest == "mlp.router.expert_bias": - return [(f"{prefix}.mlp.gate.e_score_correction_bias", param)] + return [(f"{prefix}.mlp.gate.e_score_correction_bias", to_model_dtype(args, param))] # qk norm elif rest == "self_attention.q_layernorm.weight": diff --git a/slime/backends/megatron_utils/megatron_to_hf/qwen3_next.py b/slime/backends/megatron_utils/megatron_to_hf/qwen3_next.py index f12f31195f..619cb84ef8 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/qwen3_next.py +++ b/slime/backends/megatron_utils/megatron_to_hf/qwen3_next.py @@ -2,6 +2,8 @@ import torch +from .dtype_utils import to_model_dtype + def _convert_mtp_layer(args, name, param, layer_idx): """Convert MTP layer parameters from Megatron to HuggingFace format. @@ -161,9 +163,9 @@ def convert_qwen3_next_to_hf(args, name, param): elif rest == "pre_mlp_layernorm.weight": return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] elif rest == "mlp.router.weight": - return [(f"model.layers.{layer_idx}.mlp.gate.weight", param)] + return [(f"model.layers.{layer_idx}.mlp.gate.weight", to_model_dtype(args, param))] elif rest == "mlp.router.expert_bias": - return [(f"model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", param)] + return [(f"model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", to_model_dtype(args, param))] # qk norm elif rest == "self_attention.q_layernorm.weight": diff --git a/slime/backends/megatron_utils/megatron_to_hf/qwen3moe.py b/slime/backends/megatron_utils/megatron_to_hf/qwen3moe.py index 9f5b5b81a6..474929ab30 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/qwen3moe.py +++ b/slime/backends/megatron_utils/megatron_to_hf/qwen3moe.py @@ -2,6 +2,8 @@ import torch +from .dtype_utils import to_model_dtype + def convert_qwen3moe_to_hf(args, name, param): if name == "module.module.embedding.word_embeddings.weight": @@ -104,9 +106,9 @@ def convert_qwen3moe_to_hf(args, name, param): elif rest == "pre_mlp_layernorm.weight": return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] elif rest == "mlp.router.weight": - return [(f"model.layers.{layer_idx}.mlp.gate.weight", param)] + return [(f"model.layers.{layer_idx}.mlp.gate.weight", to_model_dtype(args, param))] elif rest == "mlp.router.expert_bias": - return [(f"model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", param)] + return [(f"model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", to_model_dtype(args, param))] # qk norm elif rest == "self_attention.q_layernorm.weight": diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index db6020a94d..4c18963f9c 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -33,7 +33,7 @@ from .checkpoint import load_checkpoint, save_checkpoint from .cp_utils import reduce_train_step_metrics from .data import DataIterator, get_batch -from .loss import loss_function +from .loss import ROLLOUT_TOP_P_TOKEN_KEYS, get_rollout_top_p_logprob_kwargs, loss_function from .model_provider import get_model_provider_func logger = logging.getLogger(__name__) @@ -73,6 +73,12 @@ def wrapped_forward_step(*args, **kwargs): return wrapped_forward_step +def _with_rollout_top_p_token_keys(args: Namespace, keys: Sequence[str]) -> list[str]: + if args.rollout_top_p == 1.0: + return list(keys) + return [*keys, *ROLLOUT_TOP_P_TOKEN_KEYS] + + def _iter_critic_output_layers(model: Sequence[DDP]): for chunk_id, module in enumerate(unwrap_model(model)): output_layer = getattr(module, "output_layer", None) @@ -265,6 +271,7 @@ def forward_only( data_iterator: Sequence[DataIterator], num_microbatches: Sequence[int], store_prefix: str = "", + use_rollout_top_p_replay: bool = False, ) -> dict[str, list[torch.Tensor]]: """Run forward passes only and collect non-loss outputs (e.g., logprobs). @@ -284,6 +291,8 @@ def forward_only( data_iterator (Sequence[DataIterator]): Iterable(s) yielding batches for inference. num_microbatches (Sequence[int]): Number of microbatches per rollout step. store_prefix (str): Prefix to prepend to stored output keys. + use_rollout_top_p_replay (bool): Whether to pass rollout top-p token sets + to the post-forward log-prob callback when top-p rollout is enabled. Returns: dict[str, list[torch.Tensor]]: Aggregated outputs keyed by ``store_prefix + key``. @@ -294,6 +303,15 @@ def forward_only( iterator.reset() config = get_model_config(model[0]) + batch_keys = [ + "tokens", + "loss_masks", + "multimodal_train_inputs", + "total_lengths", + "response_lengths", + ] + if use_rollout_top_p_replay: + batch_keys = _with_rollout_top_p_token_keys(args, batch_keys) def forward_step( data_iterator: DataIterator, model: GPTModel, return_schedule_plan: bool = False @@ -315,16 +333,8 @@ def forward_step( # Get the batch. batch = get_batch( data_iterator, - [ - "tokens", - "loss_masks", - "multimodal_train_inputs", - "total_lengths", - "response_lengths", - "max_seq_lens", - ], + batch_keys, args.data_pad_size_multiplier, - args.qkv_format, args.allgather_cp, ) unconcat_tokens = batch["unconcat_tokens"] @@ -344,15 +354,17 @@ def forward_step( forward_kwargs.update(batch["multimodal_train_inputs"]) output_tensor = model(**forward_kwargs) - return output_tensor, partial( - f, - args=args, - unconcat_tokens=unconcat_tokens, - total_lengths=total_lengths, - response_lengths=response_lengths, - with_entropy=args.use_rollout_entropy, - max_seq_lens=batch.get("max_seq_lens", None), - ) + output_kwargs = { + "args": args, + "unconcat_tokens": unconcat_tokens, + "total_lengths": total_lengths, + "response_lengths": response_lengths, + "with_entropy": args.use_rollout_entropy, + } + if use_rollout_top_p_replay: + output_kwargs.update(get_rollout_top_p_logprob_kwargs(args, batch)) + + return output_tensor, partial(f, **output_kwargs) # Turn on evaluation mode which disables dropout. for model_module in model: @@ -486,25 +498,26 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p # Get the batch. batch = get_batch( data_iterator, - [ - "tokens", - "multimodal_train_inputs", - "packed_seq_params", - "total_lengths", - "response_lengths", - "loss_masks", - "log_probs", - "ref_log_probs", - "values", - "advantages", - "returns", - "rollout_log_probs", - "max_seq_lens", - "teacher_log_probs", - "rollout_mask_sums", - ], + _with_rollout_top_p_token_keys( + args, + [ + "tokens", + "multimodal_train_inputs", + "packed_seq_params", + "total_lengths", + "response_lengths", + "loss_masks", + "log_probs", + "ref_log_probs", + "values", + "advantages", + "returns", + "rollout_log_probs", + "teacher_log_probs", + "rollout_mask_sums", + ], + ), args.data_pad_size_multiplier, - args.qkv_format, args.allgather_cp, ) diff --git a/slime/backends/megatron_utils/sglang.py b/slime/backends/megatron_utils/sglang.py index 801217310d..97c82a31cd 100644 --- a/slime/backends/megatron_utils/sglang.py +++ b/slime/backends/megatron_utils/sglang.py @@ -13,15 +13,6 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions -try: - from sglang.srt.managers.io_struct import DeltaEncoding, DeltaParam, DeltaSpec -except ImportError: - # Older sglang images don't have delta-sync io_struct. Only --update-weight-mode=delta - # needs these; the default full-sync path runs without them. - DeltaEncoding = None - DeltaParam = None - DeltaSpec = None - from sglang.srt.utils import MultiprocessingSerializer @@ -37,7 +28,4 @@ "monkey_patch_torch_reductions", "MultiprocessingSerializer", "FlattenedTensorBucket", - "DeltaEncoding", - "DeltaParam", - "DeltaSpec", ] diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_disk.py b/slime/backends/megatron_utils/update_weight/update_weight_from_disk.py index bb0e0df72a..a5f81d9263 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_disk.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_disk.py @@ -46,6 +46,7 @@ def connect_rollout_engines( rollout_engine_lock: ActorHandle, engine_gpu_counts: Sequence[int] | None = None, engine_gpu_offsets: Sequence[int] | None = None, + all_engine_actors: Sequence[ActorHandle] | None = None, ) -> None: self.rollout_engines = rollout_engines self.rollout_engine_lock = rollout_engine_lock diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py b/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py new file mode 100644 index 0000000000..acb929ce5c --- /dev/null +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py @@ -0,0 +1,315 @@ +from __future__ import annotations + +import json +import logging +import os +import queue +import shutil +from argparse import Namespace +from collections import deque +from collections.abc import Callable, Mapping, Sequence +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +import ray +import safetensors.numpy +import torch +import torch.distributed as dist +import zstandard +from megatron.core import mpu +from ray.actor import ActorHandle + +from slime.utils.disk_delta import NUM_WORKERS, checksum, make_tensor_reader, overwrite_encode +from slime.utils.distributed_utils import get_gloo_group + +from .update_weight_from_distributed import UpdateWeightFromDistributed + +logger = logging.getLogger(__name__) + + +class UpdateWeightFromDiskDelta(UpdateWeightFromDistributed): + """ + Delta weight sync over a shared filesystem. PP-src ranks diff each gathered HF tensor against + a CPU snapshot of the previous sync and publish the changes as a canonical HF checkpoint dir; + every rollout host applies the delta into its local checkpoint and reloads via the ordinary + update_weights_from_disk path, so sglang needs no delta support. + """ + + def __init__( + self, + args: Namespace, + model: Sequence[torch.nn.Module], + weights_getter: Callable[[], Mapping[str, torch.Tensor]], + *, + model_name: str, + quantization_config: dict[str, int | str | list[str]] | None, + ) -> None: + super().__init__(args, model, weights_getter, model_name=model_name, quantization_config=quantization_config) + self.delta_dir = args.update_weight_disk_dir + os.makedirs(self.delta_dir, exist_ok=True) + self.delta_encoding = args.update_weight_delta_encoding + self.checksum_algorithm = args.update_weight_delta_checksum + self._snapshot: dict[str, np.ndarray] = {} + self._baseline_captured = False + # Opaque HTTP rollout: no engine handles, so publish the version to disk and let the fleet + # pull it, instead of pushing via per-engine RPCs. + self._publish_only = bool(getattr(args, "rollout_endpoint_url", None)) + self._commit_hook: Callable | None = None + if args.custom_delta_pre_push_path: + from slime.utils.misc import load_function + + self._commit_hook = load_function(args.custom_delta_pre_push_path) + + def connect_rollout_engines( + self, + rollout_engines: Sequence[ActorHandle], + rollout_engine_lock: ActorHandle, + engine_gpu_counts: Sequence[int] | None = None, + engine_gpu_offsets: Sequence[int] | None = None, + all_engine_actors: Sequence[ActorHandle] | None = None, + ) -> None: + # The local checkpoint is host-local, so every host applies its own copy: + # all_engine_actors is one actor per host, vs rollout_engines (node 0 only). The + # rollout_engine_lock the NCCL path uses isn't needed — a per-host flock serializes applies. + self.rollout_engines = rollout_engines + self.all_engine_actors = list(all_engine_actors or rollout_engines) + self._is_pp_src_rank = ( + mpu.get_data_parallel_rank(with_context_parallel=True) == 0 and mpu.get_tensor_model_parallel_rank() == 0 + ) + + def disconnect_rollout_engines(self) -> None: + pass # no NCCL groups to tear down + + @torch.no_grad() + def update_weights(self) -> None: + # The first call only captures the baseline snapshot the next sync diffs against. + if not self._baseline_captured: + self._capture_baseline() + self._baseline_captured = True + return + + self.weight_version += 1 + if dist.get_rank() == 0 and not self._publish_only: + ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) + ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) + dist.barrier(group=get_gloo_group()) + + self._publish() + if self._publish_only: + self._announce_version() + else: + self._reload_engines() + self._record_metrics() + + def _capture_baseline(self) -> None: + """Capture the baseline snapshot the first delta diffs against (no publish), and clear any + stale stream from a prior run. Seeds from hf_checkpoint — what each host materializes its + base from — so the invariant ``snapshot == engine base`` holds even where the megatron->HF + round-trip trims vocab-padding rows (embed/lm_head). A tensor absent there (rare) falls back + to the gathered value.""" + # a prior run's versions would apply against the wrong base; start the dir clean + if dist.get_rank() == 0: + shutil.rmtree(self.delta_dir, ignore_errors=True) + os.makedirs(self.delta_dir, exist_ok=True) + if self._commit_hook is not None: + self._commit_hook(self.args, self.delta_dir, list(self.rollout_engines)) + dist.barrier(group=get_gloo_group()) + + read_hf = make_tensor_reader(self.args.hf_checkpoint) # index the HF headers once + for name, tensor in self._iter_hf_tensors(): + try: + self._snapshot[name] = read_hf(name) + except KeyError: + self._snapshot[name] = tensor.detach().cpu().contiguous().view(torch.uint8).numpy().reshape(-1) + logger.warning("seed: %s absent from hf_checkpoint; seeding from current weights", name) + if dist.get_rank() == 0: + logger.info( + "[disk delta] captured baseline snapshot of %d tensors from %s", + len(self._snapshot), + self.args.hf_checkpoint, + ) + + def _publish(self) -> None: + """Encode this version's changed tensors (PP-src ranks), then write it as a canonical HF dir.""" + self._encode_delta() + dist.barrier(group=get_gloo_group()) + self._write_delta_files() + + def _write_delta_files(self) -> None: + """Write this rank's changed tensors as one canonical model-NNNNN.safetensors, and on rank + 0 the HF index. The sequential file numbers and the index are coordinated over gloo (small + object gathers), not the filesystem — a shared volume may not surface one rank's writes to + another until commit.""" + group = get_gloo_group() + world, rank = dist.get_world_size(), dist.get_rank() + + # number the files sequentially across only the ranks that have one (no gaps) + counts: list = [None] * world + dist.all_gather_object(counts, int(bool(self._delta)), group=group) + offset, total = sum(counts[:rank]), sum(counts) + + fname = None + self.wire_bytes = 0 + if self._delta: + fname = f"model-{offset:05d}-of-{total:05d}.safetensors" + blob = safetensors.numpy.save(self._delta, metadata=self._checksums) + self.wire_bytes = len(blob) + _atomic_write(os.path.join(self._version_dir, fname), blob) + + maps: list = [None] * world + dist.all_gather_object(maps, {name: fname for name in self._delta}, group=group) + if rank == 0: + index = { + "metadata": { + "version": f"{self.weight_version:06d}", + "base_version": f"{self.weight_version - 1:06d}", + "delta_encoding": self.delta_encoding, + "compression_format": "zstd", + "checksum_format": self.checksum_algorithm, + }, + "weight_map": {name: f for m in maps for name, f in m.items()}, + } + _atomic_write(os.path.join(self._version_dir, "model.safetensors.index.json"), json.dumps(index).encode()) + dist.barrier(group=group) + + def _reload_engines(self) -> None: + """Commit the published files, have each host apply the delta, then reload the engines.""" + if self._commit_hook is not None: + self._commit_hook(self.args, self._version_dir, list(self.rollout_engines)) + dist.barrier(group=get_gloo_group()) + if dist.get_rank() == 0: + ray.get([actor.sync_local_checkpoint.remote(self.weight_version) for actor in self.all_engine_actors]) + ray.get( + [ + engine.update_weights_from_disk.remote( + model_path=self.args.update_weight_local_checkpoint_dir, + weight_version=str(self.weight_version), + ) + for engine in self.rollout_engines + ] + ) + ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + dist.barrier(group=get_gloo_group()) + + def _announce_version(self) -> None: + """Publish-only: commit the version dir and advance the latest-version pointer, so the + external fleet pulls and applies it on its own. No engine handles, hence no reload RPCs.""" + if self._commit_hook is not None: + self._commit_hook(self.args, self._version_dir, []) # opaque fleet: no engine handles + dist.barrier(group=get_gloo_group()) + if dist.get_rank() == 0: + _atomic_write(os.path.join(self.delta_dir, "latest"), f"{self.weight_version:06d}".encode()) + dist.barrier(group=get_gloo_group()) + + def _iter_hf_tensors(self): + """Yield (name, gathered HF tensor) for every param: base-class TP then EP gather passes.""" + for chunk_iter in (self._iter_non_expert_chunks(), self._iter_expert_chunks()): + for hf_chunk in chunk_iter: + yield from hf_chunk + dist.barrier(group=get_gloo_group()) + + def _encode_delta(self) -> None: + """Diff each gathered HF tensor against the snapshot, keeping the changed ones (compressed) + in self._delta with their checksums. The GPU->CPU gather is pipelined into a compute pool: + the main loop copies one tensor to a pinned buffer and submits it; pool workers diff and + compress in parallel (each is a few big GIL-releasing numpy/zstd calls).""" + self._version_dir = os.path.join(self.delta_dir, f"weight_v{self.weight_version:06d}") + if self._is_pp_src_rank: + os.makedirs(self._version_dir, exist_ok=True) + snapshot = self._snapshot + self._delta: dict[str, np.ndarray] = {} # changed tensor name -> compressed diff + self._checksums: dict[str, str] = {} # changed tensor name -> new-state checksum + self.changed_bytes = self.total_bytes = 0 + + # Pinned host-buffer pool: a pinned non_blocking GPU->CPU copy is far faster than .cpu(). + max_bytes = max((int(v.nbytes) for v in snapshot.values()), default=0) + free_q: queue.Queue = queue.Queue() + use_pinned = True + try: + for _ in range(max(4, min(2 * NUM_WORKERS, (32 << 30) // max(max_bytes, 1)))): + free_q.put(torch.empty(max_bytes, dtype=torch.uint8, pin_memory=True)) + except RuntimeError as e: # low memlock limit + logger.warning("pinned host buffers unavailable (%s); using pageable .cpu()", e) + use_pinned = False + + def diff_and_compress(name, buf, nbytes, pinned): + if pinned: # copy out and free the pinned buffer before the heavy diff/compress + new = np.empty(nbytes, dtype=np.uint8) + np.copyto(new, buf.numpy()[:nbytes]) + free_q.put(buf) + else: + new = buf + old = snapshot[name] + if self.delta_encoding == "xor": + diff = new ^ old + changed = int(np.count_nonzero(diff)) + elif self.delta_encoding == "overwrite": + mask = new != old + changed = int(np.count_nonzero(mask)) + diff = overwrite_encode(new, mask) + else: + raise ValueError(f"unknown delta encoding {self.delta_encoding!r}") + if not changed: + return name, new, None, None, 0 + compressed = np.frombuffer(zstandard.ZstdCompressor(level=1).compress(diff), dtype=np.uint8) + return name, new, compressed, checksum(self.checksum_algorithm, new), changed + + def collect(fut): + name, new, compressed, digest, changed = fut.result() + snapshot[name] = new # becomes the next sync's base + if changed: + self.changed_bytes += changed + self._delta[name] = compressed + self._checksums[name] = digest + + pool = ThreadPoolExecutor(max_workers=NUM_WORKERS) + inflight: deque = deque() + try: + for name, tensor in self._iter_hf_tensors(): + flat = tensor.detach().contiguous().view(torch.uint8).reshape(-1) + nbytes = int(flat.numel()) + if use_pinned and nbytes <= max_bytes: + buf = free_q.get() # blocks when all buffers are in flight -> backpressures the gather + buf[:nbytes].copy_(flat, non_blocking=True) + torch.cuda.current_stream().synchronize() + payload, pinned = buf, True + else: + payload, pinned = flat.cpu().numpy(), False + self.total_bytes += nbytes + inflight.append(pool.submit(diff_and_compress, name, payload, nbytes, pinned)) + if len(inflight) >= 2 * NUM_WORKERS: + collect(inflight.popleft()) + while inflight: + collect(inflight.popleft()) + finally: + pool.shutdown() + + def _record_metrics(self) -> None: + """All-reduce the byte counts and record changed-fraction / wire size; the actor drains + update_weight_metrics onto the step log.""" + counts = torch.tensor( + [self.changed_bytes, self.total_bytes, self.wire_bytes], + dtype=torch.int64, + device=torch.cuda.current_device(), + ) + dist.all_reduce(counts) + changed, total, wire = counts.tolist() + m = self.update_weight_metrics + m["perf/update_weights_density"] = changed / max(total, 1) + m["perf/update_weights_wire_bytes"] = wire + if dist.get_rank() == 0: + logger.info( + "[disk delta v=%s] density=%.2f%% wire=%.2f GB", + self.weight_version, + 100.0 * changed / max(total, 1), + wire / 1e9, + ) + + +def _atomic_write(path: str, data: bytes) -> None: + tmp = path + ".tmp" + with open(tmp, "wb") as f: + f.write(data) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, path) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 1ab48fb974..14698c4309 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -16,7 +16,6 @@ from slime.utils.distributed_utils import get_gloo_group, init_process_group from ..megatron_to_hf import convert_to_hf -from ..sglang import DeltaSpec from .common import all_gather_param, named_params_and_buffers @@ -60,6 +59,7 @@ def connect_rollout_engines( rollout_engine_lock: ActorHandle, engine_gpu_counts: Sequence[int] | None = None, engine_gpu_offsets: Sequence[int] | None = None, + all_engine_actors: Sequence[ActorHandle] | None = None, ) -> None: """ Create NCCL "slime-pp_{pp_rank}" if PP source (DP=TP=0). Lock prevents concurrent broadcasts. @@ -174,18 +174,12 @@ def _iter_non_expert_chunks(self) -> Iterator[list[tuple[str, torch.Tensor]]]: if buffer: yield buffer - def _iter_expert_chunks( - self, - params: Iterator[tuple[str, torch.Tensor]] | None = None, - ) -> Iterator[list[tuple[str, torch.Tensor]]]: + def _iter_expert_chunks(self) -> Iterator[list[tuple[str, torch.Tensor]]]: """ Yield one HF chunk per EP-weighted batch of expert params: TP gather + - buffer until threshold, then EP gather + HF convert. ``params`` lets - callers restrict the iter to a subset (used by delta-sync sub-passes); - defaults to all expert params on this rank. + buffer until threshold, then EP gather + HF convert. """ - if params is None: - params = ((n, p) for n, p in named_params_and_buffers(self.args, self.model) if ".experts." in n) + params = ((n, p) for n, p in named_params_and_buffers(self.args, self.model) if ".experts." in n) buffer_size = 0 batch: list[tuple[str, torch.Tensor]] = [] for name, param in params: @@ -247,12 +241,9 @@ def _update_bucket_weights_from_distributed( converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None, load_format: str | None = None, - delta: DeltaSpec | None = None, ) -> None: """ Lock → broadcast → clear → unlock → pbar++. Lock prevents NCCL deadlock. - Delta sync passes ``load_format="delta"`` + a ``DeltaSpec`` describing the - per-param decoding of the (__positions__, __values__) bucket tensors. """ # lock the rollout engines to prevent dead lock on broadcast. while not ray.get(self.rollout_engine_lock.acquire.remote()): @@ -265,7 +256,6 @@ def _update_bucket_weights_from_distributed( self.rollout_engines, converted_named_tensors, load_format=load_format, - delta=delta, ) ray.get(refs) @@ -339,11 +329,9 @@ def update_weights_from_distributed( rollout_engines: Sequence[ActorHandle], converted_named_tensors: Sequence[tuple[str, torch.Tensor]], load_format: str | None = None, - delta: DeltaSpec | None = None, ) -> list[ObjectRef]: """ Send metadata (Ray), broadcast tensors (NCCL rank 0 → engines). - Delta sync passes ``load_format="delta"`` + ``delta`` (DeltaSpec). """ refs = [ engine.update_weights_from_distributed.remote( @@ -353,7 +341,6 @@ def update_weights_from_distributed( group_name=group_name, weight_version=str(weight_version), load_format=load_format, - delta=delta, ) for engine in rollout_engines ] diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py deleted file mode 100644 index fbe24bbc1c..0000000000 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py +++ /dev/null @@ -1,864 +0,0 @@ -""" -Delta weight sync. - -For each sync, the sender bytewise-diffs the current weights against a -pinned-CPU snapshot of the last broadcast, packs the changed positions -and values, and ships them via one of two transports: - - - "nccl": each bucket flush goes out via NCCL broadcast (low-latency, - high-bandwidth, intra-datacenter). - - "disk": each bucket flush is written to a versioned shared-FS directory - as one safetensors file; one HTTP push per sync wakes the rollout - engines to read+apply (cross-datacenter, bandwidth-limited). - -Both transports share one wire layout (``__positions__`` uint8 byte blob + -``__values__`` param-dtype tensor + per-param decoding manifest) and one -receiver-side decoder. Three encodings differ only in how positions are -packed: - - indices : int32 absolute positions - deltas : uint16 gap-deltas (uint32 fallback per param) - deltas_zstd : ``deltas`` with the safetensors blob wrapped in zstd L1 - -The receiver overwrites changed positions with the trainer's exact bytes -(no arithmetic), so the apply is lossless and there is no drift to fight -with periodic re-syncs. The first ``update_weights`` call seeds the -snapshot without contacting the rollout engines — they're assumed to have -loaded the same HF checkpoint at init. -""" - -import itertools -import json -import logging -import os -import shutil -import threading -from argparse import Namespace -from collections.abc import Callable, Iterator, Mapping, Sequence -from concurrent.futures import ThreadPoolExecutor -from dataclasses import asdict, dataclass, field, replace -from queue import Queue - -import numpy as np -import ray -import torch -import torch.distributed as dist -from megatron.core import mpu -from ray.actor import ActorHandle -from safetensors.torch import save as st_save_bytes -from tqdm import tqdm - -from slime.utils.distributed_utils import get_gloo_group -from slime.utils.timer import Timer, timer - -from ..sglang import DeltaEncoding, DeltaParam, DeltaSpec -from .update_weight_from_distributed import UpdateWeightFromDistributed - -logger = logging.getLogger(__name__) - - -# ---------- compute + encode ----------------------------------------------- - - -@dataclass -class ParamDiff: - """ - One per-param compute output. ``values`` is a reference to the full-shape - current tensor (no copy); ``mask`` is a same-shape bool marking the - positions whose bytes differ from the snapshot. - """ - - name: str - values: torch.Tensor - mask: torch.Tensor - - -@dataclass -class EncodedChunk: - """ - One HF chunk after position+value encoding, before bucket merging. - - ``pos_bytes`` and ``val_tensor`` are the chunk-local concatenations across - all params; per-param byte/element offsets live on ``params``. - """ - - pos_bytes: bytes - val_tensor: torch.Tensor - params: list[DeltaParam] - nnz: int - - @classmethod - def empty(cls) -> "EncodedChunk": - return cls(pos_bytes=b"", val_tensor=torch.empty(0, dtype=torch.bfloat16), params=[], nnz=0) - - -def _checksum(positions: torch.Tensor, values: torch.Tensor) -> int: - """ - Wire-corruption check via ``torch.hash_tensor`` (XOR-reduce over uint64 bitcast). - Sender computes pre-flush, receiver computes post-recv; mismatch indicates - corruption between encode and apply. One reduction + one ``.item()`` sync per arg. - """ - p = int(torch.hash_tensor(positions).item()) if positions.numel() else 0 - v = int(torch.hash_tensor(values).item()) if values.numel() else 0 - return p ^ (v << 1) - - -def _bytewise_diff_mask(current: torch.Tensor, snapshot: torch.Tensor) -> torch.Tensor: - """ - Per-element bool mask: True where current and snapshot bytes differ. Dtype-agnostic via view-as-integer. - """ - es = current.element_size() - int_dtype = {1: torch.uint8, 2: torch.int16, 4: torch.int32, 8: torch.int64}.get(es) - if int_dtype is None: - raise ValueError(f"unsupported element size {es}") - return current.view(int_dtype) != snapshot.view(int_dtype) - - -def _sparse_boundaries( - diffs: list[ParamDiff], -) -> tuple[torch.Tensor, list[int], torch.Tensor, list[int]]: - """ - One concat → one nonzero → one searchsorted → one ``tolist()``: collapses - per-param host syncs to one per chunk. Returns ``(big_val, bounds, big_idx, cum)``. - """ - device = diffs[0].values.device - sizes = [d.values.numel() for d in diffs] - cum = list(itertools.accumulate(sizes)) - cum_t = torch.tensor(cum, dtype=torch.int64, device=device) - - big_values = torch.cat([d.values.contiguous().view(-1) for d in diffs], dim=0) - big_mask = torch.cat([d.mask.contiguous().view(-1) for d in diffs], dim=0) - big_idx = big_mask.nonzero(as_tuple=False).view(-1) - big_val = big_values[big_idx] - bounds = torch.searchsorted(big_idx, cum_t).tolist() - return big_val, bounds, big_idx, cum - - -def encode_indices(diffs: list[ParamDiff]) -> EncodedChunk: - """ - int32 absolute positions, per-param. Position blob is uint8 bytes; pos_width=4 for all params. - """ - if not diffs: - return EncodedChunk.empty() - big_val, bounds, big_idx, cum = _sparse_boundaries(diffs) - pos_pieces: list[torch.Tensor] = [] - val_pieces: list[torch.Tensor] = [] - params: list[DeltaParam] = [] - pos_byte_off = val_off = 0 - prev_b = 0 - prev_param_start = 0 - for i, d in enumerate(diffs): - b = bounds[i] - nnz = b - prev_b - if nnz > 0: - local_idx = (big_idx[prev_b:b] - prev_param_start).to(torch.int32) - pos_pieces.append(local_idx) - val_pieces.append(big_val[prev_b:b]) - params.append( - DeltaParam( - name=d.name, - dtype=str(d.values.dtype).replace("torch.", ""), - shape=list(d.values.shape), - pos_start=pos_byte_off, - pos_end=pos_byte_off + nnz * 4, - pos_width=4, - val_start=val_off, - val_end=val_off + nnz, - ) - ) - pos_byte_off += nnz * 4 - val_off += nnz - prev_b = b - prev_param_start = cum[i] - if not params: - return EncodedChunk.empty() - positions = torch.cat(pos_pieces, dim=0) - values = torch.cat(val_pieces, dim=0) - return EncodedChunk( - pos_bytes=positions.cpu().numpy().tobytes(), - val_tensor=values, - params=params, - nnz=val_off, - ) - - -def encode_deltas(diffs: list[ParamDiff]) -> EncodedChunk: - """ - Gap-encode sorted positions: store ``idx[k] - idx[k-1] - 1`` with idx[-1] := -1 - so the first delta equals the first index. Per-param downcast to uint16 if the max - gap fits, otherwise uint32. At ~2% Bernoulli density on bf16 weights, max gap ≈ 300 - — uint16 fits; the fallback covers pathological inputs without correctness risk. - Receiver inverts: ``idx = cumsum(delta + 1) - 1``. - """ - if not diffs: - return EncodedChunk.empty() - big_val, bounds, big_idx, cum = _sparse_boundaries(diffs) - - kept: list[tuple[ParamDiff, int]] = [] # (diff, nnz) for non-empty params - per_param_deltas: list[torch.Tensor] = [] - val_pieces: list[torch.Tensor] = [] - prev_b = 0 - prev_param_start = 0 - for i, d in enumerate(diffs): - b = bounds[i] - nnz = b - prev_b - if nnz > 0: - local_idx = big_idx[prev_b:b] - prev_param_start # int64, sorted - prev = torch.cat( - [ - torch.tensor([-1], dtype=local_idx.dtype, device=local_idx.device), - local_idx[:-1], - ] - ) - per_param_deltas.append(local_idx - prev - 1) - val_pieces.append(big_val[prev_b:b]) - kept.append((d, nnz)) - prev_b = b - prev_param_start = cum[i] - - if not kept: - return EncodedChunk.empty() - - # One CPU sync for per-param width selection. - max_per_param = torch.stack([d.max() for d in per_param_deltas]).cpu().tolist() - pos_byte_pieces: list[bytes] = [] - pos_byte_off = val_off = 0 - params: list[DeltaParam] = [] - for (d, nnz), deltas, max_d in zip(kept, per_param_deltas, max_per_param, strict=True): - width = 2 if int(max_d) <= 65535 else 4 - np_dtype = np.uint16 if width == 2 else np.uint32 - b_chunk = deltas.cpu().numpy().astype(np_dtype, copy=False).tobytes() - pos_byte_pieces.append(b_chunk) - params.append( - DeltaParam( - name=d.name, - dtype=str(d.values.dtype).replace("torch.", ""), - shape=list(d.values.shape), - pos_start=pos_byte_off, - pos_end=pos_byte_off + len(b_chunk), - pos_width=width, - val_start=val_off, - val_end=val_off + nnz, - ) - ) - pos_byte_off += len(b_chunk) - val_off += nnz - - values = torch.cat(val_pieces, dim=0) - return EncodedChunk( - pos_bytes=b"".join(pos_byte_pieces), - val_tensor=values, - params=params, - nnz=val_off, - ) - - -# ---------- snapshot state ------------------------------------------------- - - -class DeltaState: - """ - Pinned-CPU snapshot of every HF tensor we've broadcast, plus the H2D/D2H - side streams that pipeline next-chunk snapshot transfer behind the current - chunk's compute. - """ - - def __init__(self) -> None: - self.snapshot: dict[str, torch.Tensor] = {} - self.d2h_stream: torch.cuda.Stream | None = None - self.h2d_stream: torch.cuda.Stream | None = None - self.snapshot_dirty = False - - def prefetch_snapshot( - self, named_tensors: list[tuple[str, torch.Tensor]] - ) -> tuple[list[torch.Tensor], torch.cuda.Event]: - """ - Start an async H2D copy of the snapshot tensors for ``named_tensors`` on a side stream. - """ - if self.h2d_stream is None: - self.h2d_stream = torch.cuda.Stream() - prev_gpu: list[torch.Tensor] = [] - with torch.cuda.stream(self.h2d_stream): - for name, tensor in named_tensors: - if name not in self.snapshot: - raise KeyError(f"missing snapshot for {name!r}; first update_weights call seeds the snapshot") - prev_gpu.append(self.snapshot[name].to(device=tensor.device, non_blocking=True)) - event = self.h2d_stream.record_event() - return prev_gpu, event - - def compute_diffs( - self, - named_tensors: list[tuple[str, torch.Tensor]], - prefetched: tuple[list[torch.Tensor], torch.cuda.Event], - ) -> list[ParamDiff]: - """ - Wait for the prefetched H2D copy, then per-param bytewise diff against the snapshot. - """ - prev_gpu, event = prefetched - event.wait() - return [ - ParamDiff(name=name, values=current, mask=_bytewise_diff_mask(current, prev)) - for (name, current), prev in zip(named_tensors, prev_gpu, strict=True) - ] - - def update_snapshot_async(self, named_tensors: list[tuple[str, torch.Tensor]]) -> None: - """ - Enqueue a D2H copy of ``named_tensors`` into the pinned-CPU snapshot on a - side stream. Non-blocking; call ``flush_snapshot`` before the next sync. - """ - if self.d2h_stream is None: - self.d2h_stream = torch.cuda.Stream() - event = torch.cuda.current_stream().record_event() - with torch.cuda.stream(self.d2h_stream): - self.d2h_stream.wait_event(event) - for name, tensor in named_tensors: - if name not in self.snapshot: - self.snapshot[name] = torch.empty_like(tensor, device=torch.device("cpu"), pin_memory=True) - self.snapshot[name].copy_(tensor.detach(), non_blocking=True) - self.snapshot_dirty = True - - def flush_snapshot(self) -> None: - """ - Block until all enqueued D2H snapshot copies have landed. - """ - if self.snapshot_dirty: - if self.d2h_stream is not None: - self.d2h_stream.synchronize() - else: - torch.cuda.synchronize() - self.snapshot_dirty = False - - -# ---------- bucket --------------------------------------------------------- - - -@dataclass -class DeltaBucket: - """ - Accumulates encoded chunks for one flush. Per-param offsets are rebased - into the bucket's growing position blob + value tensor on ``add``. - """ - - pos_pieces: list[bytes] = field(default_factory=list) - val_pieces: list[torch.Tensor] = field(default_factory=list) - params: list[DeltaParam] = field(default_factory=list) - pos_total: int = 0 - val_total: int = 0 - byte_size: int = 0 - - @property - def has_updates(self) -> bool: - return bool(self.pos_pieces) - - def should_flush_before_add(self, chunk: EncodedChunk, byte_limit: int) -> bool: - """True iff adding ``chunk`` would push the bucket past ``byte_limit``.""" - chunk_bytes = len(chunk.pos_bytes) + chunk.val_tensor.numel() * chunk.val_tensor.element_size() - return self.has_updates and self.byte_size + chunk_bytes > byte_limit - - def add(self, chunk: EncodedChunk) -> None: - """Append ``chunk``, rebasing each param's byte/element offsets into the bucket.""" - for p in chunk.params: - self.params.append( - replace( - p, - pos_start=p.pos_start + self.pos_total, - pos_end=p.pos_end + self.pos_total, - val_start=p.val_start + self.val_total, - val_end=p.val_end + self.val_total, - ) - ) - self.pos_pieces.append(chunk.pos_bytes) - self.val_pieces.append(chunk.val_tensor) - self.pos_total += len(chunk.pos_bytes) - self.val_total += chunk.val_tensor.numel() - self.byte_size += len(chunk.pos_bytes) + chunk.val_tensor.numel() * chunk.val_tensor.element_size() - - def merged_positions_cpu(self) -> torch.Tensor: - """One CPU uint8 tensor with the bucket's positions blob.""" - merged = b"".join(self.pos_pieces) - if not merged: - return torch.empty(0, dtype=torch.uint8) - return torch.from_numpy(np.frombuffer(merged, dtype=np.uint8).copy()) - - def merged_values(self) -> torch.Tensor: - """One GPU tensor with the bucket's values, concatenated across chunks.""" - if not self.val_pieces: - return torch.empty(0, dtype=torch.bfloat16) - return torch.cat(self.val_pieces, dim=0) - - def clear(self) -> None: - """Reset to empty so the bucket can be reused for the next flush.""" - self.pos_pieces.clear() - self.val_pieces.clear() - self.params.clear() - self.pos_total = 0 - self.val_total = 0 - self.byte_size = 0 - - -# ---------- async safetensors writer (disk transport only) ----------------- - - -class AsyncSafetensorsWriter: - """ - Background thread that drains a queue of file writes. Producers do GPU→CPU - on the default stream and enqueue; the writer does the slow disk I/O - (and optional zstd compress) off the critical path. End-of-sync ``drain()`` - blocks until all enqueued writes have landed. - """ - - def __init__(self, compress_with_zstd: bool, zstd_level: int = 1) -> None: - self._queue: Queue = Queue() - self._error: BaseException | None = None - self._compress_with_zstd = compress_with_zstd - self._zstd_level = zstd_level - if compress_with_zstd: - # Lazy import — non-disk users don't pay the dep. - import zstandard - - self._zstd = zstandard - self._lock = threading.Lock() - self.bytes_pre_compress = 0 - self.bytes_post_compress = 0 - self._thread = threading.Thread(target=self._run, name="delta-disk-writer", daemon=True) - self._thread.start() - - def enqueue( - self, - path: str, - tensors: dict[str, torch.Tensor], - metadata: dict[str, str], - ) -> None: - """Hand a (path, tensors, metadata) tuple to the writer thread.""" - if self._error is not None: - raise RuntimeError(f"writer thread already failed: {self._error!r}") - self._queue.put((path, tensors, metadata)) - - def drain(self) -> None: - """Block until every queued write has landed; re-raise any writer-thread error.""" - self._queue.join() - if self._error is not None: - raise RuntimeError(f"writer thread failed: {self._error!r}") from self._error - - def reset_counters(self) -> None: - """Zero the byte counters at the start of a sync.""" - with self._lock: - self.bytes_pre_compress = 0 - self.bytes_post_compress = 0 - - def _run(self) -> None: - """Writer-thread loop: safetensors-encode → (optional zstd) → atomic replace.""" - cctx = self._zstd.ZstdCompressor(level=self._zstd_level, threads=-1) if self._compress_with_zstd else None - while True: - path, tensors, metadata = self._queue.get() - try: - if self._error is None: - blob = st_save_bytes(tensors, metadata=metadata) - pre = len(blob) - if cctx is not None: - blob = cctx.compress(blob) - post = len(blob) - tmp = path + ".tmp" - with open(tmp, "wb") as f: - f.write(blob) - f.flush() - os.fsync(f.fileno()) - os.replace(tmp, path) - with self._lock: - self.bytes_pre_compress += pre - self.bytes_post_compress += post - except BaseException as e: # noqa: BLE001 - self._error = e - finally: - self._queue.task_done() - - -# ---------- main class ----------------------------------------------------- - - -class UpdateWeightFromDistributedDelta(UpdateWeightFromDistributed): - """ - Selective delta sync. ``--update-weight-transport`` picks the per-flush carrier: - "nccl" broadcasts each bucket; "disk" writes each bucket as a safetensors file under - ``--update-weight-disk-dir`` and pushes once at end-of-sync. - """ - - def __init__( - self, - args: Namespace, - model: Sequence[torch.nn.Module], - weights_getter: Callable[[], Mapping[str, torch.Tensor]], - *, - model_name: str, - quantization_config: dict[str, int | str | list[str]] | None, - ) -> None: - super().__init__( - args, - model, - weights_getter, - model_name=model_name, - quantization_config=quantization_config, - ) - self.transport = args.update_weight_transport - self.encoding = DeltaEncoding(args.update_weight_encoding) - self.delta_state = DeltaState() - self._snapshot_seeded = False - # DELTAS_ZSTD shares the gap encoder; zstd is applied at file-write time. - self._encode = encode_indices if self.encoding is DeltaEncoding.INDICES else encode_deltas - - self.writer: AsyncSafetensorsWriter | None = None - self.delta_dir: str | None = None - self._pre_push_hook: Callable | None = None - # Disk transport: each pass boundary publishes its accumulated files - # (the only globally-synced flush points, since ``_publish_batch`` - # contains collectives). ``_pre_push_hook`` may return a Future, in - # which case the receiver RPC is deferred behind it via - # ``_rpc_executor`` so the main encode thread continues immediately. - # ``_pending_publishes`` holds the resulting Future[list[ObjectRef]] - # on rank 0; ``_finalize_sync`` awaits them at end of sync. - self._pending_files: list[str] = [] - self._pending_publishes: list = [] - self._published_any: bool = False - self._rpc_executor: ThreadPoolExecutor | None = None - if self.transport == "disk": - self.delta_dir = args.update_weight_disk_dir - os.makedirs(self.delta_dir, exist_ok=True) - self.writer = AsyncSafetensorsWriter( - compress_with_zstd=(self.encoding == DeltaEncoding.DELTAS_ZSTD), - ) - self._rpc_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="delta-publish-rpc") - if getattr(args, "custom_delta_pre_push_path", None): - from slime.utils.misc import load_function - - self._pre_push_hook = load_function(args.custom_delta_pre_push_path) - - def connect_rollout_engines( - self, - rollout_engines: Sequence[ActorHandle], - rollout_engine_lock: ActorHandle, - engine_gpu_counts: Sequence[int] | None = None, - engine_gpu_offsets: Sequence[int] | None = None, - ) -> None: - """ - NCCL transport: delegate to parent (group creation). Disk transport: just - record the engines + PP-src flag (no NCCL group needed). - """ - if self.transport == "nccl": - super().connect_rollout_engines( - rollout_engines, - rollout_engine_lock, - engine_gpu_counts=engine_gpu_counts, - engine_gpu_offsets=engine_gpu_offsets, - ) - return - self.rollout_engines = rollout_engines - self.rollout_engine_lock = rollout_engine_lock - self._engine_gpu_counts = engine_gpu_counts - self._is_pp_src_rank = ( - mpu.get_data_parallel_rank(with_context_parallel=True) == 0 and mpu.get_tensor_model_parallel_rank() == 0 - ) - pp_rank = mpu.get_pipeline_model_parallel_rank() - self._group_name = f"slime-pp_{pp_rank}" - - def disconnect_rollout_engines(self) -> None: - if self.transport == "nccl": - super().disconnect_rollout_engines() - - @torch.no_grad() - def update_weights(self) -> None: - """ - First call: seed the CPU snapshot from current model state, no engine RPCs. - Subsequent calls: pause → diff/encode → finalize → resume. ``delta_encode`` - covers the sender's per-param TP/EP gather + diff + sparse encode + per-publish - commit/RPC handoff; ``delta_finalize`` covers the tail wait for the last - batch's receiver-apply. Their sum is the sync latency the user observes. - """ - if not self._snapshot_seeded: - self._seed_snapshot() - self._snapshot_seeded = True - return - - self.weight_version += 1 - if self.transport == "disk": - self._version_dir = os.path.join(self.delta_dir, f"weight_v{self.weight_version:06d}") - if self._is_pp_src_rank: - os.makedirs(self._version_dir, exist_ok=True) - - if dist.get_rank() == 0: - ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) - ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) - dist.barrier(group=get_gloo_group()) - - self.density_nnz = self.density_numel = self.wire_bytes = self._flush_idx = 0 - self._pending_files.clear() - self._pending_publishes.clear() - self._published_any = False - if self.writer is not None: - self.writer.reset_counters() - pbar = tqdm(desc=f"[{self._group_name}] Update weights", total=0) if self._is_pp_src_rank else None - - with timer("delta_encode"): - self._send_weights(pbar) - if self.writer is not None: - self.writer.drain() - self.delta_state.flush_snapshot() - dist.barrier(group=get_gloo_group()) - - with timer("delta_finalize"): - self._finalize_sync() - - self._record_metrics() - - def _seed_snapshot(self) -> None: - """ - Populate the snapshot from current model state (TP/EP gather + HF - convert on PP-src ranks, D2H pinned copy). Cost is one full pass over - params — ~50s blocking on 355B at init. - """ - for chunk_iter in (self._iter_non_expert_chunks(), self._iter_expert_chunks()): - for hf_chunk in chunk_iter: - if hf_chunk: - self.delta_state.update_snapshot_async(hf_chunk) - dist.barrier(group=get_gloo_group()) - self.delta_state.flush_snapshot() - - def _send_weights(self, pbar: tqdm | None) -> None: - """ - Non-expert pass then expert pass, each followed by a barrier + (disk-only) - publish. The expert pass is split into ``_EXPERT_SUBPASSES`` sub-passes so - receiver apply for an earlier batch overlaps with later expert encoding, - instead of bottlenecking at end-of-sync. Megatron splits MoE layers - uniformly across PP ranks, so a per-rank slice of the expert param list - keeps the publish count identical on every rank (no barrier desync). - """ - from .common import named_params_and_buffers - - bucket = DeltaBucket() - self._pipeline_pass(self._iter_non_expert_chunks(), bucket, pbar) - self._flush_and_publish(bucket, pbar) - - expert_params = [(n, p) for n, p in named_params_and_buffers(self.args, self.model) if ".experts." in n] - n = len(expert_params) - for i in range(self._EXPERT_SUBPASSES): - lo = i * n // self._EXPERT_SUBPASSES - hi = (i + 1) * n // self._EXPERT_SUBPASSES - self._pipeline_pass(self._iter_expert_chunks(iter(expert_params[lo:hi])), bucket, pbar) - self._flush_and_publish(bucket, pbar) - - _EXPERT_SUBPASSES = 4 - - def _flush_and_publish(self, bucket: DeltaBucket, pbar: tqdm | None) -> None: - """ - End-of-sub-pass: drain the in-flight bucket, barrier all PP ranks, then - (disk-only) fire one publish RPC for everything since the last call. - """ - if bucket.has_updates: - self._flush_bucket(bucket, pbar) - dist.barrier(group=get_gloo_group()) - if self.transport == "disk": - self._publish_batch() - - def _pipeline_pass( - self, - chunk_iter: Iterator[list[tuple[str, torch.Tensor]]], - bucket: DeltaBucket, - pbar: tqdm | None, - ) -> None: - """ - 1-step H2D snapshot prefetch lookahead: chunk N+1's snapshot transfer - overlaps chunk N's compute+encode on the default stream. - """ - pending_chunk: list[tuple[str, torch.Tensor]] | None = None - pending_prefetch: tuple[list[torch.Tensor], torch.cuda.Event] | None = None - for hf_chunk in chunk_iter: - if not hf_chunk: - continue - next_prefetch = self.delta_state.prefetch_snapshot(hf_chunk) - if pending_prefetch is not None: - self._enqueue_chunk(pending_chunk, pending_prefetch, bucket, pbar) - pending_chunk, pending_prefetch = hf_chunk, next_prefetch - if pending_prefetch is not None: - self._enqueue_chunk(pending_chunk, pending_prefetch, bucket, pbar) - - def _enqueue_chunk( - self, - hf_chunk: list[tuple[str, torch.Tensor]], - prefetched: tuple[list[torch.Tensor], torch.cuda.Event], - bucket: DeltaBucket, - pbar: tqdm | None, - ) -> None: - """ - compute diffs → snapshot new prev → encode → bucket.add (flushing if full). - """ - diffs = self.delta_state.compute_diffs(hf_chunk, prefetched=prefetched) - self.delta_state.update_snapshot_async(hf_chunk) - chunk = self._encode(diffs) - self.density_numel += sum(d.values.numel() for d in diffs) - self.density_nnz += chunk.nnz - self.wire_bytes += len(chunk.pos_bytes) + chunk.val_tensor.numel() * chunk.val_tensor.element_size() - if not chunk.params: - return - if bucket.should_flush_before_add(chunk, self.args.update_weight_buffer_size): - self._flush_bucket(bucket, pbar) - bucket.add(chunk) - - def _flush_bucket(self, bucket: DeltaBucket, pbar: tqdm | None) -> None: - """ - NCCL: broadcast (__positions__, __values__) with a DeltaSpec. - Disk: enqueue one safetensors file with the same payload + metadata. - Both paths embed a checksum the receiver verifies before apply. - """ - if not bucket.has_updates: - return - positions_cpu = bucket.merged_positions_cpu() - values_gpu = bucket.merged_values() - params = list(bucket.params) - bucket.clear() - - # GPU-resident checksum: positions go to the device the values already live on - # (NCCL needs the same move anyway; disk gets it for free at the reduction). - positions_gpu = positions_cpu.to(values_gpu.device, non_blocking=True) - checksum = _checksum(positions_gpu, values_gpu) - - if self.transport == "nccl": - spec = DeltaSpec(encoding=self.encoding, params=params, checksum=checksum) - self._update_bucket_weights_from_distributed( - [("__positions__", positions_gpu), ("__values__", values_gpu)], - pbar=pbar, - load_format="delta", - delta=spec, - ) - else: # disk - tensors = {"__positions__": positions_cpu, "__values__": values_gpu.cpu()} - metadata = { - "encoding": self.encoding.value, - "params": json.dumps([asdict(p) for p in params]), - "current_version": str(self.weight_version), - "checksum": str(checksum), - } - filename = f"rank{dist.get_rank():04d}_flush{self._flush_idx:06d}.safetensors" - path = os.path.join(self._version_dir, filename) - self.writer.enqueue(path, tensors, metadata) - self._pending_files.append(filename) - if pbar is not None: - pbar.update(1) - self._flush_idx += 1 - - def _publish_batch(self) -> None: - """ - Drain pending fsyncs, invoke the pre-push hook (may return a Future for an - async durability step on shared FS), then defer rank 0's - ``update_weights_from_disk`` RPC behind that Future via ``_rpc_executor``. - Each deferred dispatch lands in ``_pending_publishes`` as a - Future[list[ObjectRef]]; ``_finalize_sync`` awaits both layers. Safe to call - with empty ``_pending_files``: the all_gather still synchronizes and rank 0 - skips the dispatch when no rank produced files. - """ - self.writer.drain() - dist.barrier(group=get_gloo_group()) - - commit_future = None - if self._pre_push_hook is not None: - commit_future = self._pre_push_hook(self.args, self._version_dir, list(self.rollout_engines)) - dist.barrier(group=get_gloo_group()) - - # Collect every rank's batch filenames at rank 0; payload is ~KB, gather is cheap. - all_files: list[list[str]] = [None] * dist.get_world_size() # type: ignore[list-item] - dist.all_gather_object(all_files, list(self._pending_files), group=get_gloo_group()) - flat = [f for sub in all_files for f in sub] - self._pending_files.clear() - - if dist.get_rank() == 0 and flat: - version_dir = self._version_dir - engines = list(self.rollout_engines) - weight_version = str(self.weight_version) - self._published_any = True - - def _fire_when_committed() -> list: - if commit_future is not None: - commit_future.result() - return [ - engine.update_weights_from_disk.remote( - model_path=version_dir, - files=flat, - load_format="delta", - weight_version=weight_version, - ) - for engine in engines - ] - - self._pending_publishes.append(self._rpc_executor.submit(_fire_when_committed)) - - def _finalize_sync(self) -> None: - """ - Per-transport end-of-sync. NCCL: each flush already broadcasted; just resume. - Disk: publish the trailing files, wait for all streamed applies to land, then - cleanup + resume. - """ - if self.transport == "nccl": - if dist.get_rank() == 0: - ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) - dist.barrier(group=get_gloo_group()) - return - - if self._pending_files: - self._publish_batch() - if dist.get_rank() == 0: - # Each entry is a Future returning a list of ObjectRefs. Awaiting the - # Futures unblocks the (commit-then-RPC) chain; ray.get waits for the - # receivers' apply to finish. - object_refs = [ref for fut in self._pending_publishes for ref in fut.result()] - ray.get(object_refs) - self._pending_publishes.clear() - if not self._published_any: - # No delta files needed publishing this sync (e.g. all-zero diff). - # Engines never saw the new version via update_weights_from_disk, so - # bump it explicitly to keep their recorded version in sync with ours. - weight_version = str(self.weight_version) - ray.get([engine.set_weight_version.remote(weight_version) for engine in self.rollout_engines]) - if not self.args.update_weight_delta_keep_files: - shutil.rmtree(self._version_dir, ignore_errors=True) - ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) - dist.barrier(group=get_gloo_group()) - - def _record_metrics(self) -> None: - """ - Allreduce density/byte counters across PP-src ranks; stash on - ``update_weight_metrics`` for the actor to drain into the next step log. - Wall-clock timings come from the slime ``Timer`` (``delta_encode`` / - ``delta_finalize`` blocks above + the outer ``update_weights`` decorator). - """ - pre_bytes = self.writer.bytes_pre_compress if self.writer is not None else 0 - post_bytes = self.writer.bytes_post_compress if self.writer is not None else 0 - counts = torch.tensor( - [self.density_nnz, self.density_numel, self.wire_bytes, pre_bytes, post_bytes], - dtype=torch.int64, - device=torch.cuda.current_device(), - ) - dist.all_reduce(counts) - nnz, numel, wire_bytes, pre_bytes, post_bytes = counts.tolist() - - density = nnz / max(numel, 1) - compression_ratio = (pre_bytes / post_bytes) if post_bytes > 0 else 1.0 - - m = self.update_weight_metrics - m["perf/update_weights_density"] = density - m["perf/update_weights_wire_bytes"] = wire_bytes - m["perf/update_weights_flushes_per_rank"] = float(self._flush_idx) - if self.transport == "disk": - m["perf/update_weights_disk_bytes_pre_compress"] = pre_bytes - m["perf/update_weights_disk_bytes_post_compress"] = post_bytes - m["perf/update_weights_compression_ratio"] = compression_ratio - - if dist.get_rank() == 0: - t = Timer().log_dict() - logger.info( - "[delta sync v=%s] transport=%s enc=%s density=%.3f%% " "encode=%.2fs finalize=%.2fs flushes/rank=%d", - self.weight_version, - self.transport, - self.encoding.value, - 100.0 * density, - t.get("delta_encode", 0.0), - t.get("delta_finalize", 0.0), - self._flush_idx, - ) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 724e05355b..723c0c4b9a 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -65,6 +65,7 @@ def connect_rollout_engines( rollout_engine_lock: ActorHandle, engine_gpu_counts: Sequence[int] | None = None, engine_gpu_offsets: Sequence[int] | None = None, + all_engine_actors: Sequence[ActorHandle] | None = None, ) -> None: """ Split colocated/distributed engines. Global source rank (DP=TP=PP=0) creates NCCL diff --git a/slime/backends/sglang_utils/external.py b/slime/backends/sglang_utils/external.py index d4d7867539..dc7c109cb2 100644 --- a/slime/backends/sglang_utils/external.py +++ b/slime/backends/sglang_utils/external.py @@ -230,3 +230,24 @@ def start_external_rollout_servers(args, *, start_router) -> tuple[dict[str, Ext ) } return servers, init_handles + + +def normalize_rollout_endpoint_url(url: str) -> str: + """Normalize an opaque HTTP rollout endpoint base URL (drop trailing slash).""" + url = url.rstrip("/") + parsed = urlparse(url) + if parsed.scheme not in ("http", "https") or parsed.netloc == "": + raise ValueError(f"Invalid --rollout-endpoint-url {url!r}. Use an absolute http:// or https:// URL.") + return url + + +def uses_rollout_endpoint(args) -> bool: + return bool(getattr(args, "rollout_endpoint_url", None)) + + +def rollout_endpoint_servers(args) -> tuple[dict[str, ExternalRolloutServer], list]: + """Rollout served by an opaque HTTP endpoint behind one URL. The fleet is elastic, so slime holds + no per-engine handles — hence no engines (weights are published to disk, not pushed) and generation + routes to the URL via get_model_url.""" + logger.info("Rollout served by opaque HTTP endpoint: %s", args.rollout_endpoint_url) + return {"default": ExternalRolloutServer(engines=[], engine_gpu_counts=[], engine_gpu_offsets=[])}, [] diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index 15c4dd7231..f30f18e03e 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -3,6 +3,7 @@ import logging import multiprocessing import os +import threading import time from urllib.parse import quote @@ -167,6 +168,19 @@ def _format_v6_uri(addr): else: self._init_normal(server_args_dict) + # Warm the host-local base off the actor's main thread: sglang serves the first rollout from + # its init-loaded weights, so the materialize (a full base copy) only has to finish before + # the first delta reload. init_local_checkpoint is idempotent and flock-guarded, so the first + # sync_local_checkpoint either finds it done or blocks on the same lock — no join needed. + if self.args.update_weight_mode == "delta" and self.args.update_weight_transport == "disk": + from slime.utils.disk_delta import init_local_checkpoint + + threading.Thread( + target=init_local_checkpoint, + args=(self.args.update_weight_local_checkpoint_dir, self.args.hf_checkpoint), + daemon=True, + ).start() + def _init_external(self, expect_server_args, external_engine_need_check_fields): logger.info(f"Use external SGLang engine (rank={self.rank}, expect_server_args={expect_server_args})") @@ -379,6 +393,25 @@ def resume_memory_occupation(self, tags: list[str] = None): def check_weights(self, action: str): return self._make_request("weights_checker", {"action": action}) + def sync_local_checkpoint(self, target_version: int): + """Apply the published deltas into this host's local checkpoint up to target_version; the + engine reloads it afterwards. Assumes this actor shares the checkpoint filesystem with the + sglang it drives (true for slime-launched engines).""" + from slime.utils.disk_delta import apply_deltas, init_local_checkpoint + + init_local_checkpoint(self.args.update_weight_local_checkpoint_dir, self.args.hf_checkpoint) # idempotent + # non-POSIX filesystems lack cross-host read-after-write consistency, so the trainer's + # just-written delta isn't visible on this mount until the hook refreshes it. + if self.args.custom_delta_pre_read_path: + from slime.utils.misc import load_function + + load_function(self.args.custom_delta_pre_read_path)(self.args.update_weight_disk_dir, target_version) + apply_deltas( + self.args.update_weight_local_checkpoint_dir, + self.args.update_weight_disk_dir, + target_version, + ) + def update_weights_from_disk( self, model_path: str, @@ -437,7 +470,6 @@ def update_weights_from_distributed( flush_cache=False, weight_version: str | None = None, load_format: str | None = None, - delta=None, ): payload = { "names": names, @@ -450,19 +482,6 @@ def update_weights_from_distributed( payload["weight_version"] = weight_version if load_format is not None: payload["load_format"] = load_format - if delta is not None: - # DeltaSpec → JSON string. Receiver reconstructs via DeltaEncoding(...) + - # DeltaParam(**p); avoids depending on FastAPI's nested-dataclass coercion. - import json - from dataclasses import asdict - - payload["delta"] = json.dumps( - { - "encoding": delta.encoding.value, - "params": [asdict(p) for p in delta.params], - "checksum": delta.checksum, - } - ) return self._make_request( "update_weights_from_distributed", payload, diff --git a/slime/ray/placement_group.py b/slime/ray/placement_group.py index c181c8e7f0..4d7f313ed2 100644 --- a/slime/ray/placement_group.py +++ b/slime/ray/placement_group.py @@ -103,7 +103,7 @@ def _get_placement_group_layout(args) -> tuple[int, int]: if args.debug_train_only: return actor_num_gpus, 0 - if args.rollout_external: + if args.rollout_external or getattr(args, "rollout_endpoint_url", None): if args.debug_rollout_only: return 0, 0 return actor_num_gpus, actor_num_gpus diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 5766d6b171..cb4f52b7f5 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -14,7 +14,11 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS -from slime.backends.sglang_utils.external import start_external_rollout_servers +from slime.backends.sglang_utils.external import ( + rollout_endpoint_servers, + start_external_rollout_servers, + uses_rollout_endpoint, +) from slime.backends.sglang_utils.sglang_config import ModelConfig, ServerGroupConfig, SglangConfig from slime.backends.sglang_utils.sglang_engine import SGLangEngine from slime.rollout.base_types import call_rollout_fn @@ -40,6 +44,8 @@ "tokens": torch.long, "loss_masks": torch.int, "rollout_log_probs": torch.float32, + "rollout_top_p_token_ids": torch.int32, + "rollout_top_p_token_offsets": torch.int32, "teacher_log_probs": torch.float32, "rollout_routed_experts": None, } @@ -535,7 +541,8 @@ def get_updatable_engines_and_lock(self): gpu_counts = srv.engine_gpu_counts if srv else [] gpu_offsets = srv.engine_gpu_offsets if srv else [] num_new = srv.num_new_engines if srv else 0 - return engines, self.rollout_engine_lock, num_new, gpu_counts, gpu_offsets + all_engine_actors = srv.all_engines if srv else [] + return engines, self.rollout_engine_lock, num_new, gpu_counts, gpu_offsets, all_engine_actors def get_num_rollout_per_epoch(self): assert self.args.rollout_global_dataset @@ -791,6 +798,21 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl if samples[0].rollout_log_probs is not None: train_data["rollout_log_probs"] = [sample.rollout_log_probs for sample in samples] + if samples[0].rollout_top_p_token_ids is not None: + for sample in samples: + assert sample.rollout_top_p_token_ids is not None + assert sample.rollout_top_p_token_offsets is not None + assert len(sample.rollout_top_p_token_offsets) == sample.response_length + 1, ( + f"top-p token offsets length {len(sample.rollout_top_p_token_offsets)} " + f"!= response length + 1 {sample.response_length + 1}" + ) + assert sample.rollout_top_p_token_offsets[-1] == len(sample.rollout_top_p_token_ids), ( + f"top-p token offsets[-1] {sample.rollout_top_p_token_offsets[-1]} " + f"!= token ids length {len(sample.rollout_top_p_token_ids)}" + ) + train_data["rollout_top_p_token_ids"] = [sample.rollout_top_p_token_ids for sample in samples] + train_data["rollout_top_p_token_offsets"] = [sample.rollout_top_p_token_offsets for sample in samples] + if samples[0].rollout_routed_experts is not None: train_data["rollout_routed_experts"] = [sample.rollout_routed_experts for sample in samples] @@ -849,6 +871,8 @@ def _split_train_data_by_dp(self, data): "rollout_ids", "rollout_mask_sums", "rollout_log_probs", + "rollout_top_p_token_ids", + "rollout_top_p_token_offsets", "rollout_routed_experts", "prompt", "teacher_log_probs", @@ -1080,6 +1104,9 @@ def start_rollout_servers(args, pg) -> tuple[dict[str, Any], list[Any]]: Note: ``init_http_client`` should be called separately before this, as the HTTP client is shared across all servers. """ + if uses_rollout_endpoint(args): + return rollout_endpoint_servers(args) + if args.rollout_external: return start_external_rollout_servers(args, start_router=_start_router) @@ -1295,6 +1322,7 @@ def compute_metrics_from_samples(args, samples): log_dict |= _compute_spec_metrics(args, samples) log_dict |= _compute_prefix_cache_metrics(args, samples) log_dict |= _compute_reward_cat_metrics(args, samples) + log_dict |= _compute_top_p_kept_vocab_metrics(args, samples) log_dict["repetition_frac"] = np.mean([int(has_repetition(s.response)) for s in samples]).item() log_dict["truncated_ratio"] = np.mean([int(s.status == Sample.Status.TRUNCATED) for s in samples]).item() return log_dict @@ -1403,6 +1431,20 @@ def _is_zero_std(samples: list[Sample]): return {f"zero_std/count_{reward}": len(items) for reward, items in group_by(interesting_rewards).items()} +def _compute_top_p_kept_vocab_metrics(args, all_samples: list[Sample]): + total_kept = 0 + total_tokens = 0 + for sample in all_samples: + offsets = sample.rollout_top_p_token_offsets + if not offsets or sample.response_length == 0: + continue + total_kept += offsets[-1] - offsets[0] + total_tokens += sample.response_length + if total_tokens == 0: + return {} + return {"top_p_kept_vocab_per_token": total_kept / total_tokens} + + def _compute_spec_metrics(args, all_samples: list[Sample]): if getattr(args, "sglang_speculative_algorithm", None) is None: return {} diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index bb87360639..daf7e1b1f6 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -14,6 +14,7 @@ from packaging.version import parse from tqdm import tqdm +from slime.backends.sglang_utils.external import uses_rollout_endpoint from slime.backends.sglang_utils.server_control import abort_servers_until_idle from slime.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput from slime.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter @@ -38,6 +39,8 @@ logger = logging.getLogger(__name__) _PROCESSOR_PROMPT_KEYS = {"input_ids", "attention_mask"} +_TOP_P_TOKEN_ID_META_KEYS = ("top_p_token_ids", "top_p_kept_token_ids") +_TOP_P_TOKEN_OFFSET_META_KEYS = ("top_p_token_offsets", "top_p_kept_token_offsets") def _prepare_prompt_ids(sample: Sample, tokenizer, processor: Any) -> list[int]: @@ -62,6 +65,78 @@ def _prepare_prompt_ids(sample: Sample, tokenizer, processor: Any) -> list[int]: return tokenizer.encode(sample.prompt, add_special_tokens=False) +def _decode_int32_meta_array(meta_info: dict[str, Any], keys: tuple[str, ...]) -> list[int] | None: + for key in keys: + if key in meta_info: + value = meta_info[key] + break + else: + return None + + if value is None: + return None + if isinstance(value, str): + value = pybase64.b64decode(value.encode("ascii")) + if isinstance(value, bytes): + return np.frombuffer(value, dtype=np.int32).tolist() + if isinstance(value, np.ndarray): + return value.astype(np.int32, copy=False).tolist() + return [int(x) for x in value] + + +def _extract_rollout_top_p_token_data( + meta_info: dict[str, Any], + *, + expected_num_tokens: int | None = None, +) -> tuple[list[int], list[int]] | None: + token_ids = _decode_int32_meta_array(meta_info, _TOP_P_TOKEN_ID_META_KEYS) + offsets = _decode_int32_meta_array(meta_info, _TOP_P_TOKEN_OFFSET_META_KEYS) + if token_ids is None and offsets is None: + return None + if token_ids is None or offsets is None: + raise ValueError("SGLang top-p token replay must include both token ids and offsets.") + if not offsets or offsets[0] != 0: + raise ValueError(f"SGLang top-p token offsets must start with 0, got {offsets[:1]}.") + if offsets[-1] != len(token_ids): + raise ValueError( + f"SGLang top-p token ids/offsets mismatch: offsets[-1]={offsets[-1]}, len(token_ids)={len(token_ids)}." + ) + if expected_num_tokens is not None and len(offsets) != expected_num_tokens + 1: + raise ValueError( + "SGLang top-p token offsets length must equal generated token count + 1: " + f"len(offsets)={len(offsets)}, generated={expected_num_tokens}." + ) + return token_ids, offsets + + +def _merge_rollout_top_p_token_data( + base_token_ids: list[int] | None, + base_offsets: list[int] | None, + token_ids: list[int], + offsets: list[int], +) -> tuple[list[int], list[int]]: + base_token_ids = list(base_token_ids or []) + base_offsets = list(base_offsets or [0]) + base_offset = base_offsets[-1] + return base_token_ids + token_ids, base_offsets + [base_offset + offset for offset in offsets[1:]] + + +def _append_rollout_top_p_token_data( + sample: Sample, + meta_info: dict[str, Any], + *, + expected_num_tokens: int | None = None, +) -> None: + top_p_data = _extract_rollout_top_p_token_data(meta_info, expected_num_tokens=expected_num_tokens) + if top_p_data is None: + return + sample.rollout_top_p_token_ids, sample.rollout_top_p_token_offsets = _merge_rollout_top_p_token_data( + sample.rollout_top_p_token_ids, + sample.rollout_top_p_token_offsets, + *top_p_data, + ) + + def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") -> str: """Return the router URL for a named model. @@ -72,8 +147,11 @@ def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") resp = await post(url, json=payload) Falls back to the default router if *model_name* is not found or - ``sglang_model_routers`` is not set. + ``sglang_model_routers`` is not set. With ``--rollout-endpoint-url`` set, returns that opaque + endpoint with *endpoint* appended (no router APIs are assumed to exist). """ + if uses_rollout_endpoint(args): + return f"{args.rollout_endpoint_url}{endpoint}" routers = getattr(args, "sglang_model_routers", None) if routers and model_name in routers: ip, port = routers[model_name] @@ -81,6 +159,37 @@ def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") return f"http://{args.sglang_router_ip}:{args.sglang_router_port}{endpoint}" +async def apply_rollout_request_hook( + args: Namespace, + url: str, + payload: dict[str, Any], + *, + headers: dict | None, + sample: Sample, +) -> dict[str, Any]: + """Run ``custom_rollout_request_hook_path`` on one outgoing /generate request. + + The hook receives ``request = {"url", "payload", "headers", "max_retries", "retry_sleep"}`` along + with ``args`` and ``sample`` (which carries its own context, e.g. ``sample.index``) — everything + about how this one request is sent, and nothing about the rollout itself. It mutates ``request`` + in place and returns None, or returns a dict of updates; this returns the resulting request. + Callers invoke this only when a hook is set, so the default path keeps calling ``post`` directly. + """ + request = {"url": url, "payload": payload, "headers": headers, "max_retries": 60, "retry_sleep": 1.0} + hook = load_function(args.custom_rollout_request_hook_path) + result = hook(args, sample, request) + if inspect.isawaitable(result): + result = await result + if result is not None: + if not isinstance(result, dict): + raise TypeError( + f"{args.custom_rollout_request_hook_path} must return None or a dict of request updates, " + f"got {type(result).__name__}" + ) + request.update(result) + return request + + class GenerateState(metaclass=SingletonMeta): """ The global state for the generation process. @@ -104,6 +213,8 @@ def __init__(self, args: Namespace) -> None: no_stop_trim=True, spaces_between_special_tokens=False, ) + if args.rollout_top_p != 1.0: + self.sampling_params["custom_params"] = {"return_top_p_token_ids": True} if getattr(args, "sglang_enable_deterministic_inference", False): sampling_seed_base = args.rollout_seed @@ -154,7 +265,7 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A assert isinstance(sample.prompt, str) state = GenerateState(args) - url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + url = get_model_url(args, "default", "/generate") assert ( sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED @@ -197,7 +308,17 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A headers = {"X-SMG-Routing-Key": sample.session_id} with trace_span(sample, "sglang_generate", attrs={"max_new_tokens": sampling_params["max_new_tokens"]}) as span: - output = await post(url, payload, headers=headers) + if getattr(args, "custom_rollout_request_hook_path", None): + request = await apply_rollout_request_hook(args, url, payload, headers=headers, sample=sample) + output = await post( + request["url"], + request["payload"], + headers=request["headers"], + max_retries=request["max_retries"], + retry_sleep=request["retry_sleep"], + ) + else: + output = await post(url, payload, headers=headers) span.update(build_sglang_meta_trace_attrs(output["meta_info"])) if "output_token_logprobs" in output["meta_info"]: @@ -219,6 +340,11 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A if sample.rollout_log_probs is None: sample.rollout_log_probs = [] sample.rollout_log_probs += new_response_log_probs + _append_rollout_top_p_token_data( + sample, + output["meta_info"], + expected_num_tokens=len(new_response_tokens), + ) if "routed_experts" in output["meta_info"]: sample.rollout_routed_experts = np.frombuffer( @@ -355,14 +481,26 @@ async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]: assert not state.aborted state.aborted = True - if parse(sglang_router.__version__) <= parse("0.2.1"): - response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") - urls = response["urls"] - else: - response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") - urls = [worker["url"] for worker in response["workers"]] - - await abort_servers_until_idle(urls) + if uses_rollout_endpoint(args) and not args.partial_rollout: + # Opaque endpoint, surplus discarded: cancel locally — the client disconnect aborts the + # request on the fleet. No worker API to call, and nothing to collect. + for task in state.pendings: + task.cancel() + await asyncio.gather(*state.pendings, return_exceptions=True) + state.pendings = set() + return aborted_samples + + if not uses_rollout_endpoint(args): + # Router: explicitly abort in-flight requests on each worker. + if parse(sglang_router.__version__) <= parse("0.2.1"): + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") + urls = response["urls"] + else: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") + urls = [worker["url"] for worker in response["workers"]] + await abort_servers_until_idle(urls) + # Opaque endpoint + partial-rollout: the streaming tasks self-break on state.aborted and return + # their partials below; closing each stream disconnects, which aborts the request on the fleet. # make sure all the pending tasks are finished count = 0 diff --git a/slime/rollout/sglang_streaming_rollout.py b/slime/rollout/sglang_streaming_rollout.py index e9d380dc58..01e1dcd0ba 100644 --- a/slime/rollout/sglang_streaming_rollout.py +++ b/slime/rollout/sglang_streaming_rollout.py @@ -32,7 +32,14 @@ import numpy as np import pybase64 -from slime.rollout.sglang_rollout import GenerateState, _prepare_prompt_ids +from slime.rollout.sglang_rollout import ( + GenerateState, + _extract_rollout_top_p_token_data, + _merge_rollout_top_p_token_data, + _prepare_prompt_ids, + apply_rollout_request_hook, + get_model_url, +) from slime.utils import http_utils from slime.utils.processing_utils import encode_image_for_rollout_engine from slime.utils.trace_utils import build_sglang_meta_trace_attrs, trace_span @@ -53,7 +60,7 @@ async def generate_streaming(args: Namespace, sample: Sample, sampling_params: d assert isinstance(sample.prompt, str) state = GenerateState(args) - url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + url = get_model_url(args, "default", "/generate") assert sample.status in ( Sample.Status.PENDING, @@ -91,6 +98,14 @@ async def generate_streaming(args: Namespace, sample: Sample, sampling_params: d if sample.session_id and getattr(args, "router_policy", None) == "consistent_hashing": headers = {"X-SMG-Routing-Key": sample.session_id} + # Let a user hook mutate the request (e.g. custom headers or a weight_version gate) before the + # stream opens. Only when one is configured — the default path opens the stream unchanged. A + # stream is one connection, not a retry loop, so any max_retries/retry_sleep the hook sets are + # ignored here; only url/payload/headers apply. + if getattr(args, "custom_rollout_request_hook_path", None): + request = await apply_rollout_request_hook(args, url, payload, headers=headers, sample=sample) + url, payload, headers = request["url"], request["payload"], request["headers"] + # Snapshot pre-call sample state. sglang's SSE chunks are cumulative # *within this call*; on each chunk we rebuild the post-call view of the # sample = prior state + chunk delta. That way a mid-stream break leaves @@ -99,6 +114,8 @@ async def generate_streaming(args: Namespace, sample: Sample, sampling_params: d base_response = sample.response or "" base_response_length = sample.response_length base_log_probs = list(sample.rollout_log_probs or []) + base_top_p_token_ids = list(sample.rollout_top_p_token_ids or []) + base_top_p_token_offsets = list(sample.rollout_top_p_token_offsets or [0]) base_loss_mask = list(sample.loss_mask) if sample.loss_mask is not None else None last_meta_info: dict[str, Any] = {} @@ -141,6 +158,15 @@ async def generate_streaming(args: Namespace, sample: Sample, sampling_params: d sample.response = base_response + call_text sample.response_length = base_response_length + len(call_tokens) sample.rollout_log_probs = base_log_probs + call_log_probs + top_p_data = _extract_rollout_top_p_token_data(meta, expected_num_tokens=len(call_tokens)) + if top_p_data is not None: + sample.rollout_top_p_token_ids, sample.rollout_top_p_token_offsets = ( + _merge_rollout_top_p_token_data( + base_top_p_token_ids, + base_top_p_token_offsets, + *top_p_data, + ) + ) if base_loss_mask is not None: assert args.partial_rollout and args.mask_offpolicy_in_partial_rollout sample.loss_mask = base_loss_mask + [1] * len(call_tokens) @@ -166,5 +192,9 @@ async def generate_streaming(args: Namespace, sample: Sample, sampling_params: d sample.update_from_meta_info(args, last_meta_info) elif state.aborted: sample.status = Sample.Status.ABORTED + # Record the version of the partial's tokens (every streaming chunk carries it) so off-policy + # correction can weight it — update_from_meta_info is skipped without a finish_reason. + if "weight_version" in last_meta_info: + sample.weight_versions.append(last_meta_info["weight_version"]) return sample diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 6efe85eae7..7f846b98ba 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -3,14 +3,13 @@ import json import logging import os -import warnings from typing import Any import yaml from slime.backends.sglang_utils.arguments import sglang_parse_args from slime.backends.sglang_utils.arguments import validate_args as sglang_validate_args -from slime.backends.sglang_utils.external import apply_external_engine_info_to_args +from slime.backends.sglang_utils.external import apply_external_engine_info_to_args, normalize_rollout_endpoint_url from slime.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from slime.utils.logging_utils import configure_logger @@ -107,13 +106,6 @@ def add_cluster_arguments(parser): def add_train_arguments(parser): # --train-backend is parsed early in _pre_parse_mode() and merged later. - parser.add_argument( - "--qkv-format", - type=str, - choices=["thd", "bshd"], - default="thd", - help="The qkv layout for Megatron backend.", - ) parser.add_argument( "--qwen-gdn-backend", type=str, @@ -146,8 +138,8 @@ def add_train_arguments(parser): default="full", help=( "Weight sync strategy. 'full' (default) broadcasts every parameter " - "every sync. 'delta' detects byte-level changes against a pinned-CPU " - "snapshot of the previous broadcast and ships only the changed positions + values." + "every sync. 'delta' diffs each sync against a pinned-CPU snapshot of the " + "previous one and ships only the changed bytes (disk transport only)." ), ) parser.add_argument( @@ -157,9 +149,8 @@ def add_train_arguments(parser): help=( "Carrier for weight sync. In full mode, 'nccl' broadcasts chunks and " "'disk' writes a complete HF checkpoint under --update-weight-disk-dir " - "before engines reload it. In delta mode, 'nccl' broadcasts sparse deltas; " - "'disk' writes sparse safetensors under --update-weight-disk-dir and pushes " - "once at end-of-sync." + "before engines reload it. Delta mode is 'disk' only: each host applies the " + "published deltas into its local checkpoint and reloads via update_weights_from_disk." ), ) parser.add_argument( @@ -169,7 +160,7 @@ def add_train_arguments(parser): help=( "Filesystem directory for disk-backed weight sync. In --update-weight-mode=full, " "one complete HF checkpoint directory is written per sync. In delta mode, " - "one sparse-delta directory is written per sync." + "one delta directory (changed tensors only) is written per sync." ), ) parser.add_argument( @@ -182,41 +173,64 @@ def add_train_arguments(parser): ), ) parser.add_argument( - "--update-weight-encoding", - choices=["indices", "deltas", "deltas_zstd"], - default="indices", + "--update-weight-delta-encoding", + choices=["xor", "overwrite"], + default="xor", help=( - "Position encoding for partial flushes. 'indices': int32 absolute " - "positions (largest, lowest compute). 'deltas': uint16 gap-deltas " - "with uint32 fallback (smaller). 'deltas_zstd': 'deltas' with the " - "safetensors blob wrapped in zstd L1 (smallest, heaviest compute — " - "best for shared-FS bandwidth ≤ ~300 MB/s)." + "On-disk delta encoding for --update-weight-mode=delta --update-weight-transport=disk. " + "'xor' (default): new ^ old — smallest wire and fastest, but an involution that must be " + "applied exactly once against the correct base (applying it twice reverts). 'overwrite': " + "changed positions + new absolute values — larger, but idempotent (re-applicable any " + "number of times). Both are byte-level and dtype-blind; the engine reads the choice from " + "each version's index metadata." ), ) parser.add_argument( - "--update-weight-delta-dir", + "--update-weight-delta-checksum", + choices=["xxh3-128", "blake3", "adler32"], + default="xxh3-128", + help=( + "Per-tensor integrity checksum for disk delta apply. The checksum is not the " + "apply bottleneck (the apply is decompress + XOR bound), so this is a digest-" + "property choice, not a speed one. 'xxh3-128' (default): widest fast non-" + "cryptographic digest, negligible accidental-corruption collisions. 'blake3': " + "cryptographic digest, for untrusted storage. 'adler32': 32-bit, for interop " + "with systems that expect it. The engine reads the choice from each version's " + "index metadata." + ), + ) + parser.add_argument( + "--custom-delta-pre-push-path", type=str, default=None, help=( - "Deprecated alias for --update-weight-disk-dir and will be removed in a future " - "release. Prefer the transport-level directory flag for both full and delta disk sync." + "Path to a custom function called on each trainer rank after its delta files " + "are written, before the engines read them — to publish the writes on a " + "non-POSIX filesystem (no cross-host visibility without an explicit sync). " + "Signature: ``def hook(args, version_dir: str, rollout_engines) -> None``; the hook gates itself." ), ) parser.add_argument( - "--update-weight-delta-keep-files", - action="store_true", - default=False, - help="Skip post-apply cleanup of per-sync version directories. Useful for debugging.", + "--custom-delta-pre-read-path", + type=str, + default=None, + help=( + "Path to a custom function called on each rollout host before it reads the " + "published delta directory — refreshes the mount so the just-published version " + "is visible on a non-POSIX filesystem (no cross-host read-after-write consistency). " + "Signature: ``def hook(delta_dir: str, target_version: int) -> None``." + ), ) parser.add_argument( - "--custom-delta-pre-push-path", + "--update-weight-local-checkpoint-dir", type=str, default=None, help=( - "Path to a custom function called by --update-weight-transport=disk after each " - "trainer rank's files are durably on local disk, before rank 0 fires the engine " - "RPCs. Signature: ``def hook(args, version_dir: str, rollout_engines) -> None``. " - "Called from every trainer rank; the hook gates itself." + "Rollout-host-local directory (NVMe) holding a full HF checkpoint that " + "disk-delta sync patches in place. Each host materializes it from " + "--hf-checkpoint at engine start, applies each version's delta there, and " + "the engines reload from it. Required for --update-weight-mode=delta " + "--update-weight-transport=disk." ), ) parser.add_argument( @@ -545,6 +559,30 @@ def add_rollout_arguments(parser): nargs="+", help="Address and ports of the external engines.", ) + parser.add_argument( + "--rollout-endpoint-url", + type=str, + default=None, + help=( + "Base URL of an opaque HTTP rollout endpoint (an elastic fleet behind one URL). " + "slime launches no engines and sends /generate here; weights are published to " + "--update-weight-disk-dir for the fleet to pull (requires delta + disk transport)." + ), + ) + parser.add_argument( + "--custom-rollout-request-hook-path", + type=str, + default=None, + help=( + "Path to a hook that can mutate each outgoing generate request before it is sent. " + "Signature: def hook(args, sample, request) -> None | dict (may be async), where " + "request holds 'url', 'payload', 'headers', 'max_retries' and 'retry_sleep'. Mutate " + "request in place and return None, or return a dict of updates to apply. Use it to " + "add custom headers, or for weight-version gating against an opaque rollout endpoint " + "by setting request['payload']['weight_version'] (the hook supplies the target " + "version) and raising max_retries/retry_sleep to wait for the fleet to load it." + ), + ) return parser def add_fault_tolerance_arguments(parser): @@ -1050,7 +1088,7 @@ def add_algo_arguments(parser): "--custom-pg-loss-reducer-function-path", type=str, default=None, - help="Path to a custom reducer function for pg_loss only. When set, pg_loss will use this custom reducer while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean. (e.g., examples/Dr.GRPO/custom_reducer.py:get_pg_loss_reducer).", + help="Path to a custom reducer function for pg_loss only. When set, pg_loss will use this custom reducer while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean.", ) parser.add_argument( @@ -1692,57 +1730,6 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: return eval_datasets -def _resolve_update_weight_disk_dir(args) -> None: - """Normalize disk-sync directory args. - - ``--update-weight-delta-dir`` is kept only as a compatibility alias. New - code should use ``--update-weight-disk-dir`` because the directory belongs - to the transport, not to the delta encoding mode. - """ - disk_dir = args.update_weight_disk_dir - delta_dir = args.update_weight_delta_dir - if disk_dir and delta_dir and disk_dir != delta_dir: - raise ValueError( - "--update-weight-delta-dir is deprecated alias for --update-weight-disk-dir; " - "please set only one of them or set both to the same path." - ) - - if delta_dir: - warnings.warn( - "--update-weight-delta-dir is deprecated and will be removed in a future release; " - "use --update-weight-disk-dir instead.", - UserWarning, - stacklevel=2, - ) - - disk_dir = disk_dir or delta_dir - if args.update_weight_transport == "disk": - if not disk_dir: - raise ValueError( - "--update-weight-transport=disk requires --update-weight-disk-dir to point at " - "a filesystem shared between the trainer and the rollout engines." - ) - args.update_weight_disk_dir = disk_dir - args.update_weight_delta_dir = disk_dir - - -def _validate_update_weight_args(args) -> None: - _resolve_update_weight_disk_dir(args) - - if args.update_weight_mode == "delta": - if args.update_weight_transport not in ("nccl", "disk"): - raise ValueError( - "--update-weight-mode=delta supports only --update-weight-transport=nccl or disk, " - f"got {args.update_weight_transport!r}." - ) - if args.colocate: - raise ValueError( - "--update-weight-mode=delta is not supported with --colocate. Colocate transfers " - "weights via CUDA IPC (only a handle crosses processes), so the delta bookkeeping " - "(snapshot + diff + sparse encode) is pure overhead." - ) - - def slime_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) @@ -1881,6 +1868,29 @@ def slime_validate_args(args): if args.rollout_external and not args.debug_train_only: apply_external_engine_info_to_args(args, logger=logger) + if args.rollout_endpoint_url is not None: + args.rollout_endpoint_url = normalize_rollout_endpoint_url(args.rollout_endpoint_url) + if args.rollout_external: + raise ValueError("--rollout-endpoint-url and --rollout-external-engine-addrs are mutually exclusive.") + if not (args.update_weight_mode == "delta" and args.update_weight_transport == "disk"): + raise ValueError( + "--rollout-endpoint-url requires --update-weight-mode=delta --update-weight-transport=disk: " + "weights are published to disk for the external fleet to pull." + ) + # One logical endpoint; client-side concurrency is sglang_server_concurrency * this. Must be + # set before init_http_client (which otherwise derives 0 engines from the 0 rollout GPUs). + if getattr(args, "rollout_num_engines", None) is None: + args.rollout_num_engines = 1 + # Partial-rollout against an opaque endpoint can only capture partials client-side, which + # needs the streaming rollout (the non-streaming path yields nothing to a disconnected client). + if getattr(args, "partial_rollout", False) and "sglang_streaming_rollout" not in ( + getattr(args, "custom_generate_function_path", None) or "" + ): + raise ValueError( + "--rollout-endpoint-url with --partial-rollout requires the streaming rollout: set " + "--custom-generate-function-path slime.rollout.sglang_streaming_rollout.generate_streaming" + ) + args.use_critic = args.advantage_estimator == "ppo" # Critic always uses the same GPU count as actor. args.critic_num_gpus_per_node = args.actor_num_gpus_per_node @@ -2002,13 +2012,30 @@ def slime_validate_args(args): args.rollout_max_prompt_len <= args.rollout_max_context_len - 1 ), f"args.rollout_max_prompt_len ({args.rollout_max_prompt_len}) must be smaller than args.rollout_max_context_len ({args.rollout_max_context_len}) so that there is at least one generated token to compute loss." - if args.qkv_format == "bshd": - assert args.train_backend == "megatron", "bshd format is only supported for megatron backend." - assert ( - args.use_dynamic_batch_size is False - ), "Dynamic batch size is not supported for bshd format. Please specify --micro-batch-size instead." - if args.only_train_params_name_list and args.freeze_params_name_list: raise ValueError("You can only specify ONE of: --only-train-params-name-list, or --freeze-params-name-list.") - _validate_update_weight_args(args) + # disk-backed sync (full or delta) writes on the trainer and reads on the engines: needs a shared dir + if args.update_weight_transport == "disk" and not args.update_weight_disk_dir: + raise ValueError( + "--update-weight-transport=disk requires --update-weight-disk-dir to point at " + "a filesystem shared between the trainer and the rollout engines." + ) + if args.update_weight_mode == "delta": + if args.update_weight_transport != "disk": + raise ValueError( + "--update-weight-mode=delta requires --update-weight-transport=disk, " + f"got {args.update_weight_transport!r}." + ) + if args.colocate: + raise ValueError( + "--update-weight-mode=delta is not supported with --colocate. Colocate transfers " + "weights via CUDA IPC (only a handle crosses processes), so the delta bookkeeping " + "(snapshot + diff + encode) is pure overhead." + ) + if not args.update_weight_local_checkpoint_dir and args.rollout_endpoint_url is None: + # publish-only (--rollout-endpoint-url) applies on the fleet, not on a slime-host-local copy + raise ValueError( + "--update-weight-mode=delta requires --update-weight-local-checkpoint-dir " + "(a rollout-host-local NVMe directory)." + ) diff --git a/slime/utils/disk_delta.py b/slime/utils/disk_delta.py new file mode 100644 index 0000000000..3ba8cef12f --- /dev/null +++ b/slime/utils/disk_delta.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import fcntl +import glob +import io +import json +import logging +import mmap +import os +import shutil +import struct +import threading +import zlib +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager + +import numpy as np +import zstandard + +logger = logging.getLogger(__name__) + +# The delta phases (XOR/scatter, zstd, checksum) are memory-bandwidth bound and release the GIL, +# so a thread pool over tensors recovers the bandwidth one thread leaves idle. +NUM_WORKERS = min(32, (os.cpu_count() or 8)) + +SYNC_DIR = ".delta_sync" # per-checkpoint dir holding the applied-version marker and the apply lock + + +def overwrite_encode(new: np.ndarray, changed_mask: np.ndarray) -> np.ndarray: + """The 'overwrite' delta: changed-position count (u4), positions (u4 each), then new values. + Idempotent to apply, unlike xor (an involution); the trainer picks the encoding per the docs.""" + pos = np.flatnonzero(changed_mask).astype(" None: + self._value = zlib.adler32(data, self._value) + + def hexdigest(self) -> str: + return f"{self._value:08x}" + + +def _new_hasher(algorithm: str): + if algorithm == "xxh3-128": + import xxhash + + return xxhash.xxh3_128() + if algorithm == "blake3": + import blake3 + + return blake3.blake3() + if algorithm == "adler32": + return _Adler32() + raise KeyError(f"unknown checksum algorithm {algorithm!r}") + + +def checksum(algorithm: str, buf) -> str: + hasher = _new_hasher(algorithm) + hasher.update(buf) + return hasher.hexdigest() + + +@contextmanager +def _apply_lock(local_ckpt_dir: str): + sync = os.path.join(local_ckpt_dir, SYNC_DIR) + os.makedirs(sync, exist_ok=True) + with open(os.path.join(sync, "lock"), "w") as f: + fcntl.flock(f, fcntl.LOCK_EX) + try: + yield + finally: + fcntl.flock(f, fcntl.LOCK_UN) + + +def _read_applied_version(local_ckpt_dir: str) -> str | None: + try: + with open(os.path.join(local_ckpt_dir, SYNC_DIR, "state.json")) as f: + return json.load(f)["version"] + except FileNotFoundError: + return None + + +def _write_applied_version(local_ckpt_dir: str, version: str) -> None: + path = os.path.join(local_ckpt_dir, SYNC_DIR, "state.json") + tmp = path + ".tmp" + with open(tmp, "w") as f: + json.dump({"version": version}, f) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, path) + + +def drop_page_cache(path: str) -> None: + """Evict a file from the page cache (POSIX_FADV_DONTNEED).""" + try: + fd = os.open(path, os.O_RDONLY) + try: + os.posix_fadvise(fd, 0, 0, os.POSIX_FADV_DONTNEED) + finally: + os.close(fd) + except OSError: + pass + + +def init_local_checkpoint(local_ckpt_dir: str, base_dir: str) -> None: + """Copy the base HF checkpoint into local_ckpt_dir once if absent (run at engine start). Each + later delta is applied on top of this copy in place.""" + with _apply_lock(local_ckpt_dir): + if _read_applied_version(local_ckpt_dir) is not None: + return + logger.info("Materializing base checkpoint %s -> %s", base_dir, local_ckpt_dir) + os.makedirs(local_ckpt_dir, exist_ok=True) + for entry in os.scandir(base_dir): + if entry.is_file(): + shutil.copy2(entry.path, os.path.join(local_ckpt_dir, entry.name)) + drop_page_cache(entry.path) # don't let the source evict the local copy we keep resident + _write_applied_version(local_ckpt_dir, "000000") + + +def _tensor_locations(ckpt_dir: str) -> dict[str, tuple[str, int, int]]: + """Map each tensor name to (file, byte offset, nbytes) by reading every safetensors header.""" + locations: dict[str, tuple[str, int, int]] = {} + for path in glob.glob(os.path.join(ckpt_dir, "*.safetensors")): + with open(path, "rb") as f: + (header_len,) = struct.unpack(" uint8 bytes`` that seeks straight to the + tensor — for reading many tensors without rescanning every header. KeyError if absent.""" + locations = _tensor_locations(ckpt_dir) + + def read(name: str) -> np.ndarray: + path, offset, nbytes = locations[name] + with open(path, "rb") as f: + f.seek(offset) + return np.frombuffer(f.read(nbytes), dtype=np.uint8) + + return read + + +def _apply_version(local_ckpt_dir: str, version_dir: str) -> None: + """Apply one version's delta in place: decompress + apply + checksum each tensor across a thread + pool (each writes a distinct mmap region, so the writes don't conflict). Any mismatch raises.""" + with open(os.path.join(version_dir, "model.safetensors.index.json")) as f: + meta = json.load(f)["metadata"] + applied = _read_applied_version(local_ckpt_dir) + if applied == meta["version"]: + return + if applied != meta["base_version"]: + raise RuntimeError(f"out-of-order delta: local at {applied}, delta builds on {meta['base_version']}") + if meta["compression_format"] != "zstd": + raise NotImplementedError(f"compression {meta['compression_format']!r} not supported") + encoding = meta["delta_encoding"] + algorithm = meta["checksum_format"] + locations = _tensor_locations(local_ckpt_dir) + open_mmaps: dict[str, tuple] = {} + mismatches: list[str] = [] + lock = threading.Lock() + file_bytes: list[bytes] = [] # keep alive: items hold zero-copy views into these + items: list[tuple] = [] # (name, compressed_view, path, offset, nbytes, want_checksum) + try: + for delta_file in sorted(glob.glob(os.path.join(version_dir, "*.safetensors"))): + with open(delta_file, "rb") as f: + blob = f.read() + file_bytes.append(blob) + (header_len,) = struct.unpack(" None: + name, compressed, path, offset, nbytes, want = item + region = np.ndarray((nbytes,), dtype=np.uint8, buffer=open_mmaps[path][1], offset=offset) + hasher = _new_hasher(algorithm) + reader = zstandard.ZstdDecompressor().stream_reader(io.BytesIO(bytes(compressed))) + pos = 0 + while pos < nbytes: # 2 MB chunks stay L2-resident across decompress -> XOR -> checksum + block = reader.read(min(2 << 20, nbytes - pos)) + if not block: + break + chunk = np.frombuffer(block, dtype=np.uint8) + region[pos : pos + chunk.size] ^= chunk + hasher.update(region[pos : pos + chunk.size]) + pos += chunk.size + if hasher.hexdigest() != want: + with lock: + mismatches.append(name) + + def apply_overwrite(item) -> None: + name, compressed, path, offset, nbytes, want = item + delta = np.frombuffer(zstandard.ZstdDecompressor().decompress(bytes(compressed)), dtype=np.uint8) + region = np.ndarray((nbytes,), dtype=np.uint8, buffer=open_mmaps[path][1], offset=offset) + count = int.from_bytes(delta[:4].tobytes(), "little") + positions = np.frombuffer(delta[4 : 4 + 4 * count].tobytes(), dtype=" None: + """Apply the delta chain in order to bring the local checkpoint up to target_version, in place. + A per-tensor checksum guards every write and any mismatch raises (fail loud, never serve bad + weights). Serialized per host by the lock (co-located actors collapse to one apply).""" + with _apply_lock(local_ckpt_dir): + applied = _read_applied_version(local_ckpt_dir) + if applied is None: + raise RuntimeError("local checkpoint not materialized") + for version in range(int(applied) + 1, target_version + 1): + _apply_version(local_ckpt_dir, os.path.join(delta_root, f"weight_v{version:06d}")) diff --git a/slime/utils/http_utils.py b/slime/utils/http_utils.py index b8c3a30fb3..56167c4b25 100644 --- a/slime/utils/http_utils.py +++ b/slime/utils/http_utils.py @@ -162,7 +162,7 @@ def _next_actor(): return actor -async def _post(client, url, payload, max_retries=60, headers=None): +async def _post(client, url, payload, max_retries=60, headers=None, retry_sleep=1.0): retry_count = 0 while retry_count < max_retries: response = None @@ -188,7 +188,7 @@ async def _post(client, url, payload, max_retries=60, headers=None): if retry_count >= max_retries: logger.info(f"Max retries ({max_retries}) reached, failing... (url={url})") raise e - await asyncio.sleep(1) + await asyncio.sleep(retry_sleep) continue finally: if response is not None: @@ -262,8 +262,8 @@ def __init__(self, concurrency: int): trust_env=False, # internal SGLang comm only — never route through system proxy ) - async def do_post(self, url, payload, max_retries=60, headers=None): - return await _post(self._client, url, payload, max_retries, headers=headers) + async def do_post(self, url, payload, max_retries=60, headers=None, retry_sleep=1.0): + return await _post(self._client, url, payload, max_retries, headers=headers, retry_sleep=retry_sleep) # Create actors per node created = [] @@ -288,7 +288,7 @@ async def do_post(self, url, payload, max_retries=60, headers=None): _post_actors = created -async def post(url, payload, max_retries=60, headers=None): +async def post(url, payload, max_retries=60, headers=None, retry_sleep=1.0): # If distributed mode is enabled and actors exist, dispatch via Ray. if _distributed_post_enabled and _post_actors: try: @@ -300,13 +300,13 @@ async def post(url, payload, max_retries=60, headers=None): # `min(32, cpu+4)`), which becomes a hard upper bound on the # number of in-flight POSTs that can be waited on in parallel # and produces large tail latencies under high concurrency. - obj_ref = actor.do_post.remote(url, payload, max_retries, headers=headers) + obj_ref = actor.do_post.remote(url, payload, max_retries, headers=headers, retry_sleep=retry_sleep) return await obj_ref except Exception as e: logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") # fall through to local - return await _post(_http_client, url, payload, max_retries, headers=headers) + return await _post(_http_client, url, payload, max_retries, headers=headers, retry_sleep=retry_sleep) async def get(url): diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 327dec2de6..a4dd0c0181 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -171,10 +171,30 @@ def compute_cispo_loss( return pg_losses, clipfrac -def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: dist.ProcessGroup | None): +def compute_log_probs( + logits: torch.Tensor, + tokens: torch.Tensor, + process_group: dist.ProcessGroup | None, + keep_mask: torch.Tensor | None = None, +): # TODO: when megatron is not installed, fall back to naive implementation from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy + if keep_mask is not None: + from megatron.core import mpu + + # Force-keep the sampled token on its TP shard so replay remains finite + # even if an engine-side path records a nucleus that misses the target. + keep_mask = keep_mask.clone() + vocab_local = keep_mask.size(-1) + vocab_start = mpu.get_tensor_model_parallel_rank() * vocab_local + local_tokens = tokens - vocab_start + on_shard = (local_tokens >= 0) & (local_tokens < vocab_local) + rows = torch.nonzero(on_shard, as_tuple=False).squeeze(-1) + if rows.numel() > 0: + keep_mask[rows, local_tokens[rows]] = True + logits = logits.masked_fill(~keep_mask, float("-inf")) + # convert to [seq_len, batch_size, vocab_size] as expected by fused_vocab_parallel_cross_entropy logits = logits.unsqueeze(1) tokens = tokens.unsqueeze(1) @@ -669,7 +689,9 @@ def chunked_gae( return advantages, returns -def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1): +def calculate_log_probs_and_entropy( + logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1, log_prob_keep_mask=None +): logits = logits.contiguous() entropy = None if logits.size(0) != 0: @@ -677,6 +699,9 @@ def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool num_chunks = (logits.size(0) - 1) // chunk_size + 1 logits_chunks = logits.chunk(num_chunks, dim=0) tokens_chunks = tokens.chunk(num_chunks, dim=0) + mask_chunks = ( + log_prob_keep_mask.chunk(num_chunks, dim=0) if log_prob_keep_mask is not None else [None] * num_chunks + ) if with_entropy: entropys = [] @@ -686,8 +711,8 @@ def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool entropy = torch.cat(entropys, dim=0) log_probs = [] - for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): - log_prob = compute_log_probs(logits_chunk.clone(), tokens_chunk, tp_group) + for tokens_chunk, logits_chunk, mask_chunk in zip(tokens_chunks, logits_chunks, mask_chunks, strict=True): + log_prob = compute_log_probs(logits_chunk.clone(), tokens_chunk, tp_group, keep_mask=mask_chunk) log_probs.append(log_prob) log_prob = torch.cat(log_probs, dim=0) else: @@ -695,7 +720,7 @@ def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool entropy_input = logits.clone() entropy = compute_entropy_from_logits(entropy_input, tp_group) - log_prob = compute_log_probs(logits.clone(), tokens, tp_group) + log_prob = compute_log_probs(logits.clone(), tokens, tp_group, keep_mask=log_prob_keep_mask) else: log_prob = logits.new_zeros((0,)) if with_entropy: diff --git a/slime/utils/types.py b/slime/utils/types.py index 2052ad5269..00bac587e8 100644 --- a/slime/utils/types.py +++ b/slime/utils/types.py @@ -33,6 +33,10 @@ class Sample: loss_mask: list[int] | None = None weight_versions: list[str] = field(default_factory=list) rollout_log_probs: list[float] | None = None # Log probabilities from rollout engine + # Ragged top-p nucleus token ids replayed from rollout sampling. For response + # token i, kept ids are rollout_top_p_token_ids[offsets[i]:offsets[i + 1]]. + rollout_top_p_token_ids: list[int] | None = None + rollout_top_p_token_offsets: list[int] | None = None rollout_routed_experts: list[list[int]] | None = None # Routed experts from rollout engine remove_sample: bool = False teacher_log_probs: list[float] | None = None # Log probabilities from teacher model for OPD diff --git a/slime_plugins/megatron_bridge/glm4v_moe.py b/slime_plugins/megatron_bridge/glm4v_moe.py index a639b905f7..2f78285782 100644 --- a/slime_plugins/megatron_bridge/glm4v_moe.py +++ b/slime_plugins/megatron_bridge/glm4v_moe.py @@ -33,10 +33,10 @@ # --------------------------------------------------------------------------- -# THD ↔ BSHD helpers (cf. Qwen3VL bridge) +# THD ↔ batch-sequence helpers (cf. Qwen3VL bridge) # --------------------------------------------------------------------------- -def _thd_to_bshd(packed: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - """Unpack THD-format [1, T, ...] to BSHD [bs, max_seq, ...] using cu_seqlens.""" +def _thd_to_batch_seq(packed: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + """Unpack THD-format [1, T, ...] to [bs, max_seq, ...] using cu_seqlens.""" seqlens = cu_seqlens[1:] - cu_seqlens[:-1] max_seq = seqlens.max().item() bs = len(cu_seqlens) - 1 @@ -46,8 +46,8 @@ def _thd_to_bshd(packed: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor return out -def _bshd_to_thd(unpacked: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - """Pack BSHD [bs, max_seq, ...] back to THD [1, T, ...].""" +def _batch_seq_to_thd(unpacked: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + """Pack [bs, max_seq, ...] back to THD [1, T, ...].""" seqlens = cu_seqlens[1:] - cu_seqlens[:-1] total = cu_seqlens[-1].item() out = unpacked.new_zeros(1, total, *unpacked.shape[2:]) @@ -275,7 +275,7 @@ def _get_vision_position_ids( def _compute_mrope_position_ids( self, - input_ids_bshd: torch.Tensor, + input_ids_batch_seq: torch.Tensor, image_grid_thw: torch.Tensor | None, ) -> torch.Tensor: """Compute 3D M-RoPE position IDs from input_ids in [bs, seq] format. @@ -283,8 +283,8 @@ def _compute_mrope_position_ids( Image regions are detected by looking for consecutive runs of ``image_token_id`` in each sequence — no ``mm_token_type_ids`` needed. """ - bs, seq_len = input_ids_bshd.shape - device = input_ids_bshd.device + bs, seq_len = input_ids_batch_seq.shape + device = input_ids_batch_seq.device spatial_merge_size = self.spatial_merge_size position_ids = torch.zeros(3, bs, seq_len, dtype=torch.long, device=device) @@ -300,7 +300,7 @@ def _compute_mrope_position_ids( grid_iter = iter(image_grid_thw) for b in range(bs): - ids = input_ids_bshd[b] + ids = input_ids_batch_seq[b] is_image = ids == self.image_token_id # Find contiguous groups: text (0) vs image (1) @@ -430,9 +430,9 @@ def forward( full_input_ids = _gather_input_ids_from_cp(input_ids, cu_seqlens) else: full_input_ids = input_ids - input_ids_bshd = _thd_to_bshd(full_input_ids, cu_seqlens) - pos_bshd = self._compute_mrope_position_ids(input_ids_bshd, image_grid_thw) - pos_packed = _bshd_to_thd(pos_bshd.permute(1, 2, 0), cu_seqlens) + input_ids_batch_seq = _thd_to_batch_seq(full_input_ids, cu_seqlens) + pos_batch_seq = self._compute_mrope_position_ids(input_ids_batch_seq, image_grid_thw) + pos_packed = _batch_seq_to_thd(pos_batch_seq.permute(1, 2, 0), cu_seqlens) position_ids = pos_packed.permute(2, 0, 1).contiguous() # [3, 1, T_global] else: position_ids = self._compute_mrope_position_ids(input_ids, image_grid_thw) diff --git a/tests/test_glm4.7_30B_A3B_pd_mooncake.py b/tests/test_glm4.7_30B_A3B_pd_mooncake.py index 66d0558a5b..75614e48ce 100644 --- a/tests/test_glm4.7_30B_A3B_pd_mooncake.py +++ b/tests/test_glm4.7_30B_A3B_pd_mooncake.py @@ -70,7 +70,8 @@ def execute(): "--rollout-batch-size 4 " "--n-samples-per-prompt 2 " "--rollout-max-response-len 512 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1.0 " + "--rollout-top-p 0.95 " "--global-batch-size 8 " ) optimizer_args = ( diff --git a/tests/test_logprob_response_spans.py b/tests/test_logprob_response_spans.py new file mode 100644 index 0000000000..51f22016f0 --- /dev/null +++ b/tests/test_logprob_response_spans.py @@ -0,0 +1,91 @@ +import _cp_dist_helpers # noqa: F401 +import pytest +import torch + +from megatron.core import mpu +from slime.backends.megatron_utils.loss import _build_topp_keep_mask + + +NUM_GPUS = 0 + + +def _set_cp(monkeypatch, *, size: int, rank: int) -> None: + monkeypatch.setattr(mpu, "get_context_parallel_world_size", lambda: size) + monkeypatch.setattr(mpu, "get_context_parallel_rank", lambda: rank) + monkeypatch.setattr(mpu, "get_tensor_model_parallel_rank", lambda: 0, raising=False) + + +def _kept_ids(row: torch.Tensor) -> list[int]: + return row.nonzero(as_tuple=False).squeeze(-1).tolist() + + +@pytest.mark.unit +@pytest.mark.parametrize( + ("rank", "expected"), + [ + (0, {2: [107]}), + (1, {1: [104], 2: [105], 3: [106]}), + ], +) +def test_top_p_mask_aligns_with_zigzag_cp_response_rows(monkeypatch, rank, expected): + _set_cp(monkeypatch, size=2, rank=rank) + keep = _build_topp_keep_mask( + 4, + 200, + torch.device("cpu"), + top_p_token_ids=[[104, 105, 106, 107]], + top_p_token_offsets=[[0, 1, 2, 3, 4]], + total_lengths=[8], + response_lengths=[4], + allgather_cp=False, + ) + + masked_rows = {row: _kept_ids(keep[row]) for row in range(keep.size(0)) if not keep[row].all()} + assert masked_rows == expected + + +@pytest.mark.unit +@pytest.mark.parametrize( + ("rank", "expected"), + [ + (0, {1: [102], 2: [103]}), + (1, {0: [104], 1: [105]}), + ], +) +def test_top_p_mask_aligns_with_allgather_cp_response_rows(monkeypatch, rank, expected): + _set_cp(monkeypatch, size=2, rank=rank) + keep = _build_topp_keep_mask( + 3, + 200, + torch.device("cpu"), + top_p_token_ids=[[102, 103, 104, 105]], + top_p_token_offsets=[[0, 1, 2, 3, 4]], + total_lengths=[6], + response_lengths=[4], + allgather_cp=True, + ) + + masked_rows = {row: _kept_ids(keep[row]) for row in range(keep.size(0)) if not keep[row].all()} + assert masked_rows == expected + + +@pytest.mark.unit +def test_top_p_mask_aligns_with_cp1_response_rows(monkeypatch): + _set_cp(monkeypatch, size=1, rank=0) + keep = _build_topp_keep_mask( + 9, + 30, + torch.device("cpu"), + top_p_token_ids=[[13, 99, 14], [21, 22, 99, 23]], + top_p_token_offsets=[[0, 2, 3], [0, 1, 3, 4]], + total_lengths=[5, 4], + response_lengths=[2, 3], + allgather_cp=False, + ) + + masked_rows = {row: _kept_ids(keep[row]) for row in range(keep.size(0)) if not keep[row].all()} + assert masked_rows == {2: [13], 3: [14], 5: [21], 6: [22], 7: [23]} + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/tests/test_megatron_argument_validation.py b/tests/test_megatron_argument_validation.py index 1f435cb577..5641dd1450 100644 --- a/tests/test_megatron_argument_validation.py +++ b/tests/test_megatron_argument_validation.py @@ -49,6 +49,7 @@ def load_slime_arguments_module(monkeypatch): sglang_arguments_mod.sglang_parse_args = lambda *args, **kwargs: None sglang_arguments_mod.validate_args = lambda args: args sglang_external_mod.apply_external_engine_info_to_args = lambda *args, **kwargs: None + sglang_external_mod.normalize_rollout_endpoint_url = lambda url: url.rstrip("/") logging_utils_mod.configure_logger = lambda *args, **kwargs: None monkeypatch.setitem(sys.modules, "sglang_router", router_pkg_mod) @@ -170,58 +171,10 @@ def test_allgather_cp_ignores_cp_size_one(monkeypatch): @pytest.mark.unit def test_update_weight_disk_dir_required_for_disk_transport(monkeypatch): module = load_slime_arguments_module(monkeypatch) - args = types.SimpleNamespace( - update_weight_transport="disk", - update_weight_disk_dir=None, - update_weight_delta_dir=None, - ) + args = make_slime_validate_args(update_weight_transport="disk", update_weight_disk_dir=None) with pytest.raises(ValueError, match="update-weight-disk-dir"): - module._resolve_update_weight_disk_dir(args) - - -@pytest.mark.unit -def test_update_weight_disk_dir_normalizes_delta_alias(monkeypatch): - module = load_slime_arguments_module(monkeypatch) - args = types.SimpleNamespace( - update_weight_transport="disk", - update_weight_disk_dir=None, - update_weight_delta_dir="/shared/delta", - ) - - with pytest.warns(UserWarning, match="will be removed in a future release"): - module._resolve_update_weight_disk_dir(args) - - assert args.update_weight_disk_dir == "/shared/delta" - assert args.update_weight_delta_dir == "/shared/delta" - - -@pytest.mark.unit -def test_update_weight_disk_dir_backfills_legacy_delta_field(monkeypatch): - module = load_slime_arguments_module(monkeypatch) - args = types.SimpleNamespace( - update_weight_transport="disk", - update_weight_disk_dir="/shared/updates", - update_weight_delta_dir=None, - ) - - module._resolve_update_weight_disk_dir(args) - - assert args.update_weight_disk_dir == "/shared/updates" - assert args.update_weight_delta_dir == "/shared/updates" - - -@pytest.mark.unit -def test_update_weight_disk_dir_rejects_conflicting_alias(monkeypatch): - module = load_slime_arguments_module(monkeypatch) - args = types.SimpleNamespace( - update_weight_transport="disk", - update_weight_disk_dir="/shared/full", - update_weight_delta_dir="/shared/delta", - ) - - with pytest.raises(ValueError, match="deprecated alias"): - module._resolve_update_weight_disk_dir(args) + module.slime_validate_args(args) def make_slime_validate_args(**overrides): @@ -267,6 +220,7 @@ def make_slime_validate_args(**overrides): save_debug_train_data=None, load_debug_rollout_data=None, rollout_external_engine_addrs=None, + rollout_endpoint_url=None, debug_train_only=False, actor_num_gpus_per_node=8, actor_num_nodes=1, @@ -296,13 +250,12 @@ def make_slime_validate_args(**overrides): eval_max_context_len=None, rollout_max_context_len=None, rollout_max_prompt_len=None, - qkv_format="thd", train_backend="megatron", only_train_params_name_list=None, freeze_params_name_list=None, update_weight_transport="nccl", update_weight_disk_dir=None, - update_weight_delta_dir=None, + update_weight_local_checkpoint_dir=None, update_weight_mode="full", ) values.update(overrides) @@ -353,33 +306,45 @@ def test_slime_validate_args_preserves_zero_rollout_gpus_without_colocate(monkey @pytest.mark.unit -def test_update_weight_delta_rejects_colocate(monkeypatch): +def test_update_weight_delta_requires_disk_transport(monkeypatch): module = load_slime_arguments_module(monkeypatch) - args = types.SimpleNamespace( + args = make_slime_validate_args( update_weight_mode="delta", update_weight_transport="nccl", - update_weight_disk_dir=None, - update_weight_delta_dir=None, + update_weight_local_checkpoint_dir="/local/ckpt", + ) + + with pytest.raises(ValueError, match="requires --update-weight-transport=disk"): + module.slime_validate_args(args) + + +@pytest.mark.unit +def test_update_weight_delta_rejects_colocate(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = make_slime_validate_args( + update_weight_mode="delta", + update_weight_transport="disk", + update_weight_disk_dir="/shared/delta", + update_weight_local_checkpoint_dir="/local/ckpt", colocate=True, ) with pytest.raises(ValueError, match="not supported with --colocate"): - module._validate_update_weight_args(args) + module.slime_validate_args(args) @pytest.mark.unit -def test_update_weight_delta_rejects_unknown_transport(monkeypatch): +def test_update_weight_delta_requires_local_checkpoint_dir(monkeypatch): module = load_slime_arguments_module(monkeypatch) - args = types.SimpleNamespace( + args = make_slime_validate_args( update_weight_mode="delta", - update_weight_transport="tensor", - update_weight_disk_dir=None, - update_weight_delta_dir=None, - colocate=False, + update_weight_transport="disk", + update_weight_disk_dir="/shared/delta", + update_weight_local_checkpoint_dir=None, ) - with pytest.raises(ValueError, match="supports only --update-weight-transport=nccl or disk"): - module._validate_update_weight_args(args) + with pytest.raises(ValueError, match="requires --update-weight-local-checkpoint-dir"): + module.slime_validate_args(args) if __name__ == "__main__": diff --git a/tests/test_megatron_to_hf_router_dtype.py b/tests/test_megatron_to_hf_router_dtype.py new file mode 100644 index 0000000000..e7959393b3 --- /dev/null +++ b/tests/test_megatron_to_hf_router_dtype.py @@ -0,0 +1,117 @@ +"""Router params must be exported in the model dtype, not their fp32 training dtype. + +With ``--moe-router-dtype fp32`` Megatron keeps the MoE router weight / expert_bias in fp32 even +when the model is bf16/fp16. ``update_weight_from_disk_delta`` XORs every freshly converted HF +tensor against the raw bytes of the base HF checkpoint (which stores the router in the model +dtype), so a leftover fp32 router is a 4-vs-2 byte-width mismatch that breaks the delta. Every +converter that emits a router must cast it back to the model dtype. This is pure tensor-metadata +work (no model load, no I/O), so the whole module runs in milliseconds. +""" + +import importlib.util +import sys +import types +from pathlib import Path + +import pytest +import torch + +NUM_GPUS = 0 + +CONVERTER_DIR = ( + Path(__file__).resolve().parents[1] / "slime" / "backends" / "megatron_utils" / "megatron_to_hf" +) + +# Synthetic parent package so the converters' ``from .dtype_utils import ...`` relative import +# resolves to the real dtype_utils.py without running the package __init__ (which imports +# megatron/mbridge and is unavailable in the unit-test env). +_PKG = "_mhf_router_dtype_test_pkg" + +# (converter module, converter function, router param names that must be cast). +# gpt-oss has no expert_bias; its second router buffer is mlp.router.bias. +CONVERTERS = [ + ("deepseekv3", "convert_deepseekv3_to_hf", ["mlp.router.weight", "mlp.router.expert_bias"]), + ("glm4moe", "convert_glm4moe_to_hf", ["mlp.router.weight", "mlp.router.expert_bias"]), + ("minimax_m2", "convert_minimax_m2_to_hf", ["mlp.router.weight", "mlp.router.expert_bias"]), + ("qwen3moe", "convert_qwen3moe_to_hf", ["mlp.router.weight", "mlp.router.expert_bias"]), + ("qwen3_5", "convert_qwen3_5_to_hf", ["mlp.router.weight", "mlp.router.expert_bias"]), + ("qwen3_next", "convert_qwen3_next_to_hf", ["mlp.router.weight", "mlp.router.expert_bias"]), + ("gpt_oss", "convert_gpt_oss_to_hf", ["mlp.router.weight", "mlp.router.bias"]), +] + + +def _load_converter(module_name, func_name): + if _PKG not in sys.modules: + pkg = types.ModuleType(_PKG) + pkg.__path__ = [str(CONVERTER_DIR)] + sys.modules[_PKG] = pkg + full_name = f"{_PKG}.{module_name}" + sys.modules.pop(full_name, None) + spec = importlib.util.spec_from_file_location(full_name, CONVERTER_DIR / f"{module_name}.py") + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[full_name] = module # let the relative import find the in-progress module + spec.loader.exec_module(module) + return getattr(module, func_name) + + +def _args(**overrides): + # Fields every converter dereferences before reaching the router branch (head_dim / + # value_num_per_group). q_lora_rank is read in the package wrapper, not the converter itself. + base = dict( + kv_channels=None, + hidden_size=16, + num_attention_heads=4, + num_query_groups=4, + num_layers=2, + bf16=False, + fp16=False, + ) + base.update(overrides) + return types.SimpleNamespace(**base) + + +def _cases(): + for module_name, func_name, router_params in CONVERTERS: + for router_param in router_params: + for flag, dtype in (("bf16", torch.bfloat16), ("fp16", torch.float16)): + yield pytest.param( + module_name, func_name, router_param, flag, dtype, + id=f"{module_name}-{router_param.split('.')[-1]}-{flag}", + ) + + +@pytest.mark.unit +@pytest.mark.parametrize("module_name,func_name,router_param,flag,model_dtype", list(_cases())) +def test_router_param_cast_to_model_dtype(module_name, func_name, router_param, flag, model_dtype): + convert = _load_converter(module_name, func_name) + args = _args(**{flag: True}) + + name = f"module.module.decoder.layers.0.{router_param}" + fp32_param = torch.randn(8, 16, dtype=torch.float32) + + converted = convert(args, name, fp32_param) + + assert len(converted) == 1, f"{func_name} should map {router_param} to a single HF tensor" + _, out = converted[0] + # A fp32 router would be 2x the bytes of a bf16/fp16 base, breaking the disk-delta XOR. + assert out.dtype == model_dtype + + +@pytest.mark.unit +@pytest.mark.parametrize("module_name,func_name,router_param,_flag,_dtype", list(_cases())) +def test_router_param_preserved_in_full_precision(module_name, func_name, router_param, _flag, _dtype): + # No bf16/fp16 flag => no model dtype to match; leave the param untouched. + convert = _load_converter(module_name, func_name) + args = _args() + + name = f"module.module.decoder.layers.0.{router_param}" + fp32_param = torch.randn(8, 16, dtype=torch.float32) + + _, out = convert(args, name, fp32_param)[0] + + assert out.dtype == torch.float32 + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/tests/test_qwen3.5_0.8B_gsm8k_async_short.py b/tests/test_qwen3.5_0.8B_gsm8k_async_short.py index 7b7609ad56..3f822945f7 100644 --- a/tests/test_qwen3.5_0.8B_gsm8k_async_short.py +++ b/tests/test_qwen3.5_0.8B_gsm8k_async_short.py @@ -36,6 +36,7 @@ def execute(): "--n-samples-per-prompt 4 " "--rollout-max-response-len 1024 " "--rollout-temperature 0.8 " + "--rollout-top-p 0.95 " "--over-sampling-batch-size 8 " "--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " "--global-batch-size 16 " diff --git a/tests/test_qwen3.5_0.8B_gsm8k_short.py b/tests/test_qwen3.5_0.8B_gsm8k_short.py index 8bf2c46d0c..93306eadd3 100644 --- a/tests/test_qwen3.5_0.8B_gsm8k_short.py +++ b/tests/test_qwen3.5_0.8B_gsm8k_short.py @@ -36,6 +36,7 @@ def execute(): "--n-samples-per-prompt 4 " "--rollout-max-response-len 1024 " "--rollout-temperature 0.8 " + "--rollout-top-p 0.95 " "--rollout-data-transport nixl " "--over-sampling-batch-size 8 " "--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " diff --git a/tests/test_rollout_request_hook.py b/tests/test_rollout_request_hook.py new file mode 100644 index 0000000000..5572815ba6 --- /dev/null +++ b/tests/test_rollout_request_hook.py @@ -0,0 +1,268 @@ +import asyncio +import json +import sys +from argparse import Namespace +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +try: + import ray # noqa: F401 +except ImportError: + pass + +try: + from tests.plugin_contracts._shared import install_stubs +except ImportError: + from plugin_contracts._shared import install_stubs + +install_stubs(with_sglang_router=True, with_transformers=True) + +NUM_GPUS = 0 + +from slime.rollout import sglang_rollout # noqa: E402 +from slime.rollout.sglang_rollout import generate # noqa: E402 +from slime.utils import http_utils # noqa: E402 +from slime.utils.types import Sample # noqa: E402 + + +def _args(**overrides): + values = { + "ci_test": False, + "rollout_endpoint_url": None, + "sglang_router_ip": "10.0.0.1", + "sglang_router_port": 30000, + "sglang_model_routers": None, + "router_policy": None, + "use_rollout_routing_replay": False, + "partial_rollout": False, + "mask_offpolicy_in_partial_rollout": False, + "sglang_speculative_algorithm": None, + "custom_rollout_request_hook_path": None, + } + values.update(overrides) + return Namespace(**values) + + +class _Tokenizer: + def encode(self, prompt, add_special_tokens=False): + assert add_special_tokens is False + return [101, len(prompt)] + + +class _GenerateState: + def __init__(self, args): + self.args = args + self.tokenizer = _Tokenizer() + self.processor = None + self.pendings = set() + self.aborted = False + + +def _fake_generate_response(): + return { + "text": " answer", + "meta_info": { + "output_token_logprobs": [[-0.25, 42]], + "finish_reason": {"type": "stop"}, + "prompt_tokens": 2, + "cached_tokens": 1, + }, + } + + +def _run_generate(args, monkeypatch, sample_index=5): + captured = {} + + async def fake_post(url, payload, headers=None, max_retries=60, retry_sleep=1.0): + captured.update(url=url, payload=payload, headers=headers, max_retries=max_retries, retry_sleep=retry_sleep) + return _fake_generate_response() + + monkeypatch.setattr(sglang_rollout, "GenerateState", _GenerateState) + monkeypatch.setattr(sglang_rollout, "post", fake_post) + sample = asyncio.run(generate(args, Sample(index=sample_index, prompt="hi"), {"max_new_tokens": 8})) + return sample, captured + + +def test_generate_without_hook_sends_plain_request(monkeypatch): + """With no hook configured, generate calls post directly — no request dict, no extra fields.""" + args = _args(rollout_endpoint_url="https://rollout.example") + sample, captured = _run_generate(args, monkeypatch) + + assert captured["url"] == "https://rollout.example/generate" + assert "weight_version" not in captured["payload"] + # default path uses post's defaults — the hook never ran to change them + assert captured["max_retries"] == 60 + assert captured["retry_sleep"] == 1.0 + assert sample.status == Sample.Status.COMPLETED + + +def test_request_hook_can_mutate_request_in_place(monkeypatch): + def hook(args, sample, request): + # request carries how-to-send fields (incl. retry knobs); rollout context comes off the sample. + assert set(request) == {"url", "payload", "headers", "max_retries", "retry_sleep"} + request["headers"] = {**(request["headers"] or {}), "Authorization": "Bearer t"} + request["payload"]["weight_version"] = {"min_version": sample.index} + request["max_retries"], request["retry_sleep"] = 120, 0.5 + + monkeypatch.setattr(sglang_rollout, "load_function", lambda path: hook) + args = _args(rollout_endpoint_url="https://rollout.example", custom_rollout_request_hook_path="example.hook") + sample, captured = _run_generate(args, monkeypatch, sample_index=5) + + assert captured["headers"]["Authorization"] == "Bearer t" + assert captured["payload"]["weight_version"] == {"min_version": 5} + assert captured["max_retries"] == 120 + assert captured["retry_sleep"] == 0.5 + assert sample.status == Sample.Status.COMPLETED + + +def test_request_hook_can_return_dict_of_updates_and_be_async(monkeypatch): + async def hook(args, sample, request): + payload = dict(request["payload"]) + payload["weight_version"] = {"min_version": sample.index} + return {"payload": payload} + + monkeypatch.setattr(sglang_rollout, "load_function", lambda path: hook) + args = _args(rollout_endpoint_url="https://rollout.example", custom_rollout_request_hook_path="example.hook") + sample, captured = _run_generate(args, monkeypatch, sample_index=3) + + assert captured["payload"]["weight_version"] == {"min_version": 3} + assert sample.status == Sample.Status.COMPLETED + + +def test_request_hook_must_return_none_or_dict(monkeypatch): + monkeypatch.setattr(sglang_rollout, "load_function", lambda path: lambda a, s, r: "nope") + args = _args(rollout_endpoint_url="https://rollout.example", custom_rollout_request_hook_path="example.hook") + with pytest.raises(TypeError, match="None or a dict"): + _run_generate(args, monkeypatch) + + +class _FakeStreamResponse: + def __init__(self, lines): + self._lines = lines + + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + return False + + def raise_for_status(self): + pass + + async def aiter_lines(self): + for line in self._lines: + yield line + + +class _FakeStreamClient: + def __init__(self, captured, lines): + self._captured = captured + self._lines = lines + + def stream(self, method, url, json=None, headers=None): + self._captured.update(method=method, url=url, json=json, headers=headers) + return _FakeStreamResponse(self._lines) + + +def test_streaming_generate_routes_to_endpoint_and_applies_hook(monkeypatch): + """The streaming path funnels through the same hook and uses get_model_url, so it reaches the + opaque endpoint (not the unset router) and a hook can shape the request.""" + from slime.rollout import sglang_streaming_rollout + + chunk = { + "text": " answer", + "meta_info": { + "output_token_logprobs": [[-0.25, 42]], + "finish_reason": {"type": "stop"}, + "prompt_tokens": 2, + "cached_tokens": 0, + }, + } + captured = {} + client = _FakeStreamClient(captured, [f"data: {json.dumps(chunk)}", "data: [DONE]"]) + + def hook(args, sample, request): + request["headers"] = {**(request["headers"] or {}), "Authorization": "Bearer t"} + request["payload"]["weight_version"] = {"min_version": sample.index} + + # The hook runs inside apply_rollout_request_hook, which resolves load_function from sglang_rollout. + monkeypatch.setattr(sglang_streaming_rollout, "GenerateState", _GenerateState) + monkeypatch.setattr(sglang_rollout, "load_function", lambda path: hook) + monkeypatch.setattr(http_utils, "_http_client", client) + + args = _args(rollout_endpoint_url="https://rollout.example", custom_rollout_request_hook_path="example.hook") + sample = asyncio.run( + sglang_streaming_rollout.generate_streaming(args, Sample(index=7, prompt="hi"), {"max_new_tokens": 8}) + ) + + assert captured["url"] == "https://rollout.example/generate" + assert captured["headers"]["Authorization"] == "Bearer t" + assert captured["json"]["weight_version"] == {"min_version": 7} + assert sample.status == Sample.Status.COMPLETED + + +def test_streaming_generate_without_hook_opens_stream_unchanged(monkeypatch): + """No hook configured: the stream is opened with the original payload/headers, no hook detour.""" + from slime.rollout import sglang_streaming_rollout + + chunk = { + "text": " answer", + "meta_info": {"output_token_logprobs": [[-0.25, 42]], "finish_reason": {"type": "stop"}}, + } + captured = {} + client = _FakeStreamClient(captured, [f"data: {json.dumps(chunk)}", "data: [DONE]"]) + + monkeypatch.setattr(sglang_streaming_rollout, "GenerateState", _GenerateState) + monkeypatch.setattr(http_utils, "_http_client", client) + + args = _args(rollout_endpoint_url="https://rollout.example") + asyncio.run(sglang_streaming_rollout.generate_streaming(args, Sample(index=0, prompt="hi"), {"max_new_tokens": 8})) + + assert captured["url"] == "https://rollout.example/generate" + assert "weight_version" not in captured["json"] + + +def test_post_retries_until_version_available_with_backoff(monkeypatch): + """A gating hook relies on this: the fleet rejects a not-yet-loaded version, and post retries + with retry_sleep backoff until it is served.""" + import httpx + + attempts = [] + sleeps = [] + + class _Resp: + def __init__(self, status): + self.status_code = status + self.text = "weight version not loaded" + + def raise_for_status(self): + if self.status_code >= 400: + raise httpx.HTTPStatusError("not ready", request=None, response=self) + + async def aread(self): + return b'{"ok": true}' + + async def aclose(self): + pass + + class _Client: + async def post(self, url, json=None, headers=None): + attempts.append(json) + return _Resp(409 if len(attempts) <= 2 else 200) + + async def fake_sleep(seconds): + sleeps.append(seconds) + + monkeypatch.setattr(http_utils.asyncio, "sleep", fake_sleep) + + payload = {"weight_version": {"min_version": 11}} + out = asyncio.run(http_utils._post(_Client(), "http://fleet/generate", payload, max_retries=5, retry_sleep=0.01)) + + assert out == {"ok": True} + assert len(attempts) == 3 + assert sleeps == [0.01, 0.01] # backed off once per rejection, honoring retry_sleep diff --git a/tests/test_sample.py b/tests/test_sample.py index f42ff4e9cc..3cce9aeb41 100644 --- a/tests/test_sample.py +++ b/tests/test_sample.py @@ -50,6 +50,8 @@ def _make_sample(**overrides) -> Sample: loss_mask=[1, 1, 0, 1, 1], weight_versions=["v1"], rollout_log_probs=[-0.1, -0.2], + rollout_top_p_token_ids=[10, 11, 12, 20], + rollout_top_p_token_offsets=[0, 3, 4], rollout_routed_experts=[[0, 1], [2, 3]], remove_sample=False, teacher_log_probs=[-0.3, -0.4], @@ -131,6 +133,8 @@ def test_round_trip_preserves_every_field(): "loss_mask", "weight_versions", "rollout_log_probs", + "rollout_top_p_token_ids", + "rollout_top_p_token_offsets", "rollout_routed_experts", "remove_sample", "teacher_log_probs", diff --git a/tests/test_value_temperature.py b/tests/test_value_temperature.py index 0de0b793c6..195d89e88f 100644 --- a/tests/test_value_temperature.py +++ b/tests/test_value_temperature.py @@ -26,7 +26,7 @@ def test_get_values_does_not_apply_rollout_temperature(monkeypatch): try: from slime.backends.megatron_utils.loss import get_values - args = Namespace(qkv_format="thd", rollout_temperature=0.5, allgather_cp=False) + args = Namespace(rollout_temperature=0.5, allgather_cp=False) logits = torch.tensor([[[1.0], [2.0], [3.0], [4.0]]], dtype=torch.float32) tokens = [torch.tensor([10, 11, 12, 13], dtype=torch.long)]