Add On-Policy Distillation (OPSD) Trainer backend in DeepSpeed#8027
Add On-Policy Distillation (OPSD) Trainer backend in DeepSpeed#8027PKUWZP wants to merge 8 commits into
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 6384396b48
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
@PKUWZP, this needs to be split between examples and engine logic. The examples should go to the examples repo: https://github.com/deepspeedai/DeepSpeedexamples The engine should go somewhere under deepspeed. Since this already an existing hybrid_engine, we should figure out if it makes sense to replace or consolidate. |
First slice of the on-policy distillation example app under examples/opsd/. This commit lands the framework-agnostic foundation: the OPSDConfig dataclass hierarchy, chunked / streamed forward-KL / reverse-KL / JSD losses with sequence-axis chunking to bound peak memory, response-mask + shift helpers, and a 24-case CPU-only test suite covering identity, masking, chunk equivalence, gradient flow, and numerical edge cases. Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Adds the two-phase teacher path:
* TeacherWrapper loads a HuggingFace causal LM, freezes it, and runs
forward-only. Two modes: load + pin on GPU (offload_to_cpu=false), or
wrap with deepspeed.initialize using a ZeRO-3 + offload_param=cpu
config (offload_to_cpu=true). Avoids deepspeed.zero.Init() around
from_pretrained because HF's loader partitions params to zero-width
shards before the checkpoint can fill them.
* TeacherLogitCache stages the [B, T, V] teacher logits to (pinned) host
memory in bf16, and exposes chunk_to_device() so the student-side loss
can pull sequence slices back on demand. This is the memory-economising
half of the two-phase update.
CPU-only tests cover the cache shape / dtype / round-trip / chunk-bounds
behaviour and verify the streamed-via-cache loss matches the direct
chunked loss bit-for-bit.
Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Lands the fully-runnable hybrid-engine training path: a backend-agnostic RolloutEngine ABC with RolloutRequest / RolloutBatch / SamplingConfig dataclasses, a HybridEngineRollout implementation that uses DeepSpeed's accelerated decode when an inference policy exists and otherwise falls back to GatheredParameters + the raw HF generate (covers Qwen-family and other models not in DeepSpeed's inference container list), a left-padded prompt dataset + collator, a three-phase trainer loop (rollout -> teacher forward + cache -> student forward + streamed KL + backward + step), the argparse + deepspeed.initialize entry point, base DeepSpeed ZeRO-3 + hybrid_engine JSON configs, a 5-step smoke config and launcher script, and a 20-prompt math toy dataset for the smoke run. Smoke-validated end-to-end on 2x H200 with Qwen2.5-0.5B-Instruct student and Qwen2.5-1.5B-Instruct teacher; loss finite for 5 steps. Rollout interface contract is covered by tests/test_rollout_interface.py. Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Lands the second-stage rollout path, weight-sync infrastructure, and the
example app's README. Includes:
* VLLMRollout that constructs vllm.LLM on training rank 0 and broadcasts
generated token ids to peer ranks, with disjoint-GPU (subprocess) and
shared (in-process) topology paths. Weight sync gathers ZeRO-3 params
cooperatively then pushes to vLLM via LLM.collective_rpc("load_weights").
* WeightBridge ABC with COLUMN / ROW / VOCAB / REPLICATED parallel kinds
and an even-slice per-rank slicer; Qwen2WeightBridge with the full
per-parameter table for Qwen2 / Qwen2.5; Qwen3WeightBridge adding the
per-head q_norm / k_norm tensors as REPLICATED.
* vLLM-side prompt+response stitching factored into stitch_rollout() so
its index math is unit-testable without a live vLLM.
* CPU-only tests: tests/test_weight_bridge.py covers parallel-kind
dispatch, per-rank shape/gather round-trips across tp_size in {1,2,4},
indivisibility / invalid-rank guards, and the registry;
tests/test_vllm_stitch.py covers prompt/response stitching for the
common shapes including variable response lengths and left-padded
prompts.
* configs + launch scripts for both production and smoke vLLM runs.
**Known blocker called out in README and module docstring:** vLLM's worker
init calls new_group() on the global process group, which deadlocks when
launched under the standard `deepspeed --num_gpus N` launcher (rank 0
calls vLLM, other ranks never participate in vLLM's collective). The
documented fix is the TRL/OpenRLHF separate-server pattern; this PR lands
the scaffolding so that work can begin against a green codebase.
Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
HybridEngineRollout: - model.generate() path for sampling (temperature>0) - Graph capture + DeepSpeedStaticCache path for greedy (temperature=0, 3x faster) - DeepSpeedStaticCache: CUDA-graph-compatible KV cache with external write_position - RolloutRequest/RolloutBatch/RolloutConfig dataclasses VLLMRollout: - Weight sync via gdr/http backends - vllm_python config for interpreter selection - vLLM compat sitecustomize shim - Only sync requires_grad params to vLLM OPSD trainer/config: - Move trainer to deepspeed/runtime/rlhf/trainer/ - Move config to deepspeed/runtime/rlhf/config.py - Force weight_sync_interval=1 for on-policy correctness Tests: - CPU unit tests for HybridEngineRollout and VLLMRollout - Graph capture verified: HF StaticCache == DeepSpeedStaticCache == graph (100 steps, 0 diff) Verified on Qwen2.5-0.5B-Instruct / RTX 5090: - model.generate(): 90 tok/s (batch=1) - graph capture: 270 tok/s (3x speedup) - OPSD smoke test: 3 training steps pass end-to-end Signed-off-by: Guokai Ma <guokai.ma@intel.com> Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
…m_dtype to engine_dtype Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Formatting hooks: - Add the standard license header to the opsd benchmark scripts (bench_14b_rollout.py, bench_autotp_gc.py, bench_vllm_tp2.py). - Mark CUDA-specific benchmark calls with #ignore-cuda and drop the unused `sys` import flagged by flake8. - yapf: collapse the rlhf __init__ imports that now fit in 119 cols. cpu-torch-latest unit tests: - Align tests/unit/runtime/rollout/test_hybrid_engine_rollout.py with the current cfg-based HybridEngineRollout constructor: the tests previously passed continuous_batching_size / use_graph_capture kwargs and expected _generate_continuous_batching / _generate_graph_capture_cb dispatch that the implementation does not provide. Tests now build a HybridEngineRolloutConfig and assert the real generate() dispatch (module.generate by default, _generate_graph for greedy graph capture). - Update the two benchmark constructor calls to the cfg-based API. Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Keep the SPDX Apache-2.0 identifier and DeepSpeed Team attribution, which are what the check-license hook requires. Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
| @@ -0,0 +1,144 @@ | |||
| # SPDX-License-Identifier: Apache-2.0 | |||
There was a problem hiding this comment.
Please move benchmarks to examples repo. That will reduce the dependencies for building DS as benchmarks increase in complexity.
|
|
||
| def _gdr_available() -> bool: | ||
| try: | ||
| return torch.cuda.is_available() and torch.cuda.nccl.version() is not None #ignore-cuda |
There was a problem hiding this comment.
These cuda references is breaking accelerator abstraction.
| }}) | ||
| init_thread.start() | ||
|
|
||
| from vllm.distributed.utils import StatelessProcessGroup |
There was a problem hiding this comment.
Why are we making vllm a dependency? This breaks the lightweight nature of DS and introduces unnecessary complexity.
Summary
This PR adds a DeepSpeed-native on-policy distillation trainer. It also incorporates the PR from @delock (PKUWZP#1), which abstracts the OPSD example into DeepSpeed submodules, including a rollout engine abstract layer, two rollout implementations (HybridEngine and vLLM), and OPSD trainer.
On-policy distillation: a small student generates rollouts, a frozen large teacher scores them, and the student is updated by a per-token divergence (forward-KL / reverse-KL / JSD) between the two distributions on the student's own samples. Each step has three phases — student rollout → teacher forward + CPU logit cache → student forward + streamed divergence + backward — so the full
[B, T, V]teacher tensor never co-resides with the student logits on the training device.Key Design Decisions
build_rollout()factoryHybridEngineRolloutruns in-process, reusing model weights — no cross-process weight transfer neededVLLMRolloutruns as a subprocess to avoid vLLM'snew_group()deadlocking with DeepSpeed launcherHybridEngineRolloutsupport continuous batching and graph captureModules:
deepspeed/runtime/rlhf/trainer/opsd.py— chunked / streamed forward-KL, reverse-KL, JSD with sequence-axis chunkingdeepspeed/runtime/rlhf/trainer/teacher.py— frozen teacher wrapper +TeacherLogitCache(host-resident, chunk fetch)deepspeed/runtime/rlhf/rollout/hybrid_engine_rollout.py— HybridEngine rollout backendsdeepspeed/runtime/rlhf/rollout/vllm_rollout.py— vLLM rollout backendsbenchmarks/opsd— Benchmark scripts for OPSD trainingtests/unit/runtime/rollout— testing scripts for OPSD rollout backends (vLLM + HybridEngine)Validated end-to-end: On 2× H200 with Qwen2.5-0.5B-Instruct student + Qwen2.5-1.5B-Instruct teacher via the hybrid-engine path; loss finite for 5 steps. See README for the smoke recipe.
Follow-up items: documented in README file and SGLang rollout integration.
Test plan
cd examples/opsd && python -m pytest tests/ -v→ 87/87 passing on CPUdeepspeed --num_gpus 2 main.py --config configs/smoke_hybrid.jsonend-to-end on 2× H200 → 5 finite-loss stepspre-commit run --files <all changed files>→ green (yapf, flake8, check-torchdist, check-license, check-torchcuda, codespell)