From d05429f0fbe3f6f72644ed37168c593675d7c061 Mon Sep 17 00:00:00 2001 From: Nan Date: Tue, 16 Jun 2026 04:01:23 +0000 Subject: [PATCH 1/8] feat: disk-level delta weight sync Ship only the changed bytes between weight syncs as a canonical HF delta checkpoint; rollout hosts apply it into a host-local checkpoint and reload via the vanilla update_weights_from_disk path. Replaces the NCCL delta transport from #1806 with a disk-only path that needs no engine-side delta support. --- docs/en/advanced/delta-weight-sync.md | 175 ++-- docs/en/advanced/external-rollout-engines.md | 18 +- docs/en/index.rst | 2 +- docs/zh/advanced/delta-weight-sync.md | 111 +-- docs/zh/advanced/external-rollout-engines.md | 18 +- docs/zh/index.rst | 2 +- examples/README.md | 2 +- examples/delta_weight_sync/README.md | 83 +- .../run-glm4.7-30B-A3B-delta.sh | 109 +++ .../run-glm4.7-355B-A32B-delta.sh | 192 ---- requirements.txt | 3 + slime/backends/megatron_utils/actor.py | 24 +- slime/backends/megatron_utils/sglang.py | 12 - .../update_weight/update_weight_from_disk.py | 1 + .../update_weight_from_disk_delta.py | 291 ++++++ .../update_weight_from_distributed.py | 21 +- .../update_weight_from_distributed_delta.py | 864 ------------------ .../update_weight_from_tensor.py | 1 + slime/backends/sglang_utils/sglang_engine.py | 67 +- slime/ray/rollout.py | 3 +- slime/utils/arguments.py | 117 +-- slime/utils/disk_delta.py | 264 ++++++ 22 files changed, 944 insertions(+), 1436 deletions(-) create mode 100644 examples/delta_weight_sync/run-glm4.7-30B-A3B-delta.sh delete mode 100755 examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh create mode 100644 slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py delete mode 100644 slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py create mode 100644 slime/utils/disk_delta.py diff --git a/docs/en/advanced/delta-weight-sync.md b/docs/en/advanced/delta-weight-sync.md index a1e670f81d..d421597297 100644 --- a/docs/en/advanced/delta-weight-sync.md +++ b/docs/en/advanced/delta-weight-sync.md @@ -1,111 +1,86 @@ # Delta Weight Sync -- [Why](#why) -- [Quick Start](#quick-start) -- [Mode vs Transport](#mode-vs-transport) -- [How It Works](#how-it-works) -- [Encoding Choice](#encoding-choice) -- [Why Not Colocated](#why-not-colocated) +Delta weight sync keeps non-colocated rollout engines up to date by shipping only the bytes +that changed between two syncs, instead of a full checkpoint each time. It targets large-model +training/inference disaggregation across clusters or datacenters, where writing the whole actor +every sync is the dominant cost. -## Why +It is **disk-transport only** and reloads through the **ordinary** `update_weights_from_disk` +endpoint, so the inference engine needs no delta-specific support. -Slime's default sync broadcasts every parameter every step. The cost scales linearly with model size and dominates the sync phase, even though only a few percent of weights change between consecutive RL steps. Delta sync keeps a pinned-CPU snapshot of the last broadcast and ships only the positions whose bytes differ. - -The motivating use case is **training/inference disaggregation** — running the trainer and the rollout engines in *different datacenters* over a shared filesystem with bandwidth on the order of 100s of MB/s, where a full broadcast is infeasible but a sparse delta (~3% density, ~5 GB for a 355B model) is. The same delta machinery also runs over NCCL inside a single datacenter, where it serves as the validation baseline that proves the wire encoding and apply logic are correct. - -Prior art: selective overwrite is inspired by [arXiv:2509.19128](https://arxiv.org/abs/2509.19128); the cross-DC disaggregation motivation is from [Fireworks AI — Frontier RL Is Cheaper Than You Think](https://fireworks.ai/blog/frontier-rl-is-cheaper-than-you-think). Another public production-shaped reference is the [Composer 2 technical report by the Cursor Research Team](https://arxiv.org/html/2603.24477v2), which describes Cursor partnering with Fireworks AI for RL inference and syncing every training-step update through shared S3, delta compression, and cross-region inference-cluster reconstruction. - -## Quick Start - -Disk transport (training/inference disaggregation — the main use case): +## Configuration ```bash --update-weight-mode delta --update-weight-transport disk ---update-weight-encoding deltas_zstd # best for ≤ 300 MB/s shared FS --update-weight-disk-dir /shared/fs/delta-updates +--update-weight-local-checkpoint-dir /local/nvme/rollout-ckpt +--update-weight-delta-encoding xor # or: overwrite +--update-weight-delta-checksum xxh3-128 # or: blake3, adler32 ``` -NCCL transport (intra-datacenter validation baseline): - -```bash ---update-weight-mode delta ---update-weight-transport nccl ---update-weight-encoding indices # lowest compute, no compression -``` - -Full-checkpoint disk transport (simple external-engine fallback): - -```bash ---update-weight-mode full ---update-weight-transport disk ---update-weight-disk-dir /shared/fs/full-updates -``` - -This writes a complete HF checkpoint under `weight_v{N:06d}/` for every sync, -then asks each SGLang engine to reload it with `update_weights_from_disk`. It is -useful when the trainer cannot form an NCCL group with pre-launched rollout -engines, but it is much heavier than delta sync for large models. - -Receiver-side delta tuning (applies to delta NCCL and delta disk): - -```bash ---sglang-update-weight-delta-chunk-bytes $((2 * 1024 * 1024 * 1024)) # byte cap per load_weights call ---sglang-update-weight-delta-read-workers 4 # parallel I/O threads (disk only) -``` - -See [examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh](../../../examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh) for a complete launcher. - -## Mode vs Transport - -`--update-weight-mode` decides **what** gets sent; `--update-weight-transport` -decides **how** it reaches SGLang. - -| mode | transport | behavior | -|---|---|---| -| `full` | `nccl` | default path: broadcast every HF weight chunk over a trainer-engine NCCL group | -| `full` | `disk` | write a complete HF checkpoint under `--update-weight-disk-dir`, then call `update_weights_from_disk` | -| `delta` | `nccl` | broadcast sparse changed positions + values over NCCL | -| `delta` | `disk` | write sparse safetensors under `--update-weight-disk-dir`, then call `update_weights_from_disk(load_format="delta")` | - -`--update-weight-delta-dir` is kept only as a backward-compatible alias for -`--update-weight-disk-dir`; new launchers should use the transport-level name. - -## How It Works - -Delta NCCL and delta disk share one sender pipeline, one wire layout, and one receiver-side decoder; only the per-flush carrier differs. - -**Sender (per sync, PP-source rank only):** - -1. **Diff** the current weights against the pinned-CPU snapshot via bytewise compare (`current.view(int_dtype) != snapshot.view(int_dtype)`) — lossless, dtype-agnostic, no arithmetic. -2. **Encode** changed (position, value) pairs into a packed `__positions__` byte blob + `__values__` tensor + per-param decoding manifest. The encoding (`indices`, `deltas`, `deltas_zstd`) governs only how positions are packed; values are sent verbatim in the param's dtype. -3. **Bucket** per-chunk encodes up to `--update-weight-buffer-size` bytes, then flush: - - NCCL: broadcast `(__positions__, __values__)` to the rollout engines with a `DeltaSpec` (encoding + per-param manifest) carried in the Ray RPC. - - Disk: write one safetensors file per flush under `weight_v{N:06d}/`. Async background thread does the I/O + optional zstd compression off the critical path. -4. **Snapshot the just-sent values** via a D2H copy on a side stream so it overlaps with the next chunk's encode. - -**End-of-sync (disk only):** write a `DONE` marker, then rank 0 fires one HTTP push per engine and removes the directory after every engine acknowledges. - -**Receiver:** - -For both transports, the receiver ends up calling the same `_apply_delta_payload(encoding, params, positions, values)` helper. It decodes each param's slice into a full-shape tensor with NaN at unchanged positions, then routes it through `model.load_weights(...)` under a `_delta_apply_context` that patches `Tensor.copy_` / `Tensor.fill_` to perform NaN-masked overwrite. Auxiliary writes (scratch buffers, fp8 scales, MoE biases via `post_load_weights`) keep their normal semantics. - -Selective overwrite has no arithmetic — the receiver writes the trainer's exact bytes at changed positions — so it's lossless by construction and there's no notion of drift to fight with periodic base re-syncs. - -## Encoding Choice - -`--update-weight-encoding` picks how positions are packed. All three share the same on-wire layout (`__positions__` uint8 blob + `__values__` tensor + per-param manifest); decoder dispatches on the metadata. - -| value | positions | when to pick | -|---|---|---| -| `indices` | int32 absolute positions (4 bytes / nnz) | NCCL or fast intra-cluster FS (≥ ~600 MB/s) | -| `deltas` | uint16 gap-deltas with uint32 fallback (~2 bytes / nnz at 2% density) | medium FS bandwidth (~300-500 MB/s) | -| `deltas_zstd` | `deltas` wrapped in zstd L1 on disk | cross-DC / cross-region shared FS (≤ ~300 MB/s) | - -**Why gap-encoded positions are smaller**: positions come out of `mask.nonzero()` already sorted ascending. At density `p`, the expected gap between consecutive nonzero positions is `1/p`, and `P(gap > 65535) ≈ exp(-p · 65535)`. At p = 2% that's effectively zero, so uint16 fits with a uint32 per-param fallback for pathological inputs. Half the position bytes of `indices`, lossless. - -**Break-even with `indices`** at our density (~2%): `deltas` halves the positions blob (which dominates the wire); `zstd` shaves another ~35-40% on top by compressing the gap byte stream, at the cost of ~250ms/file compress + ~150ms/file decompress. The crossover with `indices` is where compress/decompress compute exceeds the bandwidth savings — empirically around 500 MB/s for `deltas` and 300 MB/s for `deltas_zstd`. - -## Why Not Colocated - -Colocated weight sync uses CUDA IPC: only a memory handle (~64 B) crosses processes. Delta encoding's "bytes saved on the wire" benefit is zero, while the bookkeeping (snapshot + diff + sparse encode) is pure overhead. Slime rejects `--update-weight-mode delta --colocate` at argparse time. +| Flag | Role | +|---|---| +| `--update-weight-disk-dir` | Shared filesystem directory the trainer publishes deltas to and the rollout hosts read from. | +| `--update-weight-local-checkpoint-dir` | Host-local (e.g. NVMe) full HF checkpoint that the delta is applied into in place. Each host materializes it from `--hf-checkpoint` at engine start. | +| `--update-weight-delta-encoding` | On-disk delta encoding: `xor` (default) or `overwrite`. | +| `--update-weight-delta-checksum` | Per-tensor integrity checksum: `xxh3-128` (default), `blake3`, or `adler32`. | + +Deltas are always zstd-compressed (level 1); profiling showed it dominates lz4 / gzip / snappy / brotli on both wire size and decompress speed for this data, so it is not a knob. + +## How it works + +1. **Seed.** On the first sync the trainer captures a CPU snapshot of every parameter — seeded + from `--hf-checkpoint`, which is exactly what each rollout host materializes its local + checkpoint from. Nothing is published; this snapshot is the base the next sync diffs against. +2. **Publish.** On every later sync the trainer diffs each gathered HF tensor against the + snapshot, encodes and compresses the change, and writes a new version directory + `weight_v{N:06d}/` under `--update-weight-disk-dir`. The directory is a canonical HF + checkpoint — `model-NNNNN.safetensors` files holding the compressed diff tensors plus a + `model.safetensors.index.json` (tensor name → file) carrying the apply metadata — so the + artifact is portable, not tied to the trainer's parallelism layout. The snapshot is then + advanced to the new values for the next diff. +3. **Apply.** Each rollout host applies the new version's delta into its local checkpoint in + place. The apply is parallelized across tensors and verified per-tensor (see Integrity). +4. **Reload.** The engines reload the patched local checkpoint through the vanilla + `update_weights_from_disk` path — they never see the delta format. + +Because the snapshot is seeded from `--hf-checkpoint` (the engine's actual base) rather than +from the current GPU weights, the scheme is correct for any model even where the Megatron→HF +round-trip is not byte-exact (e.g. trimmed vocab-padding rows in the embedding / LM head). + +## Encodings + +Both encodings are byte-level and dtype-blind, so the same path works for quantized checkpoints. +The engine reads the choice from each version's index metadata. + +- **`xor`** (default): writes `new ^ old`. Smallest wire and fastest to apply (sequential, + cache-friendly; the unchanged bytes are zeros the compressor crushes). It is an involution, + so it must be applied **exactly once** against the correct base — applying it twice reverts. +- **`overwrite`**: writes the changed positions and their new absolute values. Larger on the + wire and a less cache-friendly scattered apply, but **idempotent**: re-applying it (or + finishing a partially-applied delta) converges to the same state regardless of how many times + it runs. Use it when re-applicability matters more than wire size. + +## Integrity + +The trainer stores a per-tensor checksum of each tensor's new state in the version. After +applying, every host recomputes the checksum and **raises on any mismatch**, so a corrupt delta +or a wrong base fails loud instead of serving bad weights. The apply also refuses to run out of +order: a version only applies on top of its declared base version. + +`--update-weight-delta-checksum` selects the algorithm. The checksum is not the apply bottleneck +(the apply is decompress + XOR bound), so this is a digest-property choice, not a speed one: +`xxh3-128` (default) is the widest fast non-cryptographic digest; `blake3` is cryptographic, for +untrusted storage; `adler32` is for interop with systems that expect it. + +## Shared-filesystem visibility hooks + +On a POSIX shared filesystem (NFS, Lustre, …) no extra step is needed. Object-store-backed +volumes that need an explicit commit/refresh to make writes visible across hosts can supply two +optional hooks, loaded by import path — no vendor-specific code lives in slime: + +- `--custom-delta-pre-push-path`: called after a version's files are written, before the engines + are told to read it (e.g. commit the volume). Signature: `hook(args, version_dir, rollout_engines)`. +- `--custom-delta-pre-read-path`: called on each rollout host before it reads the delta directory + (e.g. refresh the volume). Signature: `hook(delta_dir, target_version)`. diff --git a/docs/en/advanced/external-rollout-engines.md b/docs/en/advanced/external-rollout-engines.md index 498afa0c5f..6ff6f6da72 100644 --- a/docs/en/advanced/external-rollout-engines.md +++ b/docs/en/advanced/external-rollout-engines.md @@ -79,28 +79,16 @@ This keeps the full-checkpoint directories after engines acknowledge the load. ## Update With Delta -Delta update targets large-model training/inference disaggregation across clusters or datacenters. Instead of writing a full checkpoint, the trainer keeps a pinned-CPU snapshot of the previous sync, detects byte-level changes, and sends only changed positions and values. - -Recommended for cross-cluster / shared-filesystem deployments: +Delta update targets large-model training/inference disaggregation across clusters or datacenters. Instead of writing a full checkpoint every sync, the trainer keeps a CPU snapshot of the previous sync, diffs each parameter against it, and publishes only the changed bytes; every rollout host applies the delta into its local checkpoint and reloads via the vanilla `update_weights_from_disk` endpoint. ```bash --update-weight-mode delta --update-weight-transport disk ---update-weight-encoding deltas_zstd --update-weight-disk-dir /shared/fs/delta-updates +--update-weight-local-checkpoint-dir /local/nvme/rollout-ckpt ``` -With disk transport, each sync writes sparse safetensors under `weight_v{N:06d}/`, then calls `update_weights_from_disk(load_format="delta")`. SGLang overwrites only changed positions in the current weights; unchanged positions stay in place. - -For intra-datacenter validation or bandwidth-rich environments, NCCL transport is also available: - -```bash ---update-weight-mode delta ---update-weight-transport nccl ---update-weight-encoding indices -``` - -For encoding choices, wire layout, receiver-side selective overwrite, and tuning parameters, see [Delta Weight Sync](delta-weight-sync.md). +See [Delta Weight Sync](delta-weight-sync.md) for the mechanism, encodings, integrity checks, and shared-filesystem visibility hooks. ## Deployment Checklist diff --git a/docs/en/index.rst b/docs/en/index.rst index 4617d5beaf..91f18ab4c6 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -38,9 +38,9 @@ Start by Use Case - Build agentic RL workflows: :doc:`get_started/agent` - Configure production SGLang rollout topology: :doc:`advanced/sglang-config` - Connect external rollout engines: :doc:`advanced/external-rollout-engines` +- Sync weights as byte-level deltas: :doc:`advanced/delta-weight-sync` - Use PD disaggregation: :doc:`advanced/pd-disaggregation` - Use BF16 training with FP8 rollout or FP8 KV cache: :doc:`advanced/low-precision` -- Use delta weight sync: :doc:`advanced/delta-weight-sync` - Understand CI and reliability coverage: :doc:`developer_guide/ci` - Debug, trace, and profile long-running jobs: :doc:`developer_guide/debug`, :doc:`developer_guide/trace`, :doc:`developer_guide/profiling` diff --git a/docs/zh/advanced/delta-weight-sync.md b/docs/zh/advanced/delta-weight-sync.md index f009dc954a..0aa5434472 100644 --- a/docs/zh/advanced/delta-weight-sync.md +++ b/docs/zh/advanced/delta-weight-sync.md @@ -1,107 +1,54 @@ # Delta 权重同步 -- [背景](#背景) -- [快速开始](#快速开始) -- [同步模式与传输方式](#同步模式与传输方式) -- [工作原理](#工作原理) -- [编码选择](#编码选择) -- [为何不支持 colocated](#为何不支持-colocated) +Delta 权重同步只发送两次同步之间发生变化的字节,而不是每次都写一份完整 checkpoint,以此让非 colocate 的 rollout engine 保持最新。它面向大模型、跨集群或跨数据中心的训推解耦场景——这种场景下每次都写整份 actor 权重是主要开销。 -## 背景 +它**只支持 disk transport**,并且通过**原生**的 `update_weights_from_disk` 端点 reload,因此推理引擎不需要任何 delta 相关的支持。 -slime 默认的权重同步会在每一步广播全部参数,开销随模型规模线性增长,即使每步真正变化的权重只有几个百分点。Delta 同步在内存中保留上一次同步后的参数快照(pinned CPU),只发送字节发生变化的位置。 - -最主要的应用场景是 **训练 / 推理跨数据中心解耦** —— 训练器和推理引擎运行在不同数据中心,通过共享文件系统通信(带宽通常在百 MB/s 级别)。在这种环境下,全量广播不可行,而 ~3% 密度的稀疏 delta(355B 模型约 5 GB)是可行的。同一套 delta 机制在数据中心内部跑 NCCL,作为验证基线,确认 wire 编码和 apply 逻辑正确。 - -参考资料:选择性覆写借鉴自 [arXiv:2509.19128](https://arxiv.org/abs/2509.19128),跨数据中心的动机来自 [Fireworks AI — Frontier RL Is Cheaper Than You Think](https://fireworks.ai/blog/frontier-rl-is-cheaper-than-you-think)。另一个接近生产形态的公开参考是 [Cursor Research Team 的 Composer 2 技术报告](https://arxiv.org/html/2603.24477v2):其中描述了 Cursor 与 Fireworks AI 合作运行 RL inference,并通过共享 S3、delta compression 和跨区域 inference 集群重建来同步每步训练权重。 - -## 快速开始 - -磁盘传输(跨数据中心训推解耦,主要场景): +## 配置 ```bash --update-weight-mode delta --update-weight-transport disk ---update-weight-encoding deltas_zstd # ≤ 300 MB/s 共享 FS 推荐 --update-weight-disk-dir /shared/fs/delta-updates +--update-weight-local-checkpoint-dir /local/nvme/rollout-ckpt +--update-weight-delta-encoding xor # 或: overwrite +--update-weight-delta-checksum xxh3-128 # 或: blake3, adler32 ``` -NCCL 传输(数据中心内部验证基线): - -```bash ---update-weight-mode delta ---update-weight-transport nccl ---update-weight-encoding indices # 计算最少,无压缩 -``` - -全量 checkpoint 磁盘传输(外部引擎的简单兜底路径): - -```bash ---update-weight-mode full ---update-weight-transport disk ---update-weight-disk-dir /shared/fs/full-updates -``` - -这会在每次同步时写一个完整 HF checkpoint 到 `weight_v{N:06d}/`,然后让每个 -SGLang engine 通过 `update_weights_from_disk` 重新加载。它适用于训练器无法和预启动 -rollout engine 建 NCCL group 的场景,但对大模型来说比 delta 同步重很多。 - -接收端 delta 调优(适用于 delta NCCL 和 delta 磁盘): - -```bash ---sglang-update-weight-delta-chunk-bytes $((2 * 1024 * 1024 * 1024)) # 每次 load_weights 字节上限 ---sglang-update-weight-delta-read-workers 4 # 并行 I/O 线程数(仅磁盘传输) -``` - -完整启动脚本见 [examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh](../../../examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh)。 - -## 同步模式与传输方式 - -`--update-weight-mode` 决定**发送什么**,`--update-weight-transport` 决定**如何送到 SGLang**。 +| 参数 | 作用 | +|---|---| +| `--update-weight-disk-dir` | 训练端发布 delta、rollout host 读取 delta 的共享文件系统目录。 | +| `--update-weight-local-checkpoint-dir` | host 本地(如 NVMe)的完整 HF checkpoint,delta 原地 apply 到这里。每个 host 在 engine 启动时由 `--hf-checkpoint` 物化。 | +| `--update-weight-delta-encoding` | 磁盘上的 delta 编码:`xor`(默认)或 `overwrite`。 | +| `--update-weight-delta-checksum` | 逐 tensor 完整性 checksum:`xxh3-128`(默认)、`blake3` 或 `adler32`。 | -| 同步模式 (`mode`) | 传输方式 (`transport`) | 行为 | -|---|---|---| -| `full` | `nccl` | 默认路径:通过训练器和 engine 之间的 NCCL group 广播所有 HF 权重 chunk | -| `full` | `disk` | 在 `--update-weight-disk-dir` 下写完整 HF checkpoint,然后调用 `update_weights_from_disk` | -| `delta` | `nccl` | 通过 NCCL 广播稀疏变化位置和值 | -| `delta` | `disk` | 在 `--update-weight-disk-dir` 下写稀疏 safetensors,然后调用 `update_weights_from_disk(load_format="delta")` | - -`--update-weight-delta-dir` 只保留为 `--update-weight-disk-dir` 的向后兼容 alias; -新启动脚本应该使用传输方式级别的目录参数。 +delta 始终用 zstd(level 1)压缩;profiling 显示对这类数据它在 wire 大小和解压速度上都优于 lz4 / gzip / snappy / brotli,所以不做成可配置项。 ## 工作原理 -Delta NCCL 和 delta 磁盘共用同一条发送管线、同一种 wire 布局以及同一套接收端解码器;只有每个 bucket 的承载层不同。 - -**发送端(每次同步,仅 PP 源 rank):** - -1. **求差**:通过逐字节比较 `current.view(int_dtype) != snapshot.view(int_dtype)` 检测变化。无算术、无损、与 dtype 无关。 -2. **编码**:将变化的 (位置, 值) 对打包成 `__positions__` 字节块 + `__values__` 张量 + per-param 解码 manifest。编码方式(`indices` / `deltas` / `deltas_zstd`)只影响位置如何打包,值始终按参数本身的 dtype 原样发送。 -3. **打包并发送**:每个 chunk 编码后累积至 `--update-weight-buffer-size` 字节再 flush: - - NCCL:广播 `(__positions__, __values__)`,Ray RPC 同时携带 `DeltaSpec`(编码 + per-param manifest)。 - - 磁盘:每个 flush 写一个 safetensors 文件到 `weight_v{N:06d}/` 目录,后台线程负责 I/O 和可选的 zstd 压缩,不阻塞关键路径。 -4. **更新快照**:刚发送的值在 side stream 上 D2H 拷贝,与下一个 chunk 的编码重叠。 +1. **Seed。** 第一次同步时,训练端为每个参数捕获一份 CPU snapshot——从 `--hf-checkpoint` seed,而这正是每个 rollout host 物化本地 checkpoint 的来源。此次不发布任何东西;这份 snapshot 就是下一次同步 diff 的基准。 +2. **Publish。** 之后每次同步,训练端把每个 gather 出的 HF tensor 与 snapshot 做 diff,编码、压缩,写到 `--update-weight-disk-dir` 下的新版本目录 `weight_v{N:06d}/`。该目录是一份 canonical HF checkpoint——`model-NNNNN.safetensors` 文件装着压缩后的 diff tensor,外加 `model.safetensors.index.json`(tensor 名 → 文件)承载 apply 元数据——所以这个产物是可移植的,不绑定训练端的并行 layout。随后 snapshot 推进到新值,供下次 diff。 +3. **Apply。** 每个 rollout host 把新版本的 delta 原地 apply 进它的本地 checkpoint。apply 在 tensor 之间并行,并逐 tensor 校验(见“完整性”)。 +4. **Reload。** engine 通过原生 `update_weights_from_disk` 路径 reload 打过补丁的本地 checkpoint——它从不接触 delta 格式。 -**同步结束(仅磁盘):** 写 `DONE` 标记,rank 0 对每个引擎触发一次 HTTP push,所有引擎确认后清理目录。 +由于 snapshot 是从 `--hf-checkpoint`(engine 真正的 base)seed,而不是从当前 GPU 权重 seed,即使 Megatron→HF 往返不是逐字节相等(例如 embedding / LM head 中被裁掉的 vocab padding 行),该方案对任意模型也都正确。 -**接收端:** 两种传输最终都进入同一个 `_apply_delta_payload(encoding, params, positions, values)` 帮助函数。它把每个参数的切片解码成全形状张量,未变化位置填 NaN,然后通过 `model.load_weights(...)` 应用;过程中 `_delta_apply_context` 替换 `Tensor.copy_` / `Tensor.fill_`,对参数存储执行 NaN 掩码覆写。辅助写入(scratch buffer、fp8 scale、MoE bias 等通过 `post_load_weights` 写入的派生张量)保留正常语义。 +## 编码 -选择性覆写没有任何算术运算 —— 接收端在变化位置直接写入训练端的精确字节 —— 因此天然无损,也不存在数值漂移问题,无需周期性 base 同步。 +两种编码都是字节级、与 dtype 无关的,所以量化 checkpoint 也走同一条路径。engine 从每个版本的 index 元数据读取所用编码。 -## 编码选择 +- **`xor`**(默认):写 `new ^ old`。wire 最小、apply 最快(顺序访问、对 cache 友好;未变化的字节是 0,被压缩器压到极小)。它是一个对合(involution),所以必须**恰好对正确的 base apply 一次**——apply 两次会还原。 +- **`overwrite`**:写变化的位置及其新的绝对值。wire 更大、apply 是对 cache 不友好的分散写,但**幂等**:重复 apply(或把部分 apply 的 delta 补完)无论执行多少次都收敛到同一状态。当“可重复 apply”比 wire 大小更重要时用它。 -`--update-weight-encoding` 决定位置如何打包。三种编码共用同一种 wire 布局(`__positions__` uint8 块 + `__values__` 张量 + per-param manifest),解码端根据 metadata 分派。 +## 完整性 -| 取值 | 位置编码 | 推荐场景 | -|---|---|---| -| `indices` | int32 绝对位置(4 字节 / nnz) | NCCL 或高速集群内 FS(≥ ~600 MB/s) | -| `deltas` | uint16 增量(异常时 uint32 兜底,2% 密度下约 2 字节 / nnz) | 中等带宽 FS(~300-500 MB/s) | -| `deltas_zstd` | `deltas` 文件再用 zstd L1 压缩 | 跨数据中心 / 跨区共享 FS(≤ ~300 MB/s) | +训练端把每个 tensor 新状态的逐 tensor checksum 存进版本里。apply 之后每个 host 重新计算 checksum,**任何不匹配都会 raise**,所以损坏的 delta 或错误的 base 会直接报错失败,而不会把坏权重提供出去。apply 还拒绝乱序执行:一个版本只会在它声明的 base 版本之上 apply。 -**为何 gap 编码更省**:`mask.nonzero()` 返回的位置已经升序排列。密度 `p` 时连续非零位置的期望间隔为 `1/p`,且 `P(gap > 65535) ≈ exp(-p · 65535)`,p = 2% 时这个概率实际上为零,所以 uint16 完全够用,uint32 仅作 per-param 兜底。位置开销比 `indices` 减半,且无损。 +`--update-weight-delta-checksum` 选择算法。checksum 不是 apply 的瓶颈(apply 受解压 + XOR 限制),所以这是一个 digest 属性的选择,而非速度选择:`xxh3-128`(默认)是最宽的快速非加密 digest;`blake3` 是加密 digest,用于不可信存储;`adler32` 用于与期望它的系统互操作。 -**`deltas_zstd` 的额外收益**:在 gap 字节流上做 zstd L1 还能再减少 ~35-40%,代价是每文件约 250ms 压缩 + 150ms 解压。当共享 FS 带宽 ≤ 300 MB/s 时,带宽节省超过额外计算开销。 +## 共享文件系统可见性 hook -## 为何不支持 colocated +在 POSIX 共享文件系统(NFS、Lustre……)上不需要额外步骤。对于需要显式 commit/refresh 才能让写入跨 host 可见的对象存储卷,可以提供两个可选 hook(通过 import 路径加载——slime 里不存在任何厂商特定代码): -Colocated 同步通过 CUDA IPC:进程间传递的只是一个内存句柄(~64 B)。Delta 编码的"wire 节省"在此为零,而其簿记开销(快照 + 求差 + 稀疏编码)反而是纯损失。slime 在参数校验阶段拒绝 `--update-weight-mode delta --colocate`。 +- `--custom-delta-pre-push-path`:在一个版本的文件写完之后、通知 engine 读取之前调用(例如 commit volume)。签名:`hook(args, version_dir, rollout_engines)`。 +- `--custom-delta-pre-read-path`:在每个 rollout host 读取 delta 目录之前调用(例如 refresh volume)。签名:`hook(delta_dir, target_version)`。 diff --git a/docs/zh/advanced/external-rollout-engines.md b/docs/zh/advanced/external-rollout-engines.md index 9aae0ef5ec..0007bcff69 100644 --- a/docs/zh/advanced/external-rollout-engines.md +++ b/docs/zh/advanced/external-rollout-engines.md @@ -79,28 +79,16 @@ full checkpoint update from disk 是 external 场景最简单的兜底路径: ## Update With Delta -delta update 面向大模型、跨集群或跨数据中心训推解耦。它不写完整 checkpoint,而是在训练端保留上一次同步后的 pinned CPU snapshot,逐字节检测变化,只发送变化位置和值。 - -跨集群 / 共享文件系统推荐: +delta update 面向大模型、跨集群或跨数据中心训推解耦。它不每次都写完整 checkpoint,而是在训练端保留上一次同步的 CPU snapshot,逐参数比对,只发布变化的字节;每个 rollout host 把 delta apply 进自己的本地 checkpoint,再通过原生 `update_weights_from_disk` 端点 reload。 ```bash --update-weight-mode delta --update-weight-transport disk ---update-weight-encoding deltas_zstd --update-weight-disk-dir /shared/fs/delta-updates +--update-weight-local-checkpoint-dir /local/nvme/rollout-ckpt ``` -在 disk transport 下,每次同步会写一组稀疏 safetensors 到 `weight_v{N:06d}/`,然后调用 `update_weights_from_disk(load_format="delta")`。SGLang 侧只把变化位置覆写到当前权重上,不变位置保持原值。 - -在同一数据中心内做实现验证或带宽不紧张时,也可以用 NCCL transport: - -```bash ---update-weight-mode delta ---update-weight-transport nccl ---update-weight-encoding indices -``` - -编码如何选择、delta wire layout、接收端 selective overwrite 以及调优参数见 [Delta 权重同步](delta-weight-sync.md)。 +机制、编码、完整性校验以及共享文件系统可见性 hook 详见 [Delta 权重同步](delta-weight-sync.md)。 ## 部署检查清单 diff --git a/docs/zh/index.rst b/docs/zh/index.rst index edcd8be376..b55f1ed88e 100644 --- a/docs/zh/index.rst +++ b/docs/zh/index.rst @@ -38,9 +38,9 @@ slime 的设计目标,是让这两大能力彼此强化,同时避免把系 - 构建 agentic RL workflow::doc:`get_started/agent` - 配置生产级 SGLang rollout topology::doc:`advanced/sglang-config` - 接入 external rollout engines::doc:`advanced/external-rollout-engines` +- 以字节级 delta 同步权重::doc:`advanced/delta-weight-sync` - 使用 PD disaggregation::doc:`advanced/pd-disaggregation` - 使用 BF16 训练 + FP8 rollout 或 FP8 KV cache::doc:`advanced/low-precision` -- 使用 delta weight sync::doc:`advanced/delta-weight-sync` - 了解 CI 和可靠性覆盖::doc:`developer_guide/ci` - 调试、trace 和 profiling 长时间任务::doc:`developer_guide/debug`、:doc:`developer_guide/trace`、:doc:`developer_guide/profiling` diff --git a/examples/README.md b/examples/README.md index 128b1562d4..4618f6414c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -11,7 +11,7 @@ These examples provide concrete examples to leverage slime in your own RL workfl - **[low_precision](./low_precision)**: Examples of FP8 training and inference for improved throughput and stability. - **[multi_agent](./multi_agent)**: Example of running multi-agent RL with `slime`. - **[on_policy_distillation](./on_policy_distillation)**: Example implementation for on-policy distillation, extending the reinforcement learning pipeline to support teacher–student distillation directly within on-policy training. -- **[delta_weight_sync](./delta_weight_sync)**: Non-colocated weight sync that ships only changed positions + values over disk (training/inference disaggregation) or NCCL. +- **[delta_weight_sync](./delta_weight_sync)**: Non-colocated weight sync that ships only the changed bytes over a shared filesystem (training/inference disaggregation), reloading via the vanilla `update_weights_from_disk` path. - **[reproducibility](./reproducibility)**: Guides on achieving bitwise experiment reproduction using deterministic modes. - **[retool](./retool)**: Demonstrates the retool functionality for tool-enabled language model generation. - **[search-r1](./search-r1)**: A minimal reproduction of Search-R1, featuring multi-turn conversation and tool-calling. diff --git a/examples/delta_weight_sync/README.md b/examples/delta_weight_sync/README.md index b2c7521578..0879ba9fcb 100644 --- a/examples/delta_weight_sync/README.md +++ b/examples/delta_weight_sync/README.md @@ -1,67 +1,40 @@ # Delta Weight Sync -Non-colocated weight sync that ships only changed positions + values instead of every parameter. Two transports over one wire format and one receiver-side decoder: +Non-colocated weight sync that ships only the **changed bytes** between two syncs instead of a +full checkpoint, for training/inference disaggregation across clusters or datacenters. The +trainer publishes per-tensor deltas to a shared filesystem as a canonical HF checkpoint +directory; each rollout host applies them into a host-local checkpoint and the engines reload +through the ordinary `update_weights_from_disk` path — the inference engine needs no +delta-specific support. -- **Disk** (the point) — write per-flush safetensors to a shared filesystem; one HTTP push per sync. Designed for **training/inference disaggregation** across datacenters where bandwidth between trainer and rollout is on the order of 100s of MB/s. -- **NCCL** (the baseline) — broadcast each per-flush bucket directly. Used intra-datacenter to validate that the wire encoding and apply logic are correct, separate from any shared-FS variable. +See [Delta Weight Sync](../../docs/en/advanced/delta-weight-sync.md) for the full mechanism, +encodings, integrity checks, and shared-filesystem visibility hooks. -Both modes are lossless by construction (selective overwrite via NaN sentinel; no arithmetic). +## Try it -## Files +`run-glm4.7-30B-A3B-delta.sh` runs the disk delta path on GLM-4.7-Flash, non-colocated across a +2-node (16-GPU) Ray cluster. See its header for prerequisites. -- `run-glm4.7-355B-A32B-delta.sh`: 16-node (8 actor + 8 rollout) GLM-4.7-355B-A32B launcher. Disk transport active by default; NCCL block commented below it. +## Minimal flags -## Usage +Add to a non-colocated training run (the trainer and engines only need to share the filesystem +at `--update-weight-disk-dir`): ```bash -bash examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh +--update-weight-mode delta \ +--update-weight-transport disk \ +--update-weight-disk-dir /shared/fs/delta-updates \ +--update-weight-local-checkpoint-dir /local/nvme/rollout-ckpt \ +--update-weight-delta-encoding xor \ +--update-weight-delta-checksum xxh3-128 ``` -**Disk (default):** +- `--update-weight-disk-dir` — shared directory the trainer writes deltas to and the hosts read. +- `--update-weight-local-checkpoint-dir` — host-local full HF checkpoint the delta patches in + place; materialized from `--hf-checkpoint` at engine start. +- `--update-weight-delta-encoding` — `xor` (smallest/fastest) or `overwrite` (idempotent). +- `--update-weight-delta-checksum` — `xxh3-128` (default), `blake3`, or `adler32`. -```bash -DELTA_ARGS=( - --update-weight-mode delta - --update-weight-transport disk - --update-weight-encoding deltas_zstd - --update-weight-disk-dir /shared/fs/delta-updates -) -``` - -**NCCL (baseline):** - -```bash -DELTA_ARGS=( - --update-weight-mode delta - --update-weight-transport nccl - --update-weight-encoding indices -) -``` - -Receiver-side byte cap (both transports): - -```bash ---sglang-update-weight-delta-chunk-bytes $((2 * 1024 * 1024 * 1024)) -``` - -See [docs/en/advanced/delta-weight-sync.md](../../docs/en/advanced/delta-weight-sync.md) for the wire protocol, encoding choice, and design. - -## Results - -W&B traces comparing delta sync against the full-sync baseline on GLM-4.7-355B-A32B / DAPO-Math-17k. - -![Raw reward](./raw_reward.png) - -![Train/rollout logprob abs diff](./train_rollout_logprob_abs_diff.png) - -![Update weights time](./update_weights_time.png) - -> **Note on the small curve-to-curve gap.** RL training is inherently non-deterministic (cuBLAS reductions, FlashAttention split-K, NCCL all-reduce ordering, dynamic-batch token assignment). Two identically-configured *full*-sync runs would diverge the same way. Delta sync's selective overwrite is bit-exact with full sync per step (no arithmetic, no drift); the trajectory matches, the bits don't. - -![Update weights density](./update_weights_density.png) - -*Per-sync change density (`perf/update_weights_density`) — fraction of weight positions that moved between consecutive syncs. Sync 0 is omitted: it's the snapshot-seeding pass with density = 1.0, which would compress the y-axis.* - -## Why these encoding defaults - -Per-sync change density during RL fine-tuning at conservative LRs sits around **2-3%** ([arXiv:2602.03839](https://arxiv.org/pdf/2602.03839) reports ~1% on a related setup; we measured ~2-3% on this run). Below the 3.125% break-even point, gap-encoded positions are smaller than absolute indices — the disk default `deltas_zstd` adds zstd L1 on top to squeeze the gap byte stream further (~35-40%), which is the right tradeoff when shared-FS bandwidth is ≤ 300 MB/s. Intra-datacenter NCCL has no bandwidth pressure, so `indices` (lowest compute, biggest payload) is the cleaner default there. +For object-store-backed volumes that need an explicit commit/refresh to make writes visible +across hosts, supply `--custom-delta-pre-push-path` / `--custom-delta-pre-read-path` (no +vendor-specific code lives in slime; see the doc). diff --git a/examples/delta_weight_sync/run-glm4.7-30B-A3B-delta.sh b/examples/delta_weight_sync/run-glm4.7-30B-A3B-delta.sh new file mode 100644 index 0000000000..a399b20bbe --- /dev/null +++ b/examples/delta_weight_sync/run-glm4.7-30B-A3B-delta.sh @@ -0,0 +1,109 @@ +#!/bin/bash +# Disk delta weight-sync demo on GLM-4.7-Flash (30B-A3B), non-colocated, 2 nodes x 8 GPU. +# The trainer publishes per-tensor deltas to --update-weight-disk-dir as a canonical HF directory; +# each rollout host applies them into --update-weight-local-checkpoint-dir and reloads via the +# vanilla update_weights_from_disk path. +# +# Prerequisites: +# - A 2-node (16-GPU) Ray cluster, this script run on the head node. +# - GLM-4.7-Flash HF checkpoint + its torch_dist conversion (tools/convert_hf_to_torch_dist.py). +# - dapo-math-17k.jsonl. +# - --update-weight-disk-dir on a filesystem both nodes share. On an object-store-backed volume +# that needs an explicit commit/refresh to surface writes across hosts, also pass +# --custom-delta-pre-push-path / --custom-delta-pre-read-path (see the doc). + +set -ex +export PYTHONUNBUFFERED=1 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../scripts/models/glm4.7-30B-A3B.sh" + +MODEL_DIR=${MODEL_DIR:-/root/models/GLM-4.7-Flash} +DATA_PATH=${DATA_PATH:-/root/datasets/dapo-math-17k/dapo-math-17k.jsonl} + +CKPT_ARGS=( + --hf-checkpoint "${MODEL_DIR}" + --ref-load "${MODEL_DIR}_torch_dist" +) + +ROLLOUT_ARGS=( + --prompt-data "${DATA_PATH}" + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 3 + --rollout-batch-size 32 + --n-samples-per-prompt 4 + --rollout-max-response-len 8192 + --global-batch-size 128 +) + +# Disk delta weight sync (the point of this example). +WEIGHT_SYNC_ARGS=( + --update-weight-mode delta + --update-weight-transport disk + --update-weight-disk-dir /shared/fs/glm47-delta-updates + --update-weight-local-checkpoint-dir /local/nvme/glm47-rollout-ckpt + --update-weight-delta-encoding xor + --update-weight-delta-checksum xxh3-128 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --pipeline-model-parallel-size 2 + --context-parallel-size 2 + --expert-model-parallel-size 8 + --expert-tensor-parallel-size 1 + --sequence-parallel + --use-dynamic-batch-size + --max-tokens-per-gpu 32768 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.0 + --kl-loss-type low_var_kl +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 8 + --sglang-mem-fraction-static 0.8 + --sglang-enable-dp-attention + --sglang-dp-size 8 +) + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" + } +}" + +# Non-colocated: 16 actor GPUs (2 x 8) train while a 16-GPU rollout pool generates (delta mode +# requires non-colocation). +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 2 \ + --actor-num-gpus-per-node 8 \ + --rollout-num-gpus 16 \ + ${MODEL_ARGS[@]} \ + "${CKPT_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${WEIGHT_SYNC_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" diff --git a/examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh b/examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh deleted file mode 100755 index 9df77c4eff..0000000000 --- a/examples/delta_weight_sync/run-glm4.7-355B-A32B-delta.sh +++ /dev/null @@ -1,192 +0,0 @@ -#!/bin/bash - -# Non-colocated GLM-4.7-355B-A32B with delta weight sync. -# 8 actor nodes (TP=8, PP=4, EP=16) + 64 rollout GPUs (8 H100 nodes worth), 16 nodes total. -# Disk transport is active by default; the NCCL block below it is commented out. - -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - -set -ex - -export PYTHONUNBUFFERED=1 -unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 -fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -source "/root/slime/scripts/models/glm4.5-355B-A32B.sh" - -CKPT_ARGS=( - --hf-checkpoint /root/GLM-4.7-355B-A32B - --ref-load /root/GLM-4.7-355B-A32B_torch_dist/ -) - -ROLLOUT_ARGS=( - --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type deepscaler - --num-rollout 3000 - --rollout-batch-size 64 - --n-samples-per-prompt 8 - --rollout-max-response-len 8192 - --rollout-temperature 1 - - --num-steps-per-rollout 4 - --balance-data - --rollout-stop-token-ids 151329 151336 151338 -) - -EVAL_ARGS=( - --eval-interval 20 - --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 8 - --eval-max-response-len 8192 - --eval-top-p 1 -) - -PERF_ARGS=( - --tensor-model-parallel-size 8 - --sequence-parallel - --pipeline-model-parallel-size 4 - --context-parallel-size 2 - --expert-model-parallel-size 16 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - --use-dynamic-batch-size - --max-tokens-per-gpu 16384 -) - -GRPO_ARGS=( - --advantage-estimator gspo - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --kl-coef 0.00 - --entropy-coef 0.00 - --eps-clip 1e-4 - --eps-clip-high 2e-4 - --use-tis -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 - - --optimizer-cpu-offload - --overlap-cpu-optimizer-d2h-h2d - --use-precision-aware-optimizer -) - -WANDB_ARGS=( - # --use-wandb - # --wandb-project slime-delta - # --wandb-group glm4.7-355B-delta -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 32 - --sglang-mem-fraction-static 0.7 - --sglang-enable-dp-attention - --sglang-dp-size 4 - --sglang-ep-size 32 - --sglang-enable-dp-lm-head - --sglang-moe-dense-tp-size 1 - - # Receiver batches up to this many bytes per model.load_weights call. Bigger - # amortizes per-call cost (name resolution, MoE expert remap) but raises peak HBM. - --sglang-update-weight-delta-chunk-bytes $((2 * 1024 * 1024 * 1024)) - - # Max parallel I/O threads for reading delta files from disk (disk transport only). - --sglang-update-weight-delta-read-workers 4 - - # mtp - --sglang-speculative-algorithm EAGLE - --sglang-speculative-num-steps 3 - --sglang-speculative-eagle-topk 1 - --sglang-speculative-num-draft-tokens 4 -) - -# Delta weight sync. Pick one of the two blocks below. - -# ── Disk (default) — for training/inference disaggregation across datacenters ──── -# `deltas_zstd` is the right pick when shared-FS bandwidth is ≤ ~300 MB/s. -DELTA_ARGS=( - --update-weight-mode delta - --update-weight-transport disk - --update-weight-encoding deltas_zstd - --update-weight-disk-dir /shared/fs/delta-updates -) - -# ── NCCL (baseline) — intra-datacenter, no shared FS ──────────────────────────── -# DELTA_ARGS=( -# --update-weight-mode delta -# --update-weight-transport nccl -# --update-weight-encoding indices -# ) - -MISC_ARGS=( - --attention-dropout 0.0 - --hidden-dropout 0.0 - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - --attention-backend flash - --moe-token-dispatcher-type flex - --moe-enable-deepep - --update-weight-buffer-size $((2 * 1024 * 1024 * 1024)) -) - -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - -RUNTIME_ENV_JSON=$(cat <=0.2.3 tensorboard transformers wandb +xxhash # disk delta weight sync (checksum + codec) +zstandard diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 9830cc19cf..5f0ec234f9 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -143,12 +143,14 @@ def init( ), "--update-weight-mode=delta is not supported with --colocate" update_weight_cls = UpdateWeightFromTensor elif self.args.update_weight_mode == "delta": - # Lazy import: the delta module pulls DeltaEncoding/DeltaParam/DeltaSpec from - # sglang, which only exist on newer images. Importing eagerly would break old - # images even when delta mode is unused. - from .update_weight.update_weight_from_distributed_delta import UpdateWeightFromDistributedDelta + # Delta sync is disk-transport only: each host applies the published deltas into + # its local checkpoint and the engines reload via vanilla update_weights_from_disk. + assert ( + self.args.update_weight_transport == "disk" + ), "--update-weight-mode=delta requires --update-weight-transport=disk" + from .update_weight.update_weight_from_disk_delta import UpdateWeightFromDiskDelta - update_weight_cls = UpdateWeightFromDistributedDelta + update_weight_cls = UpdateWeightFromDiskDelta else: assert self.args.update_weight_mode == "full" if self.args.update_weight_transport == "disk": @@ -612,9 +614,14 @@ def update_weights(self) -> None: ray.get(self.rollout_manager.recover_updatable_engines.remote()) dist.barrier(group=get_gloo_group()) - rollout_engines, rollout_engine_lock, num_new_engines, engine_gpu_counts, engine_gpu_offsets = ray.get( - self.rollout_manager.get_updatable_engines_and_lock.remote() - ) + ( + rollout_engines, + rollout_engine_lock, + num_new_engines, + engine_gpu_counts, + engine_gpu_offsets, + all_engine_actors, + ) = ray.get(self.rollout_manager.get_updatable_engines_and_lock.remote()) reconnect_rollout_engines = self.args.offload_train and self.args.use_critic and not self.args.colocate @@ -634,6 +641,7 @@ def update_weights(self) -> None: rollout_engine_lock, engine_gpu_counts=engine_gpu_counts, engine_gpu_offsets=engine_gpu_offsets, + all_engine_actors=all_engine_actors, ) dist.barrier(group=get_gloo_group()) if dist.get_rank() == 0: diff --git a/slime/backends/megatron_utils/sglang.py b/slime/backends/megatron_utils/sglang.py index 801217310d..97c82a31cd 100644 --- a/slime/backends/megatron_utils/sglang.py +++ b/slime/backends/megatron_utils/sglang.py @@ -13,15 +13,6 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions -try: - from sglang.srt.managers.io_struct import DeltaEncoding, DeltaParam, DeltaSpec -except ImportError: - # Older sglang images don't have delta-sync io_struct. Only --update-weight-mode=delta - # needs these; the default full-sync path runs without them. - DeltaEncoding = None - DeltaParam = None - DeltaSpec = None - from sglang.srt.utils import MultiprocessingSerializer @@ -37,7 +28,4 @@ "monkey_patch_torch_reductions", "MultiprocessingSerializer", "FlattenedTensorBucket", - "DeltaEncoding", - "DeltaParam", - "DeltaSpec", ] diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_disk.py b/slime/backends/megatron_utils/update_weight/update_weight_from_disk.py index bb0e0df72a..a5f81d9263 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_disk.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_disk.py @@ -46,6 +46,7 @@ def connect_rollout_engines( rollout_engine_lock: ActorHandle, engine_gpu_counts: Sequence[int] | None = None, engine_gpu_offsets: Sequence[int] | None = None, + all_engine_actors: Sequence[ActorHandle] | None = None, ) -> None: self.rollout_engines = rollout_engines self.rollout_engine_lock = rollout_engine_lock diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py b/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py new file mode 100644 index 0000000000..0a56920724 --- /dev/null +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +import json +import logging +import os +import queue +import shutil +from argparse import Namespace +from collections import deque +from collections.abc import Callable, Mapping, Sequence +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +import ray +import safetensors.numpy +import torch +import torch.distributed as dist +import zstandard +from megatron.core import mpu +from ray.actor import ActorHandle + +from slime.utils.disk_delta import NUM_WORKERS, checksum, make_tensor_reader, overwrite_encode +from slime.utils.distributed_utils import get_gloo_group + +from .update_weight_from_distributed import UpdateWeightFromDistributed + +logger = logging.getLogger(__name__) + + +class UpdateWeightFromDiskDelta(UpdateWeightFromDistributed): + """ + Delta weight sync over a shared filesystem. PP-src ranks diff each gathered HF tensor against + a CPU snapshot of the previous sync and publish the changes as a canonical HF checkpoint dir; + every rollout host applies the delta into its local checkpoint and reloads via the ordinary + update_weights_from_disk path, so sglang needs no delta support. + """ + + def __init__( + self, + args: Namespace, + model: Sequence[torch.nn.Module], + weights_getter: Callable[[], Mapping[str, torch.Tensor]], + *, + model_name: str, + quantization_config: dict[str, int | str | list[str]] | None, + ) -> None: + super().__init__(args, model, weights_getter, model_name=model_name, quantization_config=quantization_config) + self.delta_dir = args.update_weight_disk_dir + os.makedirs(self.delta_dir, exist_ok=True) + self.delta_encoding = args.update_weight_delta_encoding + self.checksum_algorithm = args.update_weight_delta_checksum + self._snapshot: dict[str, np.ndarray] = {} + self._baseline_captured = False + self._commit_hook: Callable | None = None + if args.custom_delta_pre_push_path: + from slime.utils.misc import load_function + + self._commit_hook = load_function(args.custom_delta_pre_push_path) + + def connect_rollout_engines( + self, + rollout_engines: Sequence[ActorHandle], + rollout_engine_lock: ActorHandle, + engine_gpu_counts: Sequence[int] | None = None, + engine_gpu_offsets: Sequence[int] | None = None, + all_engine_actors: Sequence[ActorHandle] | None = None, + ) -> None: + # The local checkpoint is host-local, so every host applies its own copy: + # all_engine_actors is one actor per host, vs rollout_engines (node 0 only). The + # rollout_engine_lock the NCCL path uses isn't needed — a per-host flock serializes applies. + self.rollout_engines = rollout_engines + self.all_engine_actors = list(all_engine_actors or rollout_engines) + self._is_pp_src_rank = ( + mpu.get_data_parallel_rank(with_context_parallel=True) == 0 and mpu.get_tensor_model_parallel_rank() == 0 + ) + + def disconnect_rollout_engines(self) -> None: + pass # no NCCL groups to tear down + + @torch.no_grad() + def update_weights(self) -> None: + # The first call only captures the baseline snapshot the next sync diffs against. + if not self._baseline_captured: + self._capture_baseline() + self._baseline_captured = True + return + + self.weight_version += 1 + if dist.get_rank() == 0: + ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) + ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) + dist.barrier(group=get_gloo_group()) + + self._publish() + self._reload_engines() + self._record_metrics() + + def _capture_baseline(self) -> None: + """Capture the baseline snapshot the first delta diffs against (no publish), and clear any + stale stream from a prior run. Seeds from hf_checkpoint — what each host materializes its + base from — so the invariant ``snapshot == engine base`` holds even where the megatron->HF + round-trip trims vocab-padding rows (embed/lm_head). A tensor absent there (rare) falls back + to the gathered value.""" + # a prior run's versions would apply against the wrong base; start the dir clean + if dist.get_rank() == 0: + shutil.rmtree(self.delta_dir, ignore_errors=True) + os.makedirs(self.delta_dir, exist_ok=True) + if self._commit_hook is not None: + self._commit_hook(self.args, self.delta_dir, list(self.rollout_engines)) + dist.barrier(group=get_gloo_group()) + + read_hf = make_tensor_reader(self.args.hf_checkpoint) # index the HF headers once + for name, tensor in self._iter_hf_tensors(): + try: + self._snapshot[name] = read_hf(name) + except KeyError: + self._snapshot[name] = tensor.detach().cpu().contiguous().view(torch.uint8).numpy().reshape(-1) + logger.warning("seed: %s absent from hf_checkpoint; seeding from current weights", name) + if dist.get_rank() == 0: + logger.info( + "[disk delta] captured baseline snapshot of %d tensors from %s", + len(self._snapshot), + self.args.hf_checkpoint, + ) + + def _publish(self) -> None: + """Encode this version's changed tensors (PP-src ranks), then write it as a canonical HF dir.""" + self._encode_delta() + dist.barrier(group=get_gloo_group()) + self._write_delta_files() + + def _write_delta_files(self) -> None: + """Write this rank's changed tensors as one canonical model-NNNNN.safetensors, and on rank + 0 the HF index. The sequential file numbers and the index are coordinated over gloo (small + object gathers), not the filesystem — a shared volume may not surface one rank's writes to + another until commit.""" + group = get_gloo_group() + world, rank = dist.get_world_size(), dist.get_rank() + + # number the files sequentially across only the ranks that have one (no gaps) + counts: list = [None] * world + dist.all_gather_object(counts, int(bool(self._delta)), group=group) + offset, total = sum(counts[:rank]), sum(counts) + + fname = None + self.wire_bytes = 0 + if self._delta: + fname = f"model-{offset:05d}-of-{total:05d}.safetensors" + blob = safetensors.numpy.save(self._delta, metadata=self._checksums) + self.wire_bytes = len(blob) + _atomic_write(os.path.join(self._version_dir, fname), blob) + + maps: list = [None] * world + dist.all_gather_object(maps, {name: fname for name in self._delta}, group=group) + if rank == 0: + index = { + "metadata": { + "version": f"{self.weight_version:06d}", + "base_version": f"{self.weight_version - 1:06d}", + "delta_encoding": self.delta_encoding, + "compression_format": "zstd", + "checksum_format": self.checksum_algorithm, + }, + "weight_map": {name: f for m in maps for name, f in m.items()}, + } + _atomic_write(os.path.join(self._version_dir, "model.safetensors.index.json"), json.dumps(index).encode()) + dist.barrier(group=group) + + def _reload_engines(self) -> None: + """Commit the published files, have each host apply the delta, then reload the engines.""" + if self._commit_hook is not None: + self._commit_hook(self.args, self._version_dir, list(self.rollout_engines)) + dist.barrier(group=get_gloo_group()) + if dist.get_rank() == 0: + ray.get([actor.sync_weights.remote(self.weight_version) for actor in self.all_engine_actors]) + ray.get( + [ + engine.update_weights_from_disk.remote( + model_path=self.args.update_weight_local_checkpoint_dir, + weight_version=str(self.weight_version), + ) + for engine in self.rollout_engines + ] + ) + ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + dist.barrier(group=get_gloo_group()) + + def _iter_hf_tensors(self): + """Yield (name, gathered HF tensor) for every param: base-class TP then EP gather passes.""" + for chunk_iter in (self._iter_non_expert_chunks(), self._iter_expert_chunks()): + for hf_chunk in chunk_iter: + yield from hf_chunk + dist.barrier(group=get_gloo_group()) + + def _encode_delta(self) -> None: + """Diff each gathered HF tensor against the snapshot, keeping the changed ones (compressed) + in self._delta with their checksums. The GPU->CPU gather is pipelined into a compute pool: + the main loop copies one tensor to a pinned buffer and submits it; pool workers diff and + compress in parallel (each is a few big GIL-releasing numpy/zstd calls).""" + self._version_dir = os.path.join(self.delta_dir, f"weight_v{self.weight_version:06d}") + if self._is_pp_src_rank: + os.makedirs(self._version_dir, exist_ok=True) + snapshot = self._snapshot + self._delta: dict[str, np.ndarray] = {} # changed tensor name -> compressed diff + self._checksums: dict[str, str] = {} # changed tensor name -> new-state checksum + self.changed_bytes = self.total_bytes = 0 + + # Pinned host-buffer pool: a pinned non_blocking GPU->CPU copy is far faster than .cpu(). + max_bytes = max((int(v.nbytes) for v in snapshot.values()), default=0) + free_q: queue.Queue = queue.Queue() + use_pinned = True + try: + for _ in range(max(4, min(2 * NUM_WORKERS, (32 << 30) // max(max_bytes, 1)))): + free_q.put(torch.empty(max_bytes, dtype=torch.uint8, pin_memory=True)) + except RuntimeError as e: # low memlock limit + logger.warning("pinned host buffers unavailable (%s); using pageable .cpu()", e) + use_pinned = False + + def diff_and_compress(name, buf, nbytes, pinned): + if pinned: # copy out and free the pinned buffer before the heavy diff/compress + new = np.empty(nbytes, dtype=np.uint8) + np.copyto(new, buf.numpy()[:nbytes]) + free_q.put(buf) + else: + new = buf + old = snapshot[name] + if self.delta_encoding == "xor": + diff = new ^ old + changed = int(np.count_nonzero(diff)) + elif self.delta_encoding == "overwrite": + mask = new != old + changed = int(np.count_nonzero(mask)) + diff = overwrite_encode(new, mask) + else: + raise ValueError(f"unknown delta encoding {self.delta_encoding!r}") + if not changed: + return name, new, None, None, 0 + compressed = np.frombuffer(zstandard.ZstdCompressor(level=1).compress(diff), dtype=np.uint8) + return name, new, compressed, checksum(self.checksum_algorithm, new), changed + + def collect(fut): + name, new, compressed, digest, changed = fut.result() + snapshot[name] = new # becomes the next sync's base + if changed: + self.changed_bytes += changed + self._delta[name] = compressed + self._checksums[name] = digest + + pool = ThreadPoolExecutor(max_workers=NUM_WORKERS) + inflight: deque = deque() + try: + for name, tensor in self._iter_hf_tensors(): + flat = tensor.detach().contiguous().view(torch.uint8).reshape(-1) + nbytes = int(flat.numel()) + if use_pinned and nbytes <= max_bytes: + buf = free_q.get() # blocks when all buffers are in flight -> backpressures the gather + buf[:nbytes].copy_(flat, non_blocking=True) + torch.cuda.current_stream().synchronize() + payload, pinned = buf, True + else: + payload, pinned = flat.cpu().numpy(), False + self.total_bytes += nbytes + inflight.append(pool.submit(diff_and_compress, name, payload, nbytes, pinned)) + if len(inflight) >= 2 * NUM_WORKERS: + collect(inflight.popleft()) + while inflight: + collect(inflight.popleft()) + finally: + pool.shutdown() + + def _record_metrics(self) -> None: + """All-reduce the byte counts and record changed-fraction + wire size; the actor drains + update_weight_metrics onto the step log.""" + counts = torch.tensor( + [self.changed_bytes, self.total_bytes, self.wire_bytes], + dtype=torch.int64, + device=torch.cuda.current_device(), + ) + dist.all_reduce(counts) + changed, total, wire = counts.tolist() + self.update_weight_metrics["perf/update_weights_density"] = changed / max(total, 1) + self.update_weight_metrics["perf/update_weights_wire_bytes"] = wire + + +def _atomic_write(path: str, data: bytes) -> None: + tmp = path + ".tmp" + with open(tmp, "wb") as f: + f.write(data) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, path) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 1ab48fb974..14698c4309 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -16,7 +16,6 @@ from slime.utils.distributed_utils import get_gloo_group, init_process_group from ..megatron_to_hf import convert_to_hf -from ..sglang import DeltaSpec from .common import all_gather_param, named_params_and_buffers @@ -60,6 +59,7 @@ def connect_rollout_engines( rollout_engine_lock: ActorHandle, engine_gpu_counts: Sequence[int] | None = None, engine_gpu_offsets: Sequence[int] | None = None, + all_engine_actors: Sequence[ActorHandle] | None = None, ) -> None: """ Create NCCL "slime-pp_{pp_rank}" if PP source (DP=TP=0). Lock prevents concurrent broadcasts. @@ -174,18 +174,12 @@ def _iter_non_expert_chunks(self) -> Iterator[list[tuple[str, torch.Tensor]]]: if buffer: yield buffer - def _iter_expert_chunks( - self, - params: Iterator[tuple[str, torch.Tensor]] | None = None, - ) -> Iterator[list[tuple[str, torch.Tensor]]]: + def _iter_expert_chunks(self) -> Iterator[list[tuple[str, torch.Tensor]]]: """ Yield one HF chunk per EP-weighted batch of expert params: TP gather + - buffer until threshold, then EP gather + HF convert. ``params`` lets - callers restrict the iter to a subset (used by delta-sync sub-passes); - defaults to all expert params on this rank. + buffer until threshold, then EP gather + HF convert. """ - if params is None: - params = ((n, p) for n, p in named_params_and_buffers(self.args, self.model) if ".experts." in n) + params = ((n, p) for n, p in named_params_and_buffers(self.args, self.model) if ".experts." in n) buffer_size = 0 batch: list[tuple[str, torch.Tensor]] = [] for name, param in params: @@ -247,12 +241,9 @@ def _update_bucket_weights_from_distributed( converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None, load_format: str | None = None, - delta: DeltaSpec | None = None, ) -> None: """ Lock → broadcast → clear → unlock → pbar++. Lock prevents NCCL deadlock. - Delta sync passes ``load_format="delta"`` + a ``DeltaSpec`` describing the - per-param decoding of the (__positions__, __values__) bucket tensors. """ # lock the rollout engines to prevent dead lock on broadcast. while not ray.get(self.rollout_engine_lock.acquire.remote()): @@ -265,7 +256,6 @@ def _update_bucket_weights_from_distributed( self.rollout_engines, converted_named_tensors, load_format=load_format, - delta=delta, ) ray.get(refs) @@ -339,11 +329,9 @@ def update_weights_from_distributed( rollout_engines: Sequence[ActorHandle], converted_named_tensors: Sequence[tuple[str, torch.Tensor]], load_format: str | None = None, - delta: DeltaSpec | None = None, ) -> list[ObjectRef]: """ Send metadata (Ray), broadcast tensors (NCCL rank 0 → engines). - Delta sync passes ``load_format="delta"`` + ``delta`` (DeltaSpec). """ refs = [ engine.update_weights_from_distributed.remote( @@ -353,7 +341,6 @@ def update_weights_from_distributed( group_name=group_name, weight_version=str(weight_version), load_format=load_format, - delta=delta, ) for engine in rollout_engines ] diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py deleted file mode 100644 index fbe24bbc1c..0000000000 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py +++ /dev/null @@ -1,864 +0,0 @@ -""" -Delta weight sync. - -For each sync, the sender bytewise-diffs the current weights against a -pinned-CPU snapshot of the last broadcast, packs the changed positions -and values, and ships them via one of two transports: - - - "nccl": each bucket flush goes out via NCCL broadcast (low-latency, - high-bandwidth, intra-datacenter). - - "disk": each bucket flush is written to a versioned shared-FS directory - as one safetensors file; one HTTP push per sync wakes the rollout - engines to read+apply (cross-datacenter, bandwidth-limited). - -Both transports share one wire layout (``__positions__`` uint8 byte blob + -``__values__`` param-dtype tensor + per-param decoding manifest) and one -receiver-side decoder. Three encodings differ only in how positions are -packed: - - indices : int32 absolute positions - deltas : uint16 gap-deltas (uint32 fallback per param) - deltas_zstd : ``deltas`` with the safetensors blob wrapped in zstd L1 - -The receiver overwrites changed positions with the trainer's exact bytes -(no arithmetic), so the apply is lossless and there is no drift to fight -with periodic re-syncs. The first ``update_weights`` call seeds the -snapshot without contacting the rollout engines — they're assumed to have -loaded the same HF checkpoint at init. -""" - -import itertools -import json -import logging -import os -import shutil -import threading -from argparse import Namespace -from collections.abc import Callable, Iterator, Mapping, Sequence -from concurrent.futures import ThreadPoolExecutor -from dataclasses import asdict, dataclass, field, replace -from queue import Queue - -import numpy as np -import ray -import torch -import torch.distributed as dist -from megatron.core import mpu -from ray.actor import ActorHandle -from safetensors.torch import save as st_save_bytes -from tqdm import tqdm - -from slime.utils.distributed_utils import get_gloo_group -from slime.utils.timer import Timer, timer - -from ..sglang import DeltaEncoding, DeltaParam, DeltaSpec -from .update_weight_from_distributed import UpdateWeightFromDistributed - -logger = logging.getLogger(__name__) - - -# ---------- compute + encode ----------------------------------------------- - - -@dataclass -class ParamDiff: - """ - One per-param compute output. ``values`` is a reference to the full-shape - current tensor (no copy); ``mask`` is a same-shape bool marking the - positions whose bytes differ from the snapshot. - """ - - name: str - values: torch.Tensor - mask: torch.Tensor - - -@dataclass -class EncodedChunk: - """ - One HF chunk after position+value encoding, before bucket merging. - - ``pos_bytes`` and ``val_tensor`` are the chunk-local concatenations across - all params; per-param byte/element offsets live on ``params``. - """ - - pos_bytes: bytes - val_tensor: torch.Tensor - params: list[DeltaParam] - nnz: int - - @classmethod - def empty(cls) -> "EncodedChunk": - return cls(pos_bytes=b"", val_tensor=torch.empty(0, dtype=torch.bfloat16), params=[], nnz=0) - - -def _checksum(positions: torch.Tensor, values: torch.Tensor) -> int: - """ - Wire-corruption check via ``torch.hash_tensor`` (XOR-reduce over uint64 bitcast). - Sender computes pre-flush, receiver computes post-recv; mismatch indicates - corruption between encode and apply. One reduction + one ``.item()`` sync per arg. - """ - p = int(torch.hash_tensor(positions).item()) if positions.numel() else 0 - v = int(torch.hash_tensor(values).item()) if values.numel() else 0 - return p ^ (v << 1) - - -def _bytewise_diff_mask(current: torch.Tensor, snapshot: torch.Tensor) -> torch.Tensor: - """ - Per-element bool mask: True where current and snapshot bytes differ. Dtype-agnostic via view-as-integer. - """ - es = current.element_size() - int_dtype = {1: torch.uint8, 2: torch.int16, 4: torch.int32, 8: torch.int64}.get(es) - if int_dtype is None: - raise ValueError(f"unsupported element size {es}") - return current.view(int_dtype) != snapshot.view(int_dtype) - - -def _sparse_boundaries( - diffs: list[ParamDiff], -) -> tuple[torch.Tensor, list[int], torch.Tensor, list[int]]: - """ - One concat → one nonzero → one searchsorted → one ``tolist()``: collapses - per-param host syncs to one per chunk. Returns ``(big_val, bounds, big_idx, cum)``. - """ - device = diffs[0].values.device - sizes = [d.values.numel() for d in diffs] - cum = list(itertools.accumulate(sizes)) - cum_t = torch.tensor(cum, dtype=torch.int64, device=device) - - big_values = torch.cat([d.values.contiguous().view(-1) for d in diffs], dim=0) - big_mask = torch.cat([d.mask.contiguous().view(-1) for d in diffs], dim=0) - big_idx = big_mask.nonzero(as_tuple=False).view(-1) - big_val = big_values[big_idx] - bounds = torch.searchsorted(big_idx, cum_t).tolist() - return big_val, bounds, big_idx, cum - - -def encode_indices(diffs: list[ParamDiff]) -> EncodedChunk: - """ - int32 absolute positions, per-param. Position blob is uint8 bytes; pos_width=4 for all params. - """ - if not diffs: - return EncodedChunk.empty() - big_val, bounds, big_idx, cum = _sparse_boundaries(diffs) - pos_pieces: list[torch.Tensor] = [] - val_pieces: list[torch.Tensor] = [] - params: list[DeltaParam] = [] - pos_byte_off = val_off = 0 - prev_b = 0 - prev_param_start = 0 - for i, d in enumerate(diffs): - b = bounds[i] - nnz = b - prev_b - if nnz > 0: - local_idx = (big_idx[prev_b:b] - prev_param_start).to(torch.int32) - pos_pieces.append(local_idx) - val_pieces.append(big_val[prev_b:b]) - params.append( - DeltaParam( - name=d.name, - dtype=str(d.values.dtype).replace("torch.", ""), - shape=list(d.values.shape), - pos_start=pos_byte_off, - pos_end=pos_byte_off + nnz * 4, - pos_width=4, - val_start=val_off, - val_end=val_off + nnz, - ) - ) - pos_byte_off += nnz * 4 - val_off += nnz - prev_b = b - prev_param_start = cum[i] - if not params: - return EncodedChunk.empty() - positions = torch.cat(pos_pieces, dim=0) - values = torch.cat(val_pieces, dim=0) - return EncodedChunk( - pos_bytes=positions.cpu().numpy().tobytes(), - val_tensor=values, - params=params, - nnz=val_off, - ) - - -def encode_deltas(diffs: list[ParamDiff]) -> EncodedChunk: - """ - Gap-encode sorted positions: store ``idx[k] - idx[k-1] - 1`` with idx[-1] := -1 - so the first delta equals the first index. Per-param downcast to uint16 if the max - gap fits, otherwise uint32. At ~2% Bernoulli density on bf16 weights, max gap ≈ 300 - — uint16 fits; the fallback covers pathological inputs without correctness risk. - Receiver inverts: ``idx = cumsum(delta + 1) - 1``. - """ - if not diffs: - return EncodedChunk.empty() - big_val, bounds, big_idx, cum = _sparse_boundaries(diffs) - - kept: list[tuple[ParamDiff, int]] = [] # (diff, nnz) for non-empty params - per_param_deltas: list[torch.Tensor] = [] - val_pieces: list[torch.Tensor] = [] - prev_b = 0 - prev_param_start = 0 - for i, d in enumerate(diffs): - b = bounds[i] - nnz = b - prev_b - if nnz > 0: - local_idx = big_idx[prev_b:b] - prev_param_start # int64, sorted - prev = torch.cat( - [ - torch.tensor([-1], dtype=local_idx.dtype, device=local_idx.device), - local_idx[:-1], - ] - ) - per_param_deltas.append(local_idx - prev - 1) - val_pieces.append(big_val[prev_b:b]) - kept.append((d, nnz)) - prev_b = b - prev_param_start = cum[i] - - if not kept: - return EncodedChunk.empty() - - # One CPU sync for per-param width selection. - max_per_param = torch.stack([d.max() for d in per_param_deltas]).cpu().tolist() - pos_byte_pieces: list[bytes] = [] - pos_byte_off = val_off = 0 - params: list[DeltaParam] = [] - for (d, nnz), deltas, max_d in zip(kept, per_param_deltas, max_per_param, strict=True): - width = 2 if int(max_d) <= 65535 else 4 - np_dtype = np.uint16 if width == 2 else np.uint32 - b_chunk = deltas.cpu().numpy().astype(np_dtype, copy=False).tobytes() - pos_byte_pieces.append(b_chunk) - params.append( - DeltaParam( - name=d.name, - dtype=str(d.values.dtype).replace("torch.", ""), - shape=list(d.values.shape), - pos_start=pos_byte_off, - pos_end=pos_byte_off + len(b_chunk), - pos_width=width, - val_start=val_off, - val_end=val_off + nnz, - ) - ) - pos_byte_off += len(b_chunk) - val_off += nnz - - values = torch.cat(val_pieces, dim=0) - return EncodedChunk( - pos_bytes=b"".join(pos_byte_pieces), - val_tensor=values, - params=params, - nnz=val_off, - ) - - -# ---------- snapshot state ------------------------------------------------- - - -class DeltaState: - """ - Pinned-CPU snapshot of every HF tensor we've broadcast, plus the H2D/D2H - side streams that pipeline next-chunk snapshot transfer behind the current - chunk's compute. - """ - - def __init__(self) -> None: - self.snapshot: dict[str, torch.Tensor] = {} - self.d2h_stream: torch.cuda.Stream | None = None - self.h2d_stream: torch.cuda.Stream | None = None - self.snapshot_dirty = False - - def prefetch_snapshot( - self, named_tensors: list[tuple[str, torch.Tensor]] - ) -> tuple[list[torch.Tensor], torch.cuda.Event]: - """ - Start an async H2D copy of the snapshot tensors for ``named_tensors`` on a side stream. - """ - if self.h2d_stream is None: - self.h2d_stream = torch.cuda.Stream() - prev_gpu: list[torch.Tensor] = [] - with torch.cuda.stream(self.h2d_stream): - for name, tensor in named_tensors: - if name not in self.snapshot: - raise KeyError(f"missing snapshot for {name!r}; first update_weights call seeds the snapshot") - prev_gpu.append(self.snapshot[name].to(device=tensor.device, non_blocking=True)) - event = self.h2d_stream.record_event() - return prev_gpu, event - - def compute_diffs( - self, - named_tensors: list[tuple[str, torch.Tensor]], - prefetched: tuple[list[torch.Tensor], torch.cuda.Event], - ) -> list[ParamDiff]: - """ - Wait for the prefetched H2D copy, then per-param bytewise diff against the snapshot. - """ - prev_gpu, event = prefetched - event.wait() - return [ - ParamDiff(name=name, values=current, mask=_bytewise_diff_mask(current, prev)) - for (name, current), prev in zip(named_tensors, prev_gpu, strict=True) - ] - - def update_snapshot_async(self, named_tensors: list[tuple[str, torch.Tensor]]) -> None: - """ - Enqueue a D2H copy of ``named_tensors`` into the pinned-CPU snapshot on a - side stream. Non-blocking; call ``flush_snapshot`` before the next sync. - """ - if self.d2h_stream is None: - self.d2h_stream = torch.cuda.Stream() - event = torch.cuda.current_stream().record_event() - with torch.cuda.stream(self.d2h_stream): - self.d2h_stream.wait_event(event) - for name, tensor in named_tensors: - if name not in self.snapshot: - self.snapshot[name] = torch.empty_like(tensor, device=torch.device("cpu"), pin_memory=True) - self.snapshot[name].copy_(tensor.detach(), non_blocking=True) - self.snapshot_dirty = True - - def flush_snapshot(self) -> None: - """ - Block until all enqueued D2H snapshot copies have landed. - """ - if self.snapshot_dirty: - if self.d2h_stream is not None: - self.d2h_stream.synchronize() - else: - torch.cuda.synchronize() - self.snapshot_dirty = False - - -# ---------- bucket --------------------------------------------------------- - - -@dataclass -class DeltaBucket: - """ - Accumulates encoded chunks for one flush. Per-param offsets are rebased - into the bucket's growing position blob + value tensor on ``add``. - """ - - pos_pieces: list[bytes] = field(default_factory=list) - val_pieces: list[torch.Tensor] = field(default_factory=list) - params: list[DeltaParam] = field(default_factory=list) - pos_total: int = 0 - val_total: int = 0 - byte_size: int = 0 - - @property - def has_updates(self) -> bool: - return bool(self.pos_pieces) - - def should_flush_before_add(self, chunk: EncodedChunk, byte_limit: int) -> bool: - """True iff adding ``chunk`` would push the bucket past ``byte_limit``.""" - chunk_bytes = len(chunk.pos_bytes) + chunk.val_tensor.numel() * chunk.val_tensor.element_size() - return self.has_updates and self.byte_size + chunk_bytes > byte_limit - - def add(self, chunk: EncodedChunk) -> None: - """Append ``chunk``, rebasing each param's byte/element offsets into the bucket.""" - for p in chunk.params: - self.params.append( - replace( - p, - pos_start=p.pos_start + self.pos_total, - pos_end=p.pos_end + self.pos_total, - val_start=p.val_start + self.val_total, - val_end=p.val_end + self.val_total, - ) - ) - self.pos_pieces.append(chunk.pos_bytes) - self.val_pieces.append(chunk.val_tensor) - self.pos_total += len(chunk.pos_bytes) - self.val_total += chunk.val_tensor.numel() - self.byte_size += len(chunk.pos_bytes) + chunk.val_tensor.numel() * chunk.val_tensor.element_size() - - def merged_positions_cpu(self) -> torch.Tensor: - """One CPU uint8 tensor with the bucket's positions blob.""" - merged = b"".join(self.pos_pieces) - if not merged: - return torch.empty(0, dtype=torch.uint8) - return torch.from_numpy(np.frombuffer(merged, dtype=np.uint8).copy()) - - def merged_values(self) -> torch.Tensor: - """One GPU tensor with the bucket's values, concatenated across chunks.""" - if not self.val_pieces: - return torch.empty(0, dtype=torch.bfloat16) - return torch.cat(self.val_pieces, dim=0) - - def clear(self) -> None: - """Reset to empty so the bucket can be reused for the next flush.""" - self.pos_pieces.clear() - self.val_pieces.clear() - self.params.clear() - self.pos_total = 0 - self.val_total = 0 - self.byte_size = 0 - - -# ---------- async safetensors writer (disk transport only) ----------------- - - -class AsyncSafetensorsWriter: - """ - Background thread that drains a queue of file writes. Producers do GPU→CPU - on the default stream and enqueue; the writer does the slow disk I/O - (and optional zstd compress) off the critical path. End-of-sync ``drain()`` - blocks until all enqueued writes have landed. - """ - - def __init__(self, compress_with_zstd: bool, zstd_level: int = 1) -> None: - self._queue: Queue = Queue() - self._error: BaseException | None = None - self._compress_with_zstd = compress_with_zstd - self._zstd_level = zstd_level - if compress_with_zstd: - # Lazy import — non-disk users don't pay the dep. - import zstandard - - self._zstd = zstandard - self._lock = threading.Lock() - self.bytes_pre_compress = 0 - self.bytes_post_compress = 0 - self._thread = threading.Thread(target=self._run, name="delta-disk-writer", daemon=True) - self._thread.start() - - def enqueue( - self, - path: str, - tensors: dict[str, torch.Tensor], - metadata: dict[str, str], - ) -> None: - """Hand a (path, tensors, metadata) tuple to the writer thread.""" - if self._error is not None: - raise RuntimeError(f"writer thread already failed: {self._error!r}") - self._queue.put((path, tensors, metadata)) - - def drain(self) -> None: - """Block until every queued write has landed; re-raise any writer-thread error.""" - self._queue.join() - if self._error is not None: - raise RuntimeError(f"writer thread failed: {self._error!r}") from self._error - - def reset_counters(self) -> None: - """Zero the byte counters at the start of a sync.""" - with self._lock: - self.bytes_pre_compress = 0 - self.bytes_post_compress = 0 - - def _run(self) -> None: - """Writer-thread loop: safetensors-encode → (optional zstd) → atomic replace.""" - cctx = self._zstd.ZstdCompressor(level=self._zstd_level, threads=-1) if self._compress_with_zstd else None - while True: - path, tensors, metadata = self._queue.get() - try: - if self._error is None: - blob = st_save_bytes(tensors, metadata=metadata) - pre = len(blob) - if cctx is not None: - blob = cctx.compress(blob) - post = len(blob) - tmp = path + ".tmp" - with open(tmp, "wb") as f: - f.write(blob) - f.flush() - os.fsync(f.fileno()) - os.replace(tmp, path) - with self._lock: - self.bytes_pre_compress += pre - self.bytes_post_compress += post - except BaseException as e: # noqa: BLE001 - self._error = e - finally: - self._queue.task_done() - - -# ---------- main class ----------------------------------------------------- - - -class UpdateWeightFromDistributedDelta(UpdateWeightFromDistributed): - """ - Selective delta sync. ``--update-weight-transport`` picks the per-flush carrier: - "nccl" broadcasts each bucket; "disk" writes each bucket as a safetensors file under - ``--update-weight-disk-dir`` and pushes once at end-of-sync. - """ - - def __init__( - self, - args: Namespace, - model: Sequence[torch.nn.Module], - weights_getter: Callable[[], Mapping[str, torch.Tensor]], - *, - model_name: str, - quantization_config: dict[str, int | str | list[str]] | None, - ) -> None: - super().__init__( - args, - model, - weights_getter, - model_name=model_name, - quantization_config=quantization_config, - ) - self.transport = args.update_weight_transport - self.encoding = DeltaEncoding(args.update_weight_encoding) - self.delta_state = DeltaState() - self._snapshot_seeded = False - # DELTAS_ZSTD shares the gap encoder; zstd is applied at file-write time. - self._encode = encode_indices if self.encoding is DeltaEncoding.INDICES else encode_deltas - - self.writer: AsyncSafetensorsWriter | None = None - self.delta_dir: str | None = None - self._pre_push_hook: Callable | None = None - # Disk transport: each pass boundary publishes its accumulated files - # (the only globally-synced flush points, since ``_publish_batch`` - # contains collectives). ``_pre_push_hook`` may return a Future, in - # which case the receiver RPC is deferred behind it via - # ``_rpc_executor`` so the main encode thread continues immediately. - # ``_pending_publishes`` holds the resulting Future[list[ObjectRef]] - # on rank 0; ``_finalize_sync`` awaits them at end of sync. - self._pending_files: list[str] = [] - self._pending_publishes: list = [] - self._published_any: bool = False - self._rpc_executor: ThreadPoolExecutor | None = None - if self.transport == "disk": - self.delta_dir = args.update_weight_disk_dir - os.makedirs(self.delta_dir, exist_ok=True) - self.writer = AsyncSafetensorsWriter( - compress_with_zstd=(self.encoding == DeltaEncoding.DELTAS_ZSTD), - ) - self._rpc_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="delta-publish-rpc") - if getattr(args, "custom_delta_pre_push_path", None): - from slime.utils.misc import load_function - - self._pre_push_hook = load_function(args.custom_delta_pre_push_path) - - def connect_rollout_engines( - self, - rollout_engines: Sequence[ActorHandle], - rollout_engine_lock: ActorHandle, - engine_gpu_counts: Sequence[int] | None = None, - engine_gpu_offsets: Sequence[int] | None = None, - ) -> None: - """ - NCCL transport: delegate to parent (group creation). Disk transport: just - record the engines + PP-src flag (no NCCL group needed). - """ - if self.transport == "nccl": - super().connect_rollout_engines( - rollout_engines, - rollout_engine_lock, - engine_gpu_counts=engine_gpu_counts, - engine_gpu_offsets=engine_gpu_offsets, - ) - return - self.rollout_engines = rollout_engines - self.rollout_engine_lock = rollout_engine_lock - self._engine_gpu_counts = engine_gpu_counts - self._is_pp_src_rank = ( - mpu.get_data_parallel_rank(with_context_parallel=True) == 0 and mpu.get_tensor_model_parallel_rank() == 0 - ) - pp_rank = mpu.get_pipeline_model_parallel_rank() - self._group_name = f"slime-pp_{pp_rank}" - - def disconnect_rollout_engines(self) -> None: - if self.transport == "nccl": - super().disconnect_rollout_engines() - - @torch.no_grad() - def update_weights(self) -> None: - """ - First call: seed the CPU snapshot from current model state, no engine RPCs. - Subsequent calls: pause → diff/encode → finalize → resume. ``delta_encode`` - covers the sender's per-param TP/EP gather + diff + sparse encode + per-publish - commit/RPC handoff; ``delta_finalize`` covers the tail wait for the last - batch's receiver-apply. Their sum is the sync latency the user observes. - """ - if not self._snapshot_seeded: - self._seed_snapshot() - self._snapshot_seeded = True - return - - self.weight_version += 1 - if self.transport == "disk": - self._version_dir = os.path.join(self.delta_dir, f"weight_v{self.weight_version:06d}") - if self._is_pp_src_rank: - os.makedirs(self._version_dir, exist_ok=True) - - if dist.get_rank() == 0: - ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) - ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) - dist.barrier(group=get_gloo_group()) - - self.density_nnz = self.density_numel = self.wire_bytes = self._flush_idx = 0 - self._pending_files.clear() - self._pending_publishes.clear() - self._published_any = False - if self.writer is not None: - self.writer.reset_counters() - pbar = tqdm(desc=f"[{self._group_name}] Update weights", total=0) if self._is_pp_src_rank else None - - with timer("delta_encode"): - self._send_weights(pbar) - if self.writer is not None: - self.writer.drain() - self.delta_state.flush_snapshot() - dist.barrier(group=get_gloo_group()) - - with timer("delta_finalize"): - self._finalize_sync() - - self._record_metrics() - - def _seed_snapshot(self) -> None: - """ - Populate the snapshot from current model state (TP/EP gather + HF - convert on PP-src ranks, D2H pinned copy). Cost is one full pass over - params — ~50s blocking on 355B at init. - """ - for chunk_iter in (self._iter_non_expert_chunks(), self._iter_expert_chunks()): - for hf_chunk in chunk_iter: - if hf_chunk: - self.delta_state.update_snapshot_async(hf_chunk) - dist.barrier(group=get_gloo_group()) - self.delta_state.flush_snapshot() - - def _send_weights(self, pbar: tqdm | None) -> None: - """ - Non-expert pass then expert pass, each followed by a barrier + (disk-only) - publish. The expert pass is split into ``_EXPERT_SUBPASSES`` sub-passes so - receiver apply for an earlier batch overlaps with later expert encoding, - instead of bottlenecking at end-of-sync. Megatron splits MoE layers - uniformly across PP ranks, so a per-rank slice of the expert param list - keeps the publish count identical on every rank (no barrier desync). - """ - from .common import named_params_and_buffers - - bucket = DeltaBucket() - self._pipeline_pass(self._iter_non_expert_chunks(), bucket, pbar) - self._flush_and_publish(bucket, pbar) - - expert_params = [(n, p) for n, p in named_params_and_buffers(self.args, self.model) if ".experts." in n] - n = len(expert_params) - for i in range(self._EXPERT_SUBPASSES): - lo = i * n // self._EXPERT_SUBPASSES - hi = (i + 1) * n // self._EXPERT_SUBPASSES - self._pipeline_pass(self._iter_expert_chunks(iter(expert_params[lo:hi])), bucket, pbar) - self._flush_and_publish(bucket, pbar) - - _EXPERT_SUBPASSES = 4 - - def _flush_and_publish(self, bucket: DeltaBucket, pbar: tqdm | None) -> None: - """ - End-of-sub-pass: drain the in-flight bucket, barrier all PP ranks, then - (disk-only) fire one publish RPC for everything since the last call. - """ - if bucket.has_updates: - self._flush_bucket(bucket, pbar) - dist.barrier(group=get_gloo_group()) - if self.transport == "disk": - self._publish_batch() - - def _pipeline_pass( - self, - chunk_iter: Iterator[list[tuple[str, torch.Tensor]]], - bucket: DeltaBucket, - pbar: tqdm | None, - ) -> None: - """ - 1-step H2D snapshot prefetch lookahead: chunk N+1's snapshot transfer - overlaps chunk N's compute+encode on the default stream. - """ - pending_chunk: list[tuple[str, torch.Tensor]] | None = None - pending_prefetch: tuple[list[torch.Tensor], torch.cuda.Event] | None = None - for hf_chunk in chunk_iter: - if not hf_chunk: - continue - next_prefetch = self.delta_state.prefetch_snapshot(hf_chunk) - if pending_prefetch is not None: - self._enqueue_chunk(pending_chunk, pending_prefetch, bucket, pbar) - pending_chunk, pending_prefetch = hf_chunk, next_prefetch - if pending_prefetch is not None: - self._enqueue_chunk(pending_chunk, pending_prefetch, bucket, pbar) - - def _enqueue_chunk( - self, - hf_chunk: list[tuple[str, torch.Tensor]], - prefetched: tuple[list[torch.Tensor], torch.cuda.Event], - bucket: DeltaBucket, - pbar: tqdm | None, - ) -> None: - """ - compute diffs → snapshot new prev → encode → bucket.add (flushing if full). - """ - diffs = self.delta_state.compute_diffs(hf_chunk, prefetched=prefetched) - self.delta_state.update_snapshot_async(hf_chunk) - chunk = self._encode(diffs) - self.density_numel += sum(d.values.numel() for d in diffs) - self.density_nnz += chunk.nnz - self.wire_bytes += len(chunk.pos_bytes) + chunk.val_tensor.numel() * chunk.val_tensor.element_size() - if not chunk.params: - return - if bucket.should_flush_before_add(chunk, self.args.update_weight_buffer_size): - self._flush_bucket(bucket, pbar) - bucket.add(chunk) - - def _flush_bucket(self, bucket: DeltaBucket, pbar: tqdm | None) -> None: - """ - NCCL: broadcast (__positions__, __values__) with a DeltaSpec. - Disk: enqueue one safetensors file with the same payload + metadata. - Both paths embed a checksum the receiver verifies before apply. - """ - if not bucket.has_updates: - return - positions_cpu = bucket.merged_positions_cpu() - values_gpu = bucket.merged_values() - params = list(bucket.params) - bucket.clear() - - # GPU-resident checksum: positions go to the device the values already live on - # (NCCL needs the same move anyway; disk gets it for free at the reduction). - positions_gpu = positions_cpu.to(values_gpu.device, non_blocking=True) - checksum = _checksum(positions_gpu, values_gpu) - - if self.transport == "nccl": - spec = DeltaSpec(encoding=self.encoding, params=params, checksum=checksum) - self._update_bucket_weights_from_distributed( - [("__positions__", positions_gpu), ("__values__", values_gpu)], - pbar=pbar, - load_format="delta", - delta=spec, - ) - else: # disk - tensors = {"__positions__": positions_cpu, "__values__": values_gpu.cpu()} - metadata = { - "encoding": self.encoding.value, - "params": json.dumps([asdict(p) for p in params]), - "current_version": str(self.weight_version), - "checksum": str(checksum), - } - filename = f"rank{dist.get_rank():04d}_flush{self._flush_idx:06d}.safetensors" - path = os.path.join(self._version_dir, filename) - self.writer.enqueue(path, tensors, metadata) - self._pending_files.append(filename) - if pbar is not None: - pbar.update(1) - self._flush_idx += 1 - - def _publish_batch(self) -> None: - """ - Drain pending fsyncs, invoke the pre-push hook (may return a Future for an - async durability step on shared FS), then defer rank 0's - ``update_weights_from_disk`` RPC behind that Future via ``_rpc_executor``. - Each deferred dispatch lands in ``_pending_publishes`` as a - Future[list[ObjectRef]]; ``_finalize_sync`` awaits both layers. Safe to call - with empty ``_pending_files``: the all_gather still synchronizes and rank 0 - skips the dispatch when no rank produced files. - """ - self.writer.drain() - dist.barrier(group=get_gloo_group()) - - commit_future = None - if self._pre_push_hook is not None: - commit_future = self._pre_push_hook(self.args, self._version_dir, list(self.rollout_engines)) - dist.barrier(group=get_gloo_group()) - - # Collect every rank's batch filenames at rank 0; payload is ~KB, gather is cheap. - all_files: list[list[str]] = [None] * dist.get_world_size() # type: ignore[list-item] - dist.all_gather_object(all_files, list(self._pending_files), group=get_gloo_group()) - flat = [f for sub in all_files for f in sub] - self._pending_files.clear() - - if dist.get_rank() == 0 and flat: - version_dir = self._version_dir - engines = list(self.rollout_engines) - weight_version = str(self.weight_version) - self._published_any = True - - def _fire_when_committed() -> list: - if commit_future is not None: - commit_future.result() - return [ - engine.update_weights_from_disk.remote( - model_path=version_dir, - files=flat, - load_format="delta", - weight_version=weight_version, - ) - for engine in engines - ] - - self._pending_publishes.append(self._rpc_executor.submit(_fire_when_committed)) - - def _finalize_sync(self) -> None: - """ - Per-transport end-of-sync. NCCL: each flush already broadcasted; just resume. - Disk: publish the trailing files, wait for all streamed applies to land, then - cleanup + resume. - """ - if self.transport == "nccl": - if dist.get_rank() == 0: - ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) - dist.barrier(group=get_gloo_group()) - return - - if self._pending_files: - self._publish_batch() - if dist.get_rank() == 0: - # Each entry is a Future returning a list of ObjectRefs. Awaiting the - # Futures unblocks the (commit-then-RPC) chain; ray.get waits for the - # receivers' apply to finish. - object_refs = [ref for fut in self._pending_publishes for ref in fut.result()] - ray.get(object_refs) - self._pending_publishes.clear() - if not self._published_any: - # No delta files needed publishing this sync (e.g. all-zero diff). - # Engines never saw the new version via update_weights_from_disk, so - # bump it explicitly to keep their recorded version in sync with ours. - weight_version = str(self.weight_version) - ray.get([engine.set_weight_version.remote(weight_version) for engine in self.rollout_engines]) - if not self.args.update_weight_delta_keep_files: - shutil.rmtree(self._version_dir, ignore_errors=True) - ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) - dist.barrier(group=get_gloo_group()) - - def _record_metrics(self) -> None: - """ - Allreduce density/byte counters across PP-src ranks; stash on - ``update_weight_metrics`` for the actor to drain into the next step log. - Wall-clock timings come from the slime ``Timer`` (``delta_encode`` / - ``delta_finalize`` blocks above + the outer ``update_weights`` decorator). - """ - pre_bytes = self.writer.bytes_pre_compress if self.writer is not None else 0 - post_bytes = self.writer.bytes_post_compress if self.writer is not None else 0 - counts = torch.tensor( - [self.density_nnz, self.density_numel, self.wire_bytes, pre_bytes, post_bytes], - dtype=torch.int64, - device=torch.cuda.current_device(), - ) - dist.all_reduce(counts) - nnz, numel, wire_bytes, pre_bytes, post_bytes = counts.tolist() - - density = nnz / max(numel, 1) - compression_ratio = (pre_bytes / post_bytes) if post_bytes > 0 else 1.0 - - m = self.update_weight_metrics - m["perf/update_weights_density"] = density - m["perf/update_weights_wire_bytes"] = wire_bytes - m["perf/update_weights_flushes_per_rank"] = float(self._flush_idx) - if self.transport == "disk": - m["perf/update_weights_disk_bytes_pre_compress"] = pre_bytes - m["perf/update_weights_disk_bytes_post_compress"] = post_bytes - m["perf/update_weights_compression_ratio"] = compression_ratio - - if dist.get_rank() == 0: - t = Timer().log_dict() - logger.info( - "[delta sync v=%s] transport=%s enc=%s density=%.3f%% " "encode=%.2fs finalize=%.2fs flushes/rank=%d", - self.weight_version, - self.transport, - self.encoding.value, - 100.0 * density, - t.get("delta_encode", 0.0), - t.get("delta_finalize", 0.0), - self._flush_idx, - ) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 724e05355b..723c0c4b9a 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -65,6 +65,7 @@ def connect_rollout_engines( rollout_engine_lock: ActorHandle, engine_gpu_counts: Sequence[int] | None = None, engine_gpu_offsets: Sequence[int] | None = None, + all_engine_actors: Sequence[ActorHandle] | None = None, ) -> None: """ Split colocated/distributed engines. Global source rank (DP=TP=PP=0) creates NCCL diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index 15c4dd7231..04f723fcef 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -1,8 +1,10 @@ import dataclasses +import glob import ipaddress import logging import multiprocessing import os +import threading import time from urllib.parse import quote @@ -167,6 +169,15 @@ def _format_v6_uri(addr): else: self._init_normal(server_args_dict) + # Disk-delta sync: each per-host actor materializes its local checkpoint in the background + # while the engine launches and serves; the first sync joins this thread. + self._base_init_thread = None + if self.args.update_weight_mode == "delta" and self.args.update_weight_transport == "disk": + self._base_init_thread = threading.Thread( + target=self._init_local_checkpoint, name="delta-base-init", daemon=True + ) + self._base_init_thread.start() + def _init_external(self, expect_server_args, external_engine_need_check_fields): logger.info(f"Use external SGLang engine (rank={self.rank}, expect_server_args={expect_server_args})") @@ -379,6 +390,48 @@ def resume_memory_occupation(self, tags: list[str] = None): def check_weights(self, action: str): return self._make_request("weights_checker", {"action": action}) + def _init_local_checkpoint(self): + """Background thread: copy the base into the host-local checkpoint, then drop the source.""" + from slime.utils.disk_delta import init_local_checkpoint + + init_local_checkpoint(self.args.update_weight_local_checkpoint_dir, self.args.hf_checkpoint) + self._drop_hf_cache() + + def _drop_hf_cache(self): + # sglang loads the HF checkpoint once at init and never re-reads it (weight updates read the + # local base), so dropping it from the page cache keeps the local base resident instead. + from slime.utils.disk_delta import drop_page_cache + + for path in glob.glob(os.path.join(self.args.hf_checkpoint, "*")): + if os.path.isfile(path): + drop_page_cache(path) + + def _ensure_base_ready(self): + """Join the one-time base-materialization thread on the first sync, then re-drop the HF + source (sglang's init-load may have re-cached it after the thread dropped it).""" + if self._base_init_thread is None: + return + self._base_init_thread.join() + self._base_init_thread = None + self._drop_hf_cache() + + def sync_weights(self, target_version: int): + """Bring this host's local checkpoint to ``target_version`` by applying the published + deltas (raises on any error). Plain file work; the sglang server is untouched until the + subsequent reload.""" + from slime.utils.disk_delta import apply_deltas + + self._ensure_base_ready() + if self.args.custom_delta_pre_read_path: + from slime.utils.misc import load_function + + load_function(self.args.custom_delta_pre_read_path)(self.args.update_weight_disk_dir, target_version) + apply_deltas( + self.args.update_weight_local_checkpoint_dir, + self.args.update_weight_disk_dir, + target_version, + ) + def update_weights_from_disk( self, model_path: str, @@ -437,7 +490,6 @@ def update_weights_from_distributed( flush_cache=False, weight_version: str | None = None, load_format: str | None = None, - delta=None, ): payload = { "names": names, @@ -450,19 +502,6 @@ def update_weights_from_distributed( payload["weight_version"] = weight_version if load_format is not None: payload["load_format"] = load_format - if delta is not None: - # DeltaSpec → JSON string. Receiver reconstructs via DeltaEncoding(...) + - # DeltaParam(**p); avoids depending on FastAPI's nested-dataclass coercion. - import json - from dataclasses import asdict - - payload["delta"] = json.dumps( - { - "encoding": delta.encoding.value, - "params": [asdict(p) for p in delta.params], - "checksum": delta.checksum, - } - ) return self._make_request( "update_weights_from_distributed", payload, diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index b44b36f14e..0b8ccd22fd 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -495,7 +495,8 @@ def get_updatable_engines_and_lock(self): gpu_counts = srv.engine_gpu_counts if srv else [] gpu_offsets = srv.engine_gpu_offsets if srv else [] num_new = srv.num_new_engines if srv else 0 - return engines, self.rollout_engine_lock, num_new, gpu_counts, gpu_offsets + all_engine_actors = srv.all_engines if srv else [] + return engines, self.rollout_engine_lock, num_new, gpu_counts, gpu_offsets, all_engine_actors def get_num_rollout_per_epoch(self): assert self.args.rollout_global_dataset diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 6b87ef13bc..31552864f0 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -3,7 +3,6 @@ import json import logging import os -import warnings from typing import Any import yaml @@ -146,8 +145,8 @@ def add_train_arguments(parser): default="full", help=( "Weight sync strategy. 'full' (default) broadcasts every parameter " - "every sync. 'delta' detects byte-level changes against a pinned-CPU " - "snapshot of the previous broadcast and ships only the changed positions + values." + "every sync. 'delta' diffs each sync against a pinned-CPU snapshot of the " + "previous one and ships only the changed bytes (disk transport only)." ), ) parser.add_argument( @@ -157,9 +156,8 @@ def add_train_arguments(parser): help=( "Carrier for weight sync. In full mode, 'nccl' broadcasts chunks and " "'disk' writes a complete HF checkpoint under --update-weight-disk-dir " - "before engines reload it. In delta mode, 'nccl' broadcasts sparse deltas; " - "'disk' writes sparse safetensors under --update-weight-disk-dir and pushes " - "once at end-of-sync." + "before engines reload it. Delta mode is 'disk' only: each host applies the " + "published deltas into its local checkpoint and reloads via update_weights_from_disk." ), ) parser.add_argument( @@ -169,7 +167,7 @@ def add_train_arguments(parser): help=( "Filesystem directory for disk-backed weight sync. In --update-weight-mode=full, " "one complete HF checkpoint directory is written per sync. In delta mode, " - "one sparse-delta directory is written per sync." + "one delta directory (changed tensors only) is written per sync." ), ) parser.add_argument( @@ -182,32 +180,32 @@ def add_train_arguments(parser): ), ) parser.add_argument( - "--update-weight-encoding", - choices=["indices", "deltas", "deltas_zstd"], - default="indices", + "--update-weight-delta-encoding", + choices=["xor", "overwrite"], + default="xor", help=( - "Position encoding for partial flushes. 'indices': int32 absolute " - "positions (largest, lowest compute). 'deltas': uint16 gap-deltas " - "with uint32 fallback (smaller). 'deltas_zstd': 'deltas' with the " - "safetensors blob wrapped in zstd L1 (smallest, heaviest compute — " - "best for shared-FS bandwidth ≤ ~300 MB/s)." + "On-disk delta encoding for --update-weight-mode=delta --update-weight-transport=disk. " + "'xor' (default): new ^ old — smallest wire and fastest, but an involution that must be " + "applied exactly once against the correct base (applying it twice reverts). 'overwrite': " + "changed positions + new absolute values — larger, but idempotent (re-applicable any " + "number of times). Both are byte-level and dtype-blind; the engine reads the choice from " + "each version's index metadata." ), ) parser.add_argument( - "--update-weight-delta-dir", - type=str, - default=None, + "--update-weight-delta-checksum", + choices=["xxh3-128", "blake3", "adler32"], + default="xxh3-128", help=( - "Deprecated alias for --update-weight-disk-dir and will be removed in a future " - "release. Prefer the transport-level directory flag for both full and delta disk sync." + "Per-tensor integrity checksum for disk delta apply. The checksum is not the " + "apply bottleneck (the apply is decompress + XOR bound), so this is a digest-" + "property choice, not a speed one. 'xxh3-128' (default): widest fast non-" + "cryptographic digest, negligible accidental-corruption collisions. 'blake3': " + "cryptographic digest, for untrusted storage. 'adler32': 32-bit, for interop " + "with systems that expect it. The engine reads the choice from each version's " + "index metadata." ), ) - parser.add_argument( - "--update-weight-delta-keep-files", - action="store_true", - default=False, - help="Skip post-apply cleanup of per-sync version directories. Useful for debugging.", - ) parser.add_argument( "--custom-delta-pre-push-path", type=str, @@ -219,6 +217,28 @@ def add_train_arguments(parser): "Called from every trainer rank; the hook gates itself." ), ) + parser.add_argument( + "--custom-delta-pre-read-path", + type=str, + default=None, + help=( + "Path to a custom function called on each rollout host before it reads the " + "published delta directory (shared-FS visibility, e.g. a volume reload). " + "Signature: ``def hook(delta_dir: str, target_version: int) -> None``." + ), + ) + parser.add_argument( + "--update-weight-local-checkpoint-dir", + type=str, + default=None, + help=( + "Rollout-host-local directory (NVMe) holding a full HF checkpoint that " + "disk-delta sync patches in place. Each host materializes it from " + "--hf-checkpoint at engine start, applies each version's delta there, and " + "the engines reload from it. Required for --update-weight-mode=delta " + "--update-weight-transport=disk." + ), + ) parser.add_argument( "--custom-model-provider-path", type=str, @@ -1682,53 +1702,34 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: def _resolve_update_weight_disk_dir(args) -> None: - """Normalize disk-sync directory args. - - ``--update-weight-delta-dir`` is kept only as a compatibility alias. New - code should use ``--update-weight-disk-dir`` because the directory belongs - to the transport, not to the delta encoding mode. - """ - disk_dir = args.update_weight_disk_dir - delta_dir = args.update_weight_delta_dir - if disk_dir and delta_dir and disk_dir != delta_dir: + """Disk-backed sync (full or delta) needs a directory the trainer writes and the rollout + engines read — a filesystem shared between them.""" + if args.update_weight_transport == "disk" and not args.update_weight_disk_dir: raise ValueError( - "--update-weight-delta-dir is deprecated alias for --update-weight-disk-dir; " - "please set only one of them or set both to the same path." + "--update-weight-transport=disk requires --update-weight-disk-dir to point at " + "a filesystem shared between the trainer and the rollout engines." ) - if delta_dir: - warnings.warn( - "--update-weight-delta-dir is deprecated and will be removed in a future release; " - "use --update-weight-disk-dir instead.", - UserWarning, - stacklevel=2, - ) - - disk_dir = disk_dir or delta_dir - if args.update_weight_transport == "disk": - if not disk_dir: - raise ValueError( - "--update-weight-transport=disk requires --update-weight-disk-dir to point at " - "a filesystem shared between the trainer and the rollout engines." - ) - args.update_weight_disk_dir = disk_dir - args.update_weight_delta_dir = disk_dir - def _validate_update_weight_args(args) -> None: _resolve_update_weight_disk_dir(args) if args.update_weight_mode == "delta": - if args.update_weight_transport not in ("nccl", "disk"): + if args.update_weight_transport != "disk": raise ValueError( - "--update-weight-mode=delta supports only --update-weight-transport=nccl or disk, " + "--update-weight-mode=delta requires --update-weight-transport=disk, " f"got {args.update_weight_transport!r}." ) if args.colocate: raise ValueError( "--update-weight-mode=delta is not supported with --colocate. Colocate transfers " "weights via CUDA IPC (only a handle crosses processes), so the delta bookkeeping " - "(snapshot + diff + sparse encode) is pure overhead." + "(snapshot + diff + encode) is pure overhead." + ) + if not args.update_weight_local_checkpoint_dir: + raise ValueError( + "--update-weight-mode=delta requires --update-weight-local-checkpoint-dir " + "(a rollout-host-local NVMe directory)." ) diff --git a/slime/utils/disk_delta.py b/slime/utils/disk_delta.py new file mode 100644 index 0000000000..3ba8cef12f --- /dev/null +++ b/slime/utils/disk_delta.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import fcntl +import glob +import io +import json +import logging +import mmap +import os +import shutil +import struct +import threading +import zlib +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager + +import numpy as np +import zstandard + +logger = logging.getLogger(__name__) + +# The delta phases (XOR/scatter, zstd, checksum) are memory-bandwidth bound and release the GIL, +# so a thread pool over tensors recovers the bandwidth one thread leaves idle. +NUM_WORKERS = min(32, (os.cpu_count() or 8)) + +SYNC_DIR = ".delta_sync" # per-checkpoint dir holding the applied-version marker and the apply lock + + +def overwrite_encode(new: np.ndarray, changed_mask: np.ndarray) -> np.ndarray: + """The 'overwrite' delta: changed-position count (u4), positions (u4 each), then new values. + Idempotent to apply, unlike xor (an involution); the trainer picks the encoding per the docs.""" + pos = np.flatnonzero(changed_mask).astype(" None: + self._value = zlib.adler32(data, self._value) + + def hexdigest(self) -> str: + return f"{self._value:08x}" + + +def _new_hasher(algorithm: str): + if algorithm == "xxh3-128": + import xxhash + + return xxhash.xxh3_128() + if algorithm == "blake3": + import blake3 + + return blake3.blake3() + if algorithm == "adler32": + return _Adler32() + raise KeyError(f"unknown checksum algorithm {algorithm!r}") + + +def checksum(algorithm: str, buf) -> str: + hasher = _new_hasher(algorithm) + hasher.update(buf) + return hasher.hexdigest() + + +@contextmanager +def _apply_lock(local_ckpt_dir: str): + sync = os.path.join(local_ckpt_dir, SYNC_DIR) + os.makedirs(sync, exist_ok=True) + with open(os.path.join(sync, "lock"), "w") as f: + fcntl.flock(f, fcntl.LOCK_EX) + try: + yield + finally: + fcntl.flock(f, fcntl.LOCK_UN) + + +def _read_applied_version(local_ckpt_dir: str) -> str | None: + try: + with open(os.path.join(local_ckpt_dir, SYNC_DIR, "state.json")) as f: + return json.load(f)["version"] + except FileNotFoundError: + return None + + +def _write_applied_version(local_ckpt_dir: str, version: str) -> None: + path = os.path.join(local_ckpt_dir, SYNC_DIR, "state.json") + tmp = path + ".tmp" + with open(tmp, "w") as f: + json.dump({"version": version}, f) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, path) + + +def drop_page_cache(path: str) -> None: + """Evict a file from the page cache (POSIX_FADV_DONTNEED).""" + try: + fd = os.open(path, os.O_RDONLY) + try: + os.posix_fadvise(fd, 0, 0, os.POSIX_FADV_DONTNEED) + finally: + os.close(fd) + except OSError: + pass + + +def init_local_checkpoint(local_ckpt_dir: str, base_dir: str) -> None: + """Copy the base HF checkpoint into local_ckpt_dir once if absent (run at engine start). Each + later delta is applied on top of this copy in place.""" + with _apply_lock(local_ckpt_dir): + if _read_applied_version(local_ckpt_dir) is not None: + return + logger.info("Materializing base checkpoint %s -> %s", base_dir, local_ckpt_dir) + os.makedirs(local_ckpt_dir, exist_ok=True) + for entry in os.scandir(base_dir): + if entry.is_file(): + shutil.copy2(entry.path, os.path.join(local_ckpt_dir, entry.name)) + drop_page_cache(entry.path) # don't let the source evict the local copy we keep resident + _write_applied_version(local_ckpt_dir, "000000") + + +def _tensor_locations(ckpt_dir: str) -> dict[str, tuple[str, int, int]]: + """Map each tensor name to (file, byte offset, nbytes) by reading every safetensors header.""" + locations: dict[str, tuple[str, int, int]] = {} + for path in glob.glob(os.path.join(ckpt_dir, "*.safetensors")): + with open(path, "rb") as f: + (header_len,) = struct.unpack(" uint8 bytes`` that seeks straight to the + tensor — for reading many tensors without rescanning every header. KeyError if absent.""" + locations = _tensor_locations(ckpt_dir) + + def read(name: str) -> np.ndarray: + path, offset, nbytes = locations[name] + with open(path, "rb") as f: + f.seek(offset) + return np.frombuffer(f.read(nbytes), dtype=np.uint8) + + return read + + +def _apply_version(local_ckpt_dir: str, version_dir: str) -> None: + """Apply one version's delta in place: decompress + apply + checksum each tensor across a thread + pool (each writes a distinct mmap region, so the writes don't conflict). Any mismatch raises.""" + with open(os.path.join(version_dir, "model.safetensors.index.json")) as f: + meta = json.load(f)["metadata"] + applied = _read_applied_version(local_ckpt_dir) + if applied == meta["version"]: + return + if applied != meta["base_version"]: + raise RuntimeError(f"out-of-order delta: local at {applied}, delta builds on {meta['base_version']}") + if meta["compression_format"] != "zstd": + raise NotImplementedError(f"compression {meta['compression_format']!r} not supported") + encoding = meta["delta_encoding"] + algorithm = meta["checksum_format"] + locations = _tensor_locations(local_ckpt_dir) + open_mmaps: dict[str, tuple] = {} + mismatches: list[str] = [] + lock = threading.Lock() + file_bytes: list[bytes] = [] # keep alive: items hold zero-copy views into these + items: list[tuple] = [] # (name, compressed_view, path, offset, nbytes, want_checksum) + try: + for delta_file in sorted(glob.glob(os.path.join(version_dir, "*.safetensors"))): + with open(delta_file, "rb") as f: + blob = f.read() + file_bytes.append(blob) + (header_len,) = struct.unpack(" None: + name, compressed, path, offset, nbytes, want = item + region = np.ndarray((nbytes,), dtype=np.uint8, buffer=open_mmaps[path][1], offset=offset) + hasher = _new_hasher(algorithm) + reader = zstandard.ZstdDecompressor().stream_reader(io.BytesIO(bytes(compressed))) + pos = 0 + while pos < nbytes: # 2 MB chunks stay L2-resident across decompress -> XOR -> checksum + block = reader.read(min(2 << 20, nbytes - pos)) + if not block: + break + chunk = np.frombuffer(block, dtype=np.uint8) + region[pos : pos + chunk.size] ^= chunk + hasher.update(region[pos : pos + chunk.size]) + pos += chunk.size + if hasher.hexdigest() != want: + with lock: + mismatches.append(name) + + def apply_overwrite(item) -> None: + name, compressed, path, offset, nbytes, want = item + delta = np.frombuffer(zstandard.ZstdDecompressor().decompress(bytes(compressed)), dtype=np.uint8) + region = np.ndarray((nbytes,), dtype=np.uint8, buffer=open_mmaps[path][1], offset=offset) + count = int.from_bytes(delta[:4].tobytes(), "little") + positions = np.frombuffer(delta[4 : 4 + 4 * count].tobytes(), dtype=" None: + """Apply the delta chain in order to bring the local checkpoint up to target_version, in place. + A per-tensor checksum guards every write and any mismatch raises (fail loud, never serve bad + weights). Serialized per host by the lock (co-located actors collapse to one apply).""" + with _apply_lock(local_ckpt_dir): + applied = _read_applied_version(local_ckpt_dir) + if applied is None: + raise RuntimeError("local checkpoint not materialized") + for version in range(int(applied) + 1, target_version + 1): + _apply_version(local_ckpt_dir, os.path.join(delta_root, f"weight_v{version:06d}")) From bf498c51159eea86a3b7752b280ae3c7296f3708 Mon Sep 17 00:00:00 2001 From: Nan Date: Tue, 16 Jun 2026 04:35:02 +0000 Subject: [PATCH 2/8] docker: regenerate sglang patch without the delta-sync receiver --- docker/patch/latest/sglang.patch | 555 +++---------------------------- 1 file changed, 51 insertions(+), 504 deletions(-) diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index 191c20ad4a..f418ac4739 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -1,5 +1,5 @@ diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py -index a7bf9904a20..b0cb56aaece 100644 +index a7bf990..b0cb56a 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -32,6 +32,7 @@ class KVArgs: @@ -11,7 +11,7 @@ index a7bf9904a20..b0cb56aaece 100644 aux_data_lens: List[int] aux_item_lens: List[int] diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py -index e9efdcdd9ee..70265a424f5 100644 +index e9efdcd..70265a4 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -21,6 +21,7 @@ Life cycle of a request in the decode server @@ -183,7 +183,7 @@ index e9efdcdd9ee..70265a424f5 100644 return diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py -index b21aee9f7c2..87f0a6fa668 100644 +index b21aee9..87f0a6f 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -39,7 +39,10 @@ from sglang.srt.disaggregation.common.utils import ( @@ -322,7 +322,7 @@ index b21aee9f7c2..87f0a6fa668 100644 # Only the last chunk we need to send the aux data ret = self.send_aux( diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py -index ce1afdac3ad..de8fd054f70 100644 +index ce1afda..de8fd05 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -21,6 +21,8 @@ from __future__ import annotations @@ -440,7 +440,7 @@ index ce1afdac3ad..de8fd054f70 100644 release_kv_cache(req, self.tree_cache) # unlock the tree req.finished_reason = FINISH_LENGTH(length=0) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py -index e1d7d9c8db3..a44685777ea 100644 +index e1d7d9c..a446857 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -2,6 +2,7 @@ from __future__ import annotations @@ -646,7 +646,7 @@ index e1d7d9c8db3..a44685777ea 100644 ######################### diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py -index 88bf1947684..4ede8eb9078 100644 +index 88bf194..4ede8eb 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -71,6 +71,7 @@ from sglang.srt.managers.io_struct import ( @@ -679,7 +679,7 @@ index 88bf1947684..4ede8eb9078 100644 """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py -index d7368383d89..2c881d95bd5 100644 +index d736838..2c881d9 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -127,6 +127,7 @@ from sglang.srt.managers.io_struct import ( @@ -747,7 +747,7 @@ index d7368383d89..2c881d95bd5 100644 @auth_level(AuthLevel.ADMIN_OPTIONAL) async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py -index 435c30a5cfd..864a0f567a6 100644 +index 435c30a..864a0f5 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -299,6 +299,7 @@ class Envs: @@ -759,7 +759,7 @@ index 435c30a5cfd..864a0f567a6 100644 SGLANG_DISAGGREGATION_FORCE_QUERY_PREFILL_DP_RANK = EnvBool(False) # Extra slots in req_to_token_pool for decode workers (only effective when diff --git a/python/sglang/srt/layers/attention/dsa/dsa_indexer.py b/python/sglang/srt/layers/attention/dsa/dsa_indexer.py -index 85fcd4b9ec7..a49161f6154 100644 +index 85fcd4b..a49161f 100644 --- a/python/sglang/srt/layers/attention/dsa/dsa_indexer.py +++ b/python/sglang/srt/layers/attention/dsa/dsa_indexer.py @@ -2,6 +2,7 @@ from __future__ import annotations @@ -860,7 +860,7 @@ index 85fcd4b9ec7..a49161f6154 100644 if enable_dual_stream: current_stream = torch.cuda.current_stream() diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py -index 59ca3f9cce6..9c2d00fcd7c 100644 +index 59ca3f9..9c2d00f 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -793,6 +793,7 @@ class FusedMoE(torch.nn.Module): @@ -880,7 +880,7 @@ index 59ca3f9cce6..9c2d00fcd7c 100644 else loaded_weight ) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py -index 28a9d567a5e..e60a0bcfde0 100644 +index 28a9d56..e60a0bc 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -927,6 +927,10 @@ class CompressedTensorsLinearMethod(LinearMethodBase): @@ -906,7 +906,7 @@ index 28a9d567a5e..e60a0bcfde0 100644 self, layer: torch.nn.Module, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py -index 58562bb23db..c3dc1ceb0d2 100644 +index 58562bb..c3dc1ce 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py @@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( @@ -1020,69 +1020,10 @@ index 58562bb23db..c3dc1ceb0d2 100644 is_k_full=self.is_k_full, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py -index 987ec512122..e098565729b 100644 +index 987ec51..55a51e0 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py -@@ -1442,6 +1442,8 @@ class PauseContinueBroadcast: - class UpdateWeightFromDiskReqInput(BaseReq): - # The model path with the new weights - model_path: str -+ # Required iff ``load_format == "delta"``: basenames under ``model_path`` to apply. -+ files: Optional[List[str]] = None - # The format to load the weights - load_format: Optional[str] = None - # Whether to abort all requests before updating weights -@@ -1472,6 +1474,40 @@ class UpdateWeightFromDiskReqOutput(BaseReq): - num_paused_requests: Optional[int] = 0 - - -+class DeltaEncoding(str, Enum): -+ """Position encoding for delta weight updates.""" -+ -+ # int32 absolute nonzero offsets. -+ INDICES = "indices" -+ # uint16 gap-deltas between consecutive sorted positions; uint32 per-param fallback. -+ DELTAS = "deltas" -+ # ``deltas`` wrapped in zstd L1. -+ DELTAS_ZSTD = "deltas_zstd" -+ -+ -+@dataclass -+class DeltaParam: -+ """Per-param slice into the shared (positions, values) bucket.""" -+ -+ name: str -+ dtype: str -+ shape: List[int] -+ pos_start: int -+ pos_end: int -+ pos_width: int -+ val_start: int -+ val_end: int -+ -+ -+@dataclass -+class DeltaSpec: -+ """Decoding manifest for one delta bucket. ``checksum`` is verified on apply.""" -+ -+ encoding: DeltaEncoding -+ params: List[DeltaParam] -+ checksum: int = 0 -+ -+ - @dataclass - class UpdateWeightsFromDistributedReqInput(BaseReq): - names: List[str] -@@ -1487,6 +1523,8 @@ class UpdateWeightsFromDistributedReqInput(BaseReq): - weight_version: Optional[str] = None - # Optional format specification for loading - load_format: Optional[str] = None -+ # JSON-encoded DeltaSpec; required iff load_format == "delta". -+ delta: Optional[str] = None - # Whether to call torch.cuda.empty_cache() during flush - torch_empty_cache: bool = False - -@@ -1673,6 +1711,18 @@ class ResumeMemoryOccupationReqOutput(BaseReq): +@@ -1673,6 +1673,18 @@ class ResumeMemoryOccupationReqOutput(BaseReq): pass @@ -1101,7 +1042,7 @@ index 987ec512122..e098565729b 100644 @dataclass class CheckWeightsReqInput(BaseReq): action: str = "checksum" -@@ -2058,7 +2108,7 @@ class GetLoadsReqInput(BaseReq): +@@ -2058,7 +2070,7 @@ class GetLoadsReqInput(BaseReq): """Request for /v1/loads endpoint.""" VALID_SECTIONS = frozenset( @@ -1110,7 +1051,7 @@ index 987ec512122..e098565729b 100644 ) include: List[str] = field(default_factory=lambda: ["all"]) -@@ -2128,6 +2178,9 @@ class GetLoadsReqOutput(BaseReq): +@@ -2128,6 +2140,9 @@ class GetLoadsReqOutput(BaseReq): lora: Optional[LoRAMetrics] = None disaggregation: Optional[DisaggregationMetrics] = None queues: Optional[QueueMetrics] = None @@ -1121,7 +1062,7 @@ index 987ec512122..e098565729b 100644 @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py -index 42ea8431091..c369b070b57 100755 +index 42ea843..c369b07 100755 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -943,6 +943,7 @@ class Req(ReqDllmMixin): @@ -1159,7 +1100,7 @@ index 42ea8431091..c369b070b57 100755 ): # Even the last remaining request cannot fit in memory. diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py -index 8e32640fc6a..98966842506 100644 +index 8e32640..9896684 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -124,6 +124,7 @@ from sglang.srt.managers.io_struct import ( @@ -1206,7 +1147,7 @@ index 8e32640fc6a..98966842506 100644 def _pause_engine(self) -> Tuple[List[Req], int]: diff --git a/python/sglang/srt/managers/scheduler_components/load_inquirer.py b/python/sglang/srt/managers/scheduler_components/load_inquirer.py -index 3f10d7edaff..712322a95af 100644 +index 3f10d7e..712322a 100644 --- a/python/sglang/srt/managers/scheduler_components/load_inquirer.py +++ b/python/sglang/srt/managers/scheduler_components/load_inquirer.py @@ -202,6 +202,88 @@ class SchedulerLoadInquirer: @@ -1305,7 +1246,7 @@ index 3f10d7edaff..712322a95af 100644 + inflight=inflight, ) diff --git a/python/sglang/srt/managers/scheduler_components/output_streamer.py b/python/sglang/srt/managers/scheduler_components/output_streamer.py -index cac80715856..2574fcfb55c 100644 +index cac8071..2574fcf 100644 --- a/python/sglang/srt/managers/scheduler_components/output_streamer.py +++ b/python/sglang/srt/managers/scheduler_components/output_streamer.py @@ -481,7 +481,7 @@ class _GenerationStreamAccumulator: @@ -1318,7 +1259,7 @@ index cac80715856..2574fcfb55c 100644 dp_ranks = [dp_rank] * len(self.rids) if self.rids else None return BatchTokenIDOutput( diff --git a/python/sglang/srt/managers/scheduler_components/profiler_manager.py b/python/sglang/srt/managers/scheduler_components/profiler_manager.py -index 31df519f9e8..cdcf41cd8bc 100644 +index 31df519..cdcf41c 100644 --- a/python/sglang/srt/managers/scheduler_components/profiler_manager.py +++ b/python/sglang/srt/managers/scheduler_components/profiler_manager.py @@ -377,7 +377,7 @@ class SchedulerProfilerManager: @@ -1331,7 +1272,7 @@ index 31df519f9e8..cdcf41cd8bc 100644 if self.profile_in_progress: # force trace flush diff --git a/python/sglang/srt/managers/scheduler_components/weight_updater.py b/python/sglang/srt/managers/scheduler_components/weight_updater.py -index 77bf823b081..9ab3abe5618 100644 +index 77bf823..9ab3abe 100644 --- a/python/sglang/srt/managers/scheduler_components/weight_updater.py +++ b/python/sglang/srt/managers/scheduler_components/weight_updater.py @@ -16,6 +16,7 @@ from sglang.srt.constants import ( @@ -1421,7 +1362,7 @@ index 77bf823b081..9ab3abe5618 100644 return ResumeMemoryOccupationReqOutput() diff --git a/python/sglang/srt/managers/tokenizer_control_mixin.py b/python/sglang/srt/managers/tokenizer_control_mixin.py -index c9939a1fc93..ee25e5e70e0 100644 +index c9939a1..ee25e5e 100644 --- a/python/sglang/srt/managers/tokenizer_control_mixin.py +++ b/python/sglang/srt/managers/tokenizer_control_mixin.py @@ -48,6 +48,8 @@ from sglang.srt.managers.io_struct import ( @@ -1459,7 +1400,7 @@ index c9939a1fc93..ee25e5e70e0 100644 self: TokenizerManager, obj: CheckWeightsReqInput, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index 357e3c4675a..1f6dc90e471 100644 +index 357e3c4..71319d7 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1641,7 +1641,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): @@ -1480,24 +1421,6 @@ index 357e3c4675a..1f6dc90e471 100644 self.is_pause_cond.notify_all() async def update_weights_from_disk( -@@ -1704,7 +1704,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): - self.model_update_result = asyncio.Future() - if self.server_args.dp_size == 1: - result = await self.model_update_result -- if result.success: -+ if result.success and obj.load_format != "delta": - self._update_model_path_info(obj.model_path, obj.load_format) - return result.success, result.message, result.num_paused_requests - else: # self.server_args.dp_size > 1 -@@ -1712,7 +1712,7 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): - result = await self.model_update_result - - all_success = all([r.success for r in result]) -- if all_success is True: -+ if all_success is True and obj.load_format != "delta": - self._update_model_path_info(obj.model_path, obj.load_format) - all_message = [r.message for r in result] - all_message = " | ".join(all_message) @@ -2343,25 +2343,23 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): priority = getattr(state.obj, "priority", None) if priority is not None: @@ -1532,7 +1455,7 @@ index 357e3c4675a..1f6dc90e471 100644 if state.finished: # Get detailed cache breakdown if available diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py -index bd9184408ed..71bbe8f400f 100644 +index bd91844..b7cb6d4 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -29,6 +29,7 @@ from sglang.srt.managers.io_struct import ( @@ -1543,31 +1466,22 @@ index bd9184408ed..71bbe8f400f 100644 SendWeightsToRemoteInstanceReqInput, UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, -@@ -98,6 +99,7 @@ class BaseTpWorker(ABC): - recv_req.model_path, - recv_req.load_format, - recapture_cuda_graph=recv_req.recapture_cuda_graph, -+ files=recv_req.files, +@@ -155,6 +156,13 @@ class BaseTpWorker(ABC): ) return success, message -@@ -152,6 +154,14 @@ class BaseTpWorker(ABC): - recv_req.shapes, - recv_req.group_name, - recv_req.load_format, -+ recv_req.delta, -+ ) -+ return success, message -+ + def post_process_weights(self, recv_req: PostProcessWeightsReqInput): + success, message = self.model_runner.post_process_weights( + restore_weights_before_load=recv_req.restore_weights_before_load, + post_process_quantization=recv_req.post_process_quantization, - ) - return success, message ++ ) ++ return success, message ++ + def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): + monkey_patch_torch_reductions() diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py -index 353a02ee0be..7e3e3f58cb9 100644 +index 353a02e..7e3e3f5 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -1009,9 +1009,7 @@ class HiRadixCache(RadixCache): @@ -1590,7 +1504,7 @@ index 353a02ee0be..7e3e3f58cb9 100644 def _evict_regular(self, node: TreeNode): diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py -index 8efe9aae94e..79e9885c92f 100644 +index 8efe9aa..79e9885 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -2244,9 +2244,12 @@ class DSATokenToKVPool(MLATokenToKVPool): @@ -1610,7 +1524,7 @@ index 8efe9aae94e..79e9885c92f 100644 self.index_k_with_scale_buffer = [ torch.zeros( diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py -index bd6adb6e398..5ea935f76e9 100644 +index bd6adb6..5ea935f 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -467,6 +467,9 @@ class RadixCache(KVCacheEventMixin, BasePrefixCache): @@ -1635,37 +1549,18 @@ index bd6adb6e398..5ea935f76e9 100644 return DecLockRefResult(delta=delta) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 3b30eb0e1f7..d715bc6893d 100644 +index 3b30eb0..0e33834 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py -@@ -20,7 +20,9 @@ import datetime +@@ -20,6 +20,7 @@ import datetime import gc import hashlib import inspect +import json import logging -+import math import os import socket - import threading -@@ -28,7 +30,7 @@ import time - from collections import defaultdict - from dataclasses import dataclass, replace - from pathlib import Path --from typing import Any, Callable, List, Optional, Tuple, Union -+from typing import Any, Callable, Dict, List, Optional, Tuple, Union - - import torch - import torch.distributed as dist -@@ -137,6 +139,7 @@ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model - from sglang.srt.layers.utils.cp_utils import is_mla_prefill_cp_enabled - from sglang.srt.lora.lora_manager import LoRAManager - from sglang.srt.lora.lora_registry import LoRARef -+from sglang.srt.managers.io_struct import DeltaEncoding, DeltaParam, DeltaSpec - from sglang.srt.managers.schedule_batch import sanity_check_mm_pad_shift_value - from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator - from sglang.srt.mem_cache.memory_pool import ReqToTokenPool -@@ -548,7 +551,10 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -548,7 +549,10 @@ class ModelRunner(ModelRunnerKVCacheMixin): self.forward_stream = torch.get_device_module(self.device).Stream() # CPU offload @@ -1677,7 +1572,7 @@ index 3b30eb0e1f7..d715bc6893d 100644 self._weight_checker = WeightChecker(model_runner=self) -@@ -796,7 +802,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -796,7 +800,8 @@ class ModelRunner(ModelRunnerKVCacheMixin): self.maybe_init_ngram_embedding() # Init routed experts capturer @@ -1687,202 +1582,16 @@ index 3b30eb0e1f7..d715bc6893d 100644 self.init_indexer_capturer() -@@ -1657,8 +1664,14 @@ class ModelRunner(ModelRunnerKVCacheMixin): - load_format: str, +@@ -1658,7 +1663,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): weight_name_filter: Optional[Callable[[str], bool]] = None, recapture_cuda_graph: bool = False, -+ files: Optional[List[str]] = None, ) -> tuple[bool, str]: - """Update engine weights in-place from the disk.""" + """Update engine weights in-place from disk.""" -+ if load_format == "delta": -+ if not files: -+ return False, "load_format='delta' requires non-empty `files`" -+ return self._apply_delta([os.path.join(model_path, f) for f in files]) -+ logger.info( f"Update engine weights online from disk begin. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id, empty_cache=False):.2f} GB" -@@ -1888,6 +1901,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): - shapes, - group_name, - load_format: Optional[str] = None, -+ delta: Optional[str] = None, - ): - """ - Update specific parameter in the model weights online -@@ -1908,6 +1922,18 @@ class ModelRunner(ModelRunnerKVCacheMixin): - return self._update_bucketed_weights_from_distributed( - names, dtypes, shapes, group_name - ) -+ if load_format == "delta": -+ if delta is None: -+ return False, "load_format='delta' requires a DeltaSpec in the request" -+ spec_dict = json.loads(delta) -+ spec = DeltaSpec( -+ encoding=DeltaEncoding(spec_dict["encoding"]), -+ params=[DeltaParam(**p) for p in spec_dict["params"]], -+ checksum=int(spec_dict["checksum"]), -+ ) -+ return self._apply_delta_from_distributed( -+ names, dtypes, shapes, group_name, spec -+ ) - try: - weights = [] - handles = [] -@@ -1971,6 +1997,151 @@ class ModelRunner(ModelRunnerKVCacheMixin): - logger.error(error_msg) - return False, error_msg - -+ def _decode_delta_one_param( -+ self, -+ encoding: DeltaEncoding, -+ positions: torch.Tensor, -+ values: torch.Tensor, -+ p: DeltaParam, -+ ) -> torch.Tensor: -+ """Decode one param's sparse delta into a NaN-masked full tensor.""" -+ numel = math.prod(p.shape) -+ param_dtype = _resolve_torch_dtype(p.dtype) -+ flat = torch.full((numel,), float("nan"), dtype=param_dtype, device=self.device) -+ val_slice = values[p.val_start : p.val_end] -+ if val_slice.numel() == 0: -+ return flat.view(tuple(p.shape)) -+ -+ pos_bytes = positions[p.pos_start : p.pos_end] -+ if encoding is DeltaEncoding.INDICES: -+ width = 4 -+ elif encoding in (DeltaEncoding.DELTAS, DeltaEncoding.DELTAS_ZSTD): -+ width = p.pos_width -+ else: -+ raise ValueError(f"unsupported delta encoding: {encoding!r}") -+ -+ n_elems = pos_bytes.numel() // width -+ b = pos_bytes.view(n_elems, width).to(torch.int64) -+ if width == 2: -+ unpacked = b[:, 0] | (b[:, 1] << 8) -+ else: -+ unpacked = b[:, 0] | (b[:, 1] << 8) | (b[:, 2] << 16) | (b[:, 3] << 24) -+ -+ if encoding is DeltaEncoding.INDICES: -+ idx = unpacked -+ else: -+ idx = (unpacked + 1).cumsum(dim=0) - 1 -+ -+ flat.index_copy_(0, idx, val_slice.to(param_dtype)) -+ return flat.view(tuple(p.shape)) -+ -+ def _apply_delta_payload( -+ self, -+ encoding: DeltaEncoding, -+ params: List[DeltaParam], -+ positions: torch.Tensor, -+ values: torch.Tensor, -+ expected_checksum: int, -+ ) -> None: -+ actual_checksum = _delta_checksum(positions, values) -+ if actual_checksum != expected_checksum: -+ raise RuntimeError( -+ f"delta checksum mismatch: expected={expected_checksum} got={actual_checksum}" -+ ) -+ -+ chunk_byte_cap = self.server_args.update_weight_delta_chunk_bytes -+ with _delta_apply_context(self.model): -+ chunk: List[Tuple[str, torch.Tensor]] = [] -+ chunk_bytes = 0 -+ for p in params: -+ t = self._decode_delta_one_param(encoding, positions, values, p) -+ tensor_bytes = t.numel() * t.element_size() -+ if chunk_bytes + tensor_bytes > chunk_byte_cap and chunk: -+ self.model.load_weights(chunk) -+ chunk = [] -+ chunk_bytes = 0 -+ chunk.append((p.name, t)) -+ chunk_bytes += tensor_bytes -+ if chunk: -+ self.model.load_weights(chunk) -+ -+ def _decode_and_apply_blob(self, blob: bytes) -> None: -+ from safetensors.torch import load as st_load -+ -+ hdr_len = int.from_bytes(blob[:8], "little") -+ meta = json.loads(blob[8 : 8 + hdr_len]).get("__metadata__", {}) -+ encoding = DeltaEncoding(meta["encoding"]) -+ params = [DeltaParam(**p) for p in json.loads(meta["params"])] -+ expected_checksum = int(meta["checksum"]) -+ -+ tensors = st_load(blob) -+ positions = tensors["__positions__"].to(self.device, non_blocking=True) -+ values = tensors["__values__"].to(self.device, non_blocking=True) -+ self._apply_delta_payload( -+ encoding, params, positions, values, expected_checksum -+ ) -+ -+ def _apply_delta_from_distributed( -+ self, -+ names: List[str], -+ dtypes: List[str], -+ shapes: List[List[int]], -+ group_name: str, -+ delta: DeltaSpec, -+ ) -> tuple[bool, str]: -+ try: -+ recv: Dict[str, torch.Tensor] = {} -+ handles = [] -+ for name, dtype, shape in zip(names, dtypes, shapes): -+ target_dtype = _resolve_torch_dtype(dtype) -+ t = torch.empty(shape, dtype=target_dtype, device=self.device) -+ handles.append( -+ torch.distributed.broadcast( -+ t, -+ src=0, -+ group=self._model_update_group[group_name], -+ async_op=True, -+ ) -+ ) -+ recv[name] = t -+ for handle in handles: -+ handle.wait() -+ -+ self._apply_delta_payload( -+ delta.encoding, -+ delta.params, -+ recv["__positions__"], -+ recv["__values__"], -+ delta.checksum, -+ ) -+ return True, "ok" -+ except Exception as e: -+ error_msg = f"Failed to apply delta from distributed: {e}." -+ logger.error(error_msg) -+ return False, error_msg -+ -+ def _apply_delta(self, paths: List[str]) -> tuple[bool, str]: -+ import concurrent.futures -+ -+ n_files = len(paths) -+ workers = min(n_files, self.server_args.update_weight_delta_read_workers) -+ -+ def _read_and_decompress(path: str) -> bytes: -+ with open(path, "rb") as fh: -+ return _maybe_zstd_decompress(fh.read()) -+ -+ try: -+ for i in range(0, n_files, workers): -+ with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as pool: -+ batch = list(pool.map(_read_and_decompress, paths[i : i + workers])) -+ for blob in batch: -+ self._decode_and_apply_blob(blob) -+ return True, f"Applied {n_files} delta file(s)" -+ except Exception as e: -+ error_msg = f"Failed to apply delta update from disk: {e}." -+ logger.error(error_msg) -+ return False, error_msg -+ - def update_weights_from_tensor( - self, - named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]], -@@ -3468,11 +3639,17 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -3468,11 +3473,17 @@ class ModelRunner(ModelRunnerKVCacheMixin): output.expert_distribution_metrics = recorder_outputs.get("metrics") no_copy_to_cpu = not self.server_args.disable_overlap_schedule @@ -1901,7 +1610,7 @@ index 3b30eb0e1f7..d715bc6893d 100644 no_copy_to_cpu=no_copy_to_cpu, ) -@@ -3480,7 +3657,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -3480,7 +3491,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): output.indexer_topk_output = indexer_capturer.on_forward_end( forward_batch=forward_batch, can_run_graph=output.can_run_graph, @@ -1910,7 +1619,7 @@ index 3b30eb0e1f7..d715bc6893d 100644 no_copy_to_cpu=no_copy_to_cpu, ) -@@ -3718,6 +3895,39 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -3718,6 +3729,39 @@ class ModelRunner(ModelRunnerKVCacheMixin): logger.error(f"IPC weight update failed: {e}") return False, str(e) @@ -1950,120 +1659,10 @@ index 3b30eb0e1f7..d715bc6893d 100644 def prealloc_symmetric_memory_pool(self): # PyTorch mempools never de-fragment memory in OOM scenarios, so we need to pre-allocate a large chunk of memory to limit fragmentation. if ( -@@ -3767,6 +3977,123 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -3767,6 +3811,13 @@ class ModelRunner(ModelRunnerKVCacheMixin): return output -+def _param_storage_index(model): -+ import bisect -+ -+ starts: List[int] = [] -+ ends: List[int] = [] -+ owners: List[torch.Tensor] = [] -+ seen: set = set() -+ for tensors in (model.named_parameters(), model.named_buffers()): -+ for _, t in tensors: -+ if t.is_meta: -+ continue -+ try: -+ ptr = t.data_ptr() -+ except RuntimeError: -+ continue -+ if ptr == 0 or ptr in seen: -+ continue -+ seen.add(ptr) -+ sz = t.numel() * t.element_size() -+ starts.append(ptr) -+ ends.append(ptr + sz) -+ owners.append(t) -+ -+ order = sorted(range(len(starts)), key=lambda i: starts[i]) -+ starts = [starts[i] for i in order] -+ ends = [ends[i] for i in order] -+ owners = [owners[i] for i in order] -+ -+ def find_parent(dst): -+ try: -+ ptr = dst.data_ptr() -+ except RuntimeError: -+ return None -+ idx = bisect.bisect_right(starts, ptr) - 1 -+ if 0 <= idx < len(starts) and starts[idx] <= ptr < ends[idx]: -+ return owners[idx] -+ return None -+ -+ return find_parent -+ -+ -+@contextlib.contextmanager -+def _delta_apply_context(model): -+ is_param_target = _param_storage_index(model) -+ original_copy_ = torch.Tensor.copy_ -+ original_fill_ = torch.Tensor.fill_ -+ -+ def patched_copy_(self, src, *args, **kwargs): -+ if is_param_target(self) is not None: -+ src_aligned = ( -+ src.to(device=self.device, dtype=self.dtype) -+ if src.dtype != self.dtype -+ else src -+ ) -+ mask = ~torch.isnan(src_aligned) -+ self[mask] = src_aligned[mask] -+ return self -+ return original_copy_(self, src, *args, **kwargs) -+ -+ def patched_fill_(self, value): -+ if is_param_target(self) is not None: -+ try: -+ if math.isnan(value): -+ return self -+ except TypeError: -+ pass -+ return original_fill_(self, value) -+ return original_fill_(self, value) -+ -+ original_post_load = getattr(model, "post_load_weights", None) -+ if original_post_load is not None: -+ -+ def wrapped_post_load(*args, **kwargs): -+ current_copy = torch.Tensor.copy_ -+ current_fill = torch.Tensor.fill_ -+ torch.Tensor.copy_ = original_copy_ -+ torch.Tensor.fill_ = original_fill_ -+ try: -+ return original_post_load(*args, **kwargs) -+ finally: -+ torch.Tensor.copy_ = current_copy -+ torch.Tensor.fill_ = current_fill -+ -+ model.post_load_weights = wrapped_post_load -+ -+ torch.Tensor.copy_ = patched_copy_ -+ torch.Tensor.fill_ = patched_fill_ -+ try: -+ yield -+ finally: -+ torch.Tensor.copy_ = original_copy_ -+ torch.Tensor.fill_ = original_fill_ -+ if original_post_load is not None: -+ model.post_load_weights = original_post_load -+ -+ -+def _delta_checksum(positions: torch.Tensor, values: torch.Tensor) -> int: -+ p = int(torch.hash_tensor(positions).item()) if positions.numel() else 0 -+ v = int(torch.hash_tensor(values).item()) if values.numel() else 0 -+ return p ^ (v << 1) -+ -+ -+def _maybe_zstd_decompress(blob: bytes) -> bytes: -+ if blob.startswith(b"\x28\xb5\x2f\xfd"): -+ import zstandard -+ -+ return zstandard.ZstdDecompressor().decompress(blob) -+ return blob -+ -+ +def _resolve_torch_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype @@ -2075,7 +1674,7 @@ index 3b30eb0e1f7..d715bc6893d 100644 params_dict = dict(model.named_parameters()) for name, tensor in named_tensors: diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py -index 2f0074924db..8d62df83c74 100644 +index 2f00749..8d62df8 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py @@ -45,6 +45,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): @@ -2186,7 +1785,7 @@ index 2f0074924db..8d62df83c74 100644 continue diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py -index 3ffe4dde7fd..9869f11623e 100644 +index 3ffe4dd..9869f11 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -1034,9 +1034,14 @@ class Qwen3LLMModel(Qwen3Model): @@ -2208,7 +1807,7 @@ index 3ffe4dde7fd..9869f11623e 100644 positions, hidden_states, diff --git a/python/sglang/srt/multimodal/processors/glm4v.py b/python/sglang/srt/multimodal/processors/glm4v.py -index db684259d2f..17d2cb6958a 100644 +index db68425..17d2cb6 100644 --- a/python/sglang/srt/multimodal/processors/glm4v.py +++ b/python/sglang/srt/multimodal/processors/glm4v.py @@ -1,7 +1,13 @@ @@ -2276,7 +1875,7 @@ index db684259d2f..17d2cb6958a 100644 image_grid_thw = None video_grid_thw = None diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py -index b8774ebade5..fa01537b201 100644 +index b8774eb..fa01537 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -678,7 +678,7 @@ class QwenVLImageProcessor(SGLangBaseProcessor): @@ -2289,7 +1888,7 @@ index b8774ebade5..fa01537b201 100644 image_data=image_data, video_data=request_obj.video_data, diff --git a/python/sglang/srt/observability/req_time_stats.py b/python/sglang/srt/observability/req_time_stats.py -index 2de10730c94..d3ce2c62d21 100644 +index 2de1073..d3ce2c6 100644 --- a/python/sglang/srt/observability/req_time_stats.py +++ b/python/sglang/srt/observability/req_time_stats.py @@ -23,7 +23,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -2552,40 +2151,8 @@ index 2de10730c94..d3ce2c62d21 100644 return meta_data def format_duration(self, duration: float) -> str: -diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py -index 6c77ff64f92..706161ab928 100644 ---- a/python/sglang/srt/server_args.py -+++ b/python/sglang/srt/server_args.py -@@ -854,6 +854,8 @@ class ServerArgs: - weight_loader_prefetch_checkpoints: bool = False - weight_loader_prefetch_num_threads: int = 4 - weight_loader_drop_cache_after_load: bool = False -+ update_weight_delta_chunk_bytes: int = 512 * 1024 * 1024 -+ update_weight_delta_read_workers: int = 4 - remote_instance_weight_loader_seed_instance_ip: Optional[str] = None - remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None - remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None -@@ -7119,6 +7121,18 @@ class ServerArgs: - action="store_true", - help="Call posix_fadvise(DONTNEED) on each safetensors shard after loading it.", - ) -+ parser.add_argument( -+ "--update-weight-delta-chunk-bytes", -+ type=int, -+ default=ServerArgs.update_weight_delta_chunk_bytes, -+ help="Maximum bytes per delta weight chunk when applying delta updates.", -+ ) -+ parser.add_argument( -+ "--update-weight-delta-read-workers", -+ type=int, -+ default=ServerArgs.update_weight_delta_read_workers, -+ help="Number of worker threads used to read delta weight files.", -+ ) - parser.add_argument( - "--remote-instance-weight-loader-seed-instance-ip", - type=str, diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py -index 96c7286af76..9e3e2bd7142 100644 +index 96c7286..9e3e2bd 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -458,8 +458,12 @@ class EAGLEDraftCudaGraphRunner: @@ -2604,7 +2171,7 @@ index 96c7286af76..9e3e2bd7142 100644 buffers.hidden_states is not None and forward_batch.spec_info.hidden_states is not None diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py -index 6bf5d6182af..70de75f20be 100644 +index 6bf5d61..61b9603 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -530,6 +530,21 @@ class EagleDraftWorker(BaseDraftWorker): @@ -2640,28 +2207,8 @@ index 6bf5d6182af..70de75f20be 100644 # Organize the results if ( self.topk == 1 -@@ -1480,6 +1499,7 @@ class EAGLEWorkerV2(BaseSpecWorker): - recv_req.model_path, - recv_req.load_format, - recapture_cuda_graph=recv_req.recapture_cuda_graph, -+ files=recv_req.files, - ) - if not success: - return success, message -diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py -index 04b3841a23d..9aaf6b30673 100644 ---- a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py -+++ b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py -@@ -856,6 +856,7 @@ class MultiLayerEagleWorkerV2(BaseSpecWorker): - recv_req.model_path, - recv_req.load_format, - recapture_cuda_graph=recv_req.recapture_cuda_graph, -+ files=recv_req.files, - ) - if not success: - return success, message diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py -index 4556d06b16f..9c28114f85d 100644 +index 4556d06..9c28114 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2399,6 +2399,7 @@ class SafeUnpickler(pickle.Unpickler): From 35d416f5708a687693dc5d7a84e511a0b43ee371 Mon Sep 17 00:00:00 2001 From: Nan Date: Tue, 16 Jun 2026 05:16:18 +0000 Subject: [PATCH 3/8] delta: simplify engine sync path, record update-weight time sync_local_checkpoint (was sync_weights) materializes the base lazily via the idempotent init_local_checkpoint instead of a background thread; record per-sync update time in update_weight_metrics; state the pre-read/pre-push hooks' purpose (non-POSIX filesystem coherence). --- .../update_weight_from_disk_delta.py | 29 ++++++++--- slime/backends/sglang_utils/sglang_engine.py | 50 +++---------------- slime/utils/arguments.py | 11 ++-- 3 files changed, 36 insertions(+), 54 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py b/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py index 0a56920724..a732ef8b5b 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py @@ -5,6 +5,7 @@ import os import queue import shutil +import time from argparse import Namespace from collections import deque from collections.abc import Callable, Mapping, Sequence @@ -91,9 +92,11 @@ def update_weights(self) -> None: ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) + t0 = time.perf_counter() self._publish() + publish_s = time.perf_counter() - t0 self._reload_engines() - self._record_metrics() + self._record_metrics(publish_s, time.perf_counter() - t0) def _capture_baseline(self) -> None: """Capture the baseline snapshot the first delta diffs against (no publish), and clear any @@ -172,7 +175,7 @@ def _reload_engines(self) -> None: self._commit_hook(self.args, self._version_dir, list(self.rollout_engines)) dist.barrier(group=get_gloo_group()) if dist.get_rank() == 0: - ray.get([actor.sync_weights.remote(self.weight_version) for actor in self.all_engine_actors]) + ray.get([actor.sync_local_checkpoint.remote(self.weight_version) for actor in self.all_engine_actors]) ray.get( [ engine.update_weights_from_disk.remote( @@ -268,9 +271,9 @@ def collect(fut): finally: pool.shutdown() - def _record_metrics(self) -> None: - """All-reduce the byte counts and record changed-fraction + wire size; the actor drains - update_weight_metrics onto the step log.""" + def _record_metrics(self, publish_s: float, total_s: float) -> None: + """All-reduce the byte counts and record changed-fraction / wire size / sync time; the + actor drains update_weight_metrics onto the step log.""" counts = torch.tensor( [self.changed_bytes, self.total_bytes, self.wire_bytes], dtype=torch.int64, @@ -278,8 +281,20 @@ def _record_metrics(self) -> None: ) dist.all_reduce(counts) changed, total, wire = counts.tolist() - self.update_weight_metrics["perf/update_weights_density"] = changed / max(total, 1) - self.update_weight_metrics["perf/update_weights_wire_bytes"] = wire + m = self.update_weight_metrics + m["perf/update_weights_density"] = changed / max(total, 1) + m["perf/update_weights_wire_bytes"] = wire + m["perf/update_weights_total_s"] = total_s + if dist.get_rank() == 0: + logger.info( + "[disk delta v=%s] update %.1fs (publish %.1fs reload %.1fs) | density=%.2f%% wire=%.2f GB", + self.weight_version, + total_s, + publish_s, + total_s - publish_s, + 100.0 * changed / max(total, 1), + wire / 1e9, + ) def _atomic_write(path: str, data: bytes) -> None: diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index 04f723fcef..d87f6614c5 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -1,10 +1,8 @@ import dataclasses -import glob import ipaddress import logging import multiprocessing import os -import threading import time from urllib.parse import quote @@ -169,15 +167,6 @@ def _format_v6_uri(addr): else: self._init_normal(server_args_dict) - # Disk-delta sync: each per-host actor materializes its local checkpoint in the background - # while the engine launches and serves; the first sync joins this thread. - self._base_init_thread = None - if self.args.update_weight_mode == "delta" and self.args.update_weight_transport == "disk": - self._base_init_thread = threading.Thread( - target=self._init_local_checkpoint, name="delta-base-init", daemon=True - ) - self._base_init_thread.start() - def _init_external(self, expect_server_args, external_engine_need_check_fields): logger.info(f"Use external SGLang engine (rank={self.rank}, expect_server_args={expect_server_args})") @@ -390,38 +379,15 @@ def resume_memory_occupation(self, tags: list[str] = None): def check_weights(self, action: str): return self._make_request("weights_checker", {"action": action}) - def _init_local_checkpoint(self): - """Background thread: copy the base into the host-local checkpoint, then drop the source.""" - from slime.utils.disk_delta import init_local_checkpoint - - init_local_checkpoint(self.args.update_weight_local_checkpoint_dir, self.args.hf_checkpoint) - self._drop_hf_cache() - - def _drop_hf_cache(self): - # sglang loads the HF checkpoint once at init and never re-reads it (weight updates read the - # local base), so dropping it from the page cache keeps the local base resident instead. - from slime.utils.disk_delta import drop_page_cache - - for path in glob.glob(os.path.join(self.args.hf_checkpoint, "*")): - if os.path.isfile(path): - drop_page_cache(path) - - def _ensure_base_ready(self): - """Join the one-time base-materialization thread on the first sync, then re-drop the HF - source (sglang's init-load may have re-cached it after the thread dropped it).""" - if self._base_init_thread is None: - return - self._base_init_thread.join() - self._base_init_thread = None - self._drop_hf_cache() - - def sync_weights(self, target_version: int): - """Bring this host's local checkpoint to ``target_version`` by applying the published - deltas (raises on any error). Plain file work; the sglang server is untouched until the - subsequent reload.""" - from slime.utils.disk_delta import apply_deltas + def sync_local_checkpoint(self, target_version: int): + """Apply the published deltas into this host's local checkpoint up to target_version; the + engine reloads it afterwards. Assumes this actor shares the checkpoint filesystem with the + sglang it drives (true for slime-launched engines).""" + from slime.utils.disk_delta import apply_deltas, init_local_checkpoint - self._ensure_base_ready() + init_local_checkpoint(self.args.update_weight_local_checkpoint_dir, self.args.hf_checkpoint) # idempotent + # non-POSIX filesystems lack cross-host read-after-write consistency, so the trainer's + # just-written delta isn't visible on this mount until the hook refreshes it. if self.args.custom_delta_pre_read_path: from slime.utils.misc import load_function diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 31552864f0..f8f4672d70 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -211,10 +211,10 @@ def add_train_arguments(parser): type=str, default=None, help=( - "Path to a custom function called by --update-weight-transport=disk after each " - "trainer rank's files are durably on local disk, before rank 0 fires the engine " - "RPCs. Signature: ``def hook(args, version_dir: str, rollout_engines) -> None``. " - "Called from every trainer rank; the hook gates itself." + "Path to a custom function called on each trainer rank after its delta files " + "are written, before the engines read them — to publish the writes on a " + "non-POSIX filesystem (no cross-host visibility without an explicit sync). " + "Signature: ``def hook(args, version_dir: str, rollout_engines) -> None``; the hook gates itself." ), ) parser.add_argument( @@ -223,7 +223,8 @@ def add_train_arguments(parser): default=None, help=( "Path to a custom function called on each rollout host before it reads the " - "published delta directory (shared-FS visibility, e.g. a volume reload). " + "published delta directory — refreshes the mount so the just-published version " + "is visible on a non-POSIX filesystem (no cross-host read-after-write consistency). " "Signature: ``def hook(delta_dir: str, target_version: int) -> None``." ), ) From e8a23fab55ec390ec9de6324487426740c17a3a6 Mon Sep 17 00:00:00 2001 From: Nan Date: Tue, 16 Jun 2026 05:19:11 +0000 Subject: [PATCH 4/8] delta: drop redundant update-weight timing The actor's update_weights is already @timer-wrapped (perf/update_weights_time), so the per-sync total/publish/reload breakdown was duplicate instrumentation. Keep only the delta-specific metrics (density, wire bytes). --- .../update_weight_from_disk_delta.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py b/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py index a732ef8b5b..ce2d6947ce 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_disk_delta.py @@ -5,7 +5,6 @@ import os import queue import shutil -import time from argparse import Namespace from collections import deque from collections.abc import Callable, Mapping, Sequence @@ -92,11 +91,9 @@ def update_weights(self) -> None: ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) - t0 = time.perf_counter() self._publish() - publish_s = time.perf_counter() - t0 self._reload_engines() - self._record_metrics(publish_s, time.perf_counter() - t0) + self._record_metrics() def _capture_baseline(self) -> None: """Capture the baseline snapshot the first delta diffs against (no publish), and clear any @@ -271,9 +268,9 @@ def collect(fut): finally: pool.shutdown() - def _record_metrics(self, publish_s: float, total_s: float) -> None: - """All-reduce the byte counts and record changed-fraction / wire size / sync time; the - actor drains update_weight_metrics onto the step log.""" + def _record_metrics(self) -> None: + """All-reduce the byte counts and record changed-fraction / wire size; the actor drains + update_weight_metrics onto the step log.""" counts = torch.tensor( [self.changed_bytes, self.total_bytes, self.wire_bytes], dtype=torch.int64, @@ -284,14 +281,10 @@ def _record_metrics(self, publish_s: float, total_s: float) -> None: m = self.update_weight_metrics m["perf/update_weights_density"] = changed / max(total, 1) m["perf/update_weights_wire_bytes"] = wire - m["perf/update_weights_total_s"] = total_s if dist.get_rank() == 0: logger.info( - "[disk delta v=%s] update %.1fs (publish %.1fs reload %.1fs) | density=%.2f%% wire=%.2f GB", + "[disk delta v=%s] density=%.2f%% wire=%.2f GB", self.weight_version, - total_s, - publish_s, - total_s - publish_s, 100.0 * changed / max(total, 1), wire / 1e9, ) From 0b9ac54c1b1bc878fe22fb4a7c9a1b093690380c Mon Sep 17 00:00:00 2001 From: Nan Date: Tue, 16 Jun 2026 05:24:29 +0000 Subject: [PATCH 5/8] delta: update arg-validation tests to match disk-only delta design The delta scaffold reworked the update-weight args: delta requires --update-weight-transport=disk (was nccl-or-disk), needs --update-weight-local-checkpoint-dir, and the --update-weight-delta-dir compatibility alias is gone (the directory belongs to the transport, not the encoding). Drop the alias resolve/backfill/conflict tests, point the transport and colocate tests at the disk path, and cover the local-checkpoint requirement. --- tests/test_megatron_argument_validation.py | 74 +++++++--------------- 1 file changed, 22 insertions(+), 52 deletions(-) diff --git a/tests/test_megatron_argument_validation.py b/tests/test_megatron_argument_validation.py index 1f435cb577..a8a9a12d7f 100644 --- a/tests/test_megatron_argument_validation.py +++ b/tests/test_megatron_argument_validation.py @@ -173,57 +173,12 @@ def test_update_weight_disk_dir_required_for_disk_transport(monkeypatch): args = types.SimpleNamespace( update_weight_transport="disk", update_weight_disk_dir=None, - update_weight_delta_dir=None, ) with pytest.raises(ValueError, match="update-weight-disk-dir"): module._resolve_update_weight_disk_dir(args) -@pytest.mark.unit -def test_update_weight_disk_dir_normalizes_delta_alias(monkeypatch): - module = load_slime_arguments_module(monkeypatch) - args = types.SimpleNamespace( - update_weight_transport="disk", - update_weight_disk_dir=None, - update_weight_delta_dir="/shared/delta", - ) - - with pytest.warns(UserWarning, match="will be removed in a future release"): - module._resolve_update_weight_disk_dir(args) - - assert args.update_weight_disk_dir == "/shared/delta" - assert args.update_weight_delta_dir == "/shared/delta" - - -@pytest.mark.unit -def test_update_weight_disk_dir_backfills_legacy_delta_field(monkeypatch): - module = load_slime_arguments_module(monkeypatch) - args = types.SimpleNamespace( - update_weight_transport="disk", - update_weight_disk_dir="/shared/updates", - update_weight_delta_dir=None, - ) - - module._resolve_update_weight_disk_dir(args) - - assert args.update_weight_disk_dir == "/shared/updates" - assert args.update_weight_delta_dir == "/shared/updates" - - -@pytest.mark.unit -def test_update_weight_disk_dir_rejects_conflicting_alias(monkeypatch): - module = load_slime_arguments_module(monkeypatch) - args = types.SimpleNamespace( - update_weight_transport="disk", - update_weight_disk_dir="/shared/full", - update_weight_delta_dir="/shared/delta", - ) - - with pytest.raises(ValueError, match="deprecated alias"): - module._resolve_update_weight_disk_dir(args) - - def make_slime_validate_args(**overrides): values = dict( eval_config=None, @@ -353,13 +308,28 @@ def test_slime_validate_args_preserves_zero_rollout_gpus_without_colocate(monkey @pytest.mark.unit -def test_update_weight_delta_rejects_colocate(monkeypatch): +def test_update_weight_delta_requires_disk_transport(monkeypatch): module = load_slime_arguments_module(monkeypatch) args = types.SimpleNamespace( update_weight_mode="delta", update_weight_transport="nccl", update_weight_disk_dir=None, - update_weight_delta_dir=None, + update_weight_local_checkpoint_dir="/local/ckpt", + colocate=False, + ) + + with pytest.raises(ValueError, match="requires --update-weight-transport=disk"): + module._validate_update_weight_args(args) + + +@pytest.mark.unit +def test_update_weight_delta_rejects_colocate(monkeypatch): + module = load_slime_arguments_module(monkeypatch) + args = types.SimpleNamespace( + update_weight_mode="delta", + update_weight_transport="disk", + update_weight_disk_dir="/shared/delta", + update_weight_local_checkpoint_dir="/local/ckpt", colocate=True, ) @@ -368,17 +338,17 @@ def test_update_weight_delta_rejects_colocate(monkeypatch): @pytest.mark.unit -def test_update_weight_delta_rejects_unknown_transport(monkeypatch): +def test_update_weight_delta_requires_local_checkpoint_dir(monkeypatch): module = load_slime_arguments_module(monkeypatch) args = types.SimpleNamespace( update_weight_mode="delta", - update_weight_transport="tensor", - update_weight_disk_dir=None, - update_weight_delta_dir=None, + update_weight_transport="disk", + update_weight_disk_dir="/shared/delta", + update_weight_local_checkpoint_dir=None, colocate=False, ) - with pytest.raises(ValueError, match="supports only --update-weight-transport=nccl or disk"): + with pytest.raises(ValueError, match="requires --update-weight-local-checkpoint-dir"): module._validate_update_weight_args(args) From f78ec133aac69d2aed05fd32b57a0ac802f422c3 Mon Sep 17 00:00:00 2001 From: Nan Date: Tue, 16 Jun 2026 05:26:26 +0000 Subject: [PATCH 6/8] delta: inline the disk-dir check into _validate_update_weight_args MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With the delta-dir alias gone, _resolve_update_weight_disk_dir no longer normalizes anything — it's a single transport-level check, so fold it into _validate_update_weight_args. --- slime/utils/arguments.py | 9 ++------- tests/test_megatron_argument_validation.py | 3 ++- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index f8f4672d70..152618c0e5 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1702,19 +1702,14 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: return eval_datasets -def _resolve_update_weight_disk_dir(args) -> None: - """Disk-backed sync (full or delta) needs a directory the trainer writes and the rollout - engines read — a filesystem shared between them.""" +def _validate_update_weight_args(args) -> None: + # disk-backed sync (full or delta) writes on the trainer and reads on the engines: needs a shared dir if args.update_weight_transport == "disk" and not args.update_weight_disk_dir: raise ValueError( "--update-weight-transport=disk requires --update-weight-disk-dir to point at " "a filesystem shared between the trainer and the rollout engines." ) - -def _validate_update_weight_args(args) -> None: - _resolve_update_weight_disk_dir(args) - if args.update_weight_mode == "delta": if args.update_weight_transport != "disk": raise ValueError( diff --git a/tests/test_megatron_argument_validation.py b/tests/test_megatron_argument_validation.py index a8a9a12d7f..4b6372b31c 100644 --- a/tests/test_megatron_argument_validation.py +++ b/tests/test_megatron_argument_validation.py @@ -171,12 +171,13 @@ def test_allgather_cp_ignores_cp_size_one(monkeypatch): def test_update_weight_disk_dir_required_for_disk_transport(monkeypatch): module = load_slime_arguments_module(monkeypatch) args = types.SimpleNamespace( + update_weight_mode="full", update_weight_transport="disk", update_weight_disk_dir=None, ) with pytest.raises(ValueError, match="update-weight-disk-dir"): - module._resolve_update_weight_disk_dir(args) + module._validate_update_weight_args(args) def make_slime_validate_args(**overrides): From 39be7a67c2867ea3e710b25dea355ba5e9549cfc Mon Sep 17 00:00:00 2001 From: Nan Date: Tue, 16 Jun 2026 05:29:30 +0000 Subject: [PATCH 7/8] delta: inline update-weight validation into slime_validate_args slime_validate_args validates everything else inline; the extracted _validate_update_weight_args was the lone exception. Fold it in and test it the same way as the other slime_validate_args checks (make_slime_validate_args). --- slime/utils/arguments.py | 51 ++++++++++------------ tests/test_megatron_argument_validation.py | 25 ++++------- 2 files changed, 32 insertions(+), 44 deletions(-) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 152618c0e5..10739d7ee2 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1702,33 +1702,6 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: return eval_datasets -def _validate_update_weight_args(args) -> None: - # disk-backed sync (full or delta) writes on the trainer and reads on the engines: needs a shared dir - if args.update_weight_transport == "disk" and not args.update_weight_disk_dir: - raise ValueError( - "--update-weight-transport=disk requires --update-weight-disk-dir to point at " - "a filesystem shared between the trainer and the rollout engines." - ) - - if args.update_weight_mode == "delta": - if args.update_weight_transport != "disk": - raise ValueError( - "--update-weight-mode=delta requires --update-weight-transport=disk, " - f"got {args.update_weight_transport!r}." - ) - if args.colocate: - raise ValueError( - "--update-weight-mode=delta is not supported with --colocate. Colocate transfers " - "weights via CUDA IPC (only a handle crosses processes), so the delta bookkeeping " - "(snapshot + diff + encode) is pure overhead." - ) - if not args.update_weight_local_checkpoint_dir: - raise ValueError( - "--update-weight-mode=delta requires --update-weight-local-checkpoint-dir " - "(a rollout-host-local NVMe directory)." - ) - - def slime_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) @@ -1997,4 +1970,26 @@ def slime_validate_args(args): if args.only_train_params_name_list and args.freeze_params_name_list: raise ValueError("You can only specify ONE of: --only-train-params-name-list, or --freeze-params-name-list.") - _validate_update_weight_args(args) + # disk-backed sync (full or delta) writes on the trainer and reads on the engines: needs a shared dir + if args.update_weight_transport == "disk" and not args.update_weight_disk_dir: + raise ValueError( + "--update-weight-transport=disk requires --update-weight-disk-dir to point at " + "a filesystem shared between the trainer and the rollout engines." + ) + if args.update_weight_mode == "delta": + if args.update_weight_transport != "disk": + raise ValueError( + "--update-weight-mode=delta requires --update-weight-transport=disk, " + f"got {args.update_weight_transport!r}." + ) + if args.colocate: + raise ValueError( + "--update-weight-mode=delta is not supported with --colocate. Colocate transfers " + "weights via CUDA IPC (only a handle crosses processes), so the delta bookkeeping " + "(snapshot + diff + encode) is pure overhead." + ) + if not args.update_weight_local_checkpoint_dir: + raise ValueError( + "--update-weight-mode=delta requires --update-weight-local-checkpoint-dir " + "(a rollout-host-local NVMe directory)." + ) diff --git a/tests/test_megatron_argument_validation.py b/tests/test_megatron_argument_validation.py index 4b6372b31c..683b55ad4d 100644 --- a/tests/test_megatron_argument_validation.py +++ b/tests/test_megatron_argument_validation.py @@ -170,14 +170,10 @@ def test_allgather_cp_ignores_cp_size_one(monkeypatch): @pytest.mark.unit def test_update_weight_disk_dir_required_for_disk_transport(monkeypatch): module = load_slime_arguments_module(monkeypatch) - args = types.SimpleNamespace( - update_weight_mode="full", - update_weight_transport="disk", - update_weight_disk_dir=None, - ) + args = make_slime_validate_args(update_weight_transport="disk", update_weight_disk_dir=None) with pytest.raises(ValueError, match="update-weight-disk-dir"): - module._validate_update_weight_args(args) + module.slime_validate_args(args) def make_slime_validate_args(**overrides): @@ -258,7 +254,7 @@ def make_slime_validate_args(**overrides): freeze_params_name_list=None, update_weight_transport="nccl", update_weight_disk_dir=None, - update_weight_delta_dir=None, + update_weight_local_checkpoint_dir=None, update_weight_mode="full", ) values.update(overrides) @@ -311,22 +307,20 @@ def test_slime_validate_args_preserves_zero_rollout_gpus_without_colocate(monkey @pytest.mark.unit def test_update_weight_delta_requires_disk_transport(monkeypatch): module = load_slime_arguments_module(monkeypatch) - args = types.SimpleNamespace( + args = make_slime_validate_args( update_weight_mode="delta", update_weight_transport="nccl", - update_weight_disk_dir=None, update_weight_local_checkpoint_dir="/local/ckpt", - colocate=False, ) with pytest.raises(ValueError, match="requires --update-weight-transport=disk"): - module._validate_update_weight_args(args) + module.slime_validate_args(args) @pytest.mark.unit def test_update_weight_delta_rejects_colocate(monkeypatch): module = load_slime_arguments_module(monkeypatch) - args = types.SimpleNamespace( + args = make_slime_validate_args( update_weight_mode="delta", update_weight_transport="disk", update_weight_disk_dir="/shared/delta", @@ -335,22 +329,21 @@ def test_update_weight_delta_rejects_colocate(monkeypatch): ) with pytest.raises(ValueError, match="not supported with --colocate"): - module._validate_update_weight_args(args) + module.slime_validate_args(args) @pytest.mark.unit def test_update_weight_delta_requires_local_checkpoint_dir(monkeypatch): module = load_slime_arguments_module(monkeypatch) - args = types.SimpleNamespace( + args = make_slime_validate_args( update_weight_mode="delta", update_weight_transport="disk", update_weight_disk_dir="/shared/delta", update_weight_local_checkpoint_dir=None, - colocate=False, ) with pytest.raises(ValueError, match="requires --update-weight-local-checkpoint-dir"): - module._validate_update_weight_args(args) + module.slime_validate_args(args) if __name__ == "__main__": From fd7c00d6f467cdcd2d0fb26094f2a1cbf8214762 Mon Sep 17 00:00:00 2001 From: Nan Date: Tue, 16 Jun 2026 16:32:38 +0000 Subject: [PATCH 8/8] delta: prefetch the host-local base during engine init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Materialize the host-local checkpoint in a daemon thread at engine init so the one-time base copy overlaps sglang launch and the first rollout (which serves from init-loaded weights) instead of blocking the first delta reload. The first sync_local_checkpoint's init_local_checkpoint is idempotent and flock-guarded, so it either finds the copy done or blocks on the same lock — no join needed. --- slime/backends/sglang_utils/sglang_engine.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index d87f6614c5..f30f18e03e 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -3,6 +3,7 @@ import logging import multiprocessing import os +import threading import time from urllib.parse import quote @@ -167,6 +168,19 @@ def _format_v6_uri(addr): else: self._init_normal(server_args_dict) + # Warm the host-local base off the actor's main thread: sglang serves the first rollout from + # its init-loaded weights, so the materialize (a full base copy) only has to finish before + # the first delta reload. init_local_checkpoint is idempotent and flock-guarded, so the first + # sync_local_checkpoint either finds it done or blocks on the same lock — no join needed. + if self.args.update_weight_mode == "delta" and self.args.update_weight_transport == "disk": + from slime.utils.disk_delta import init_local_checkpoint + + threading.Thread( + target=init_local_checkpoint, + args=(self.args.update_weight_local_checkpoint_dir, self.args.hf_checkpoint), + daemon=True, + ).start() + def _init_external(self, expect_server_args, external_engine_need_check_fields): logger.info(f"Use external SGLang engine (rank={self.rank}, expect_server_args={expect_server_args})")