Skip to content

Add On-Policy Distillation (OPSD) Trainer backend in DeepSpeed#8027

Open
PKUWZP wants to merge 8 commits into
deepspeedai:masterfrom
PKUWZP:zhipwang_opd_pr
Open

Add On-Policy Distillation (OPSD) Trainer backend in DeepSpeed#8027
PKUWZP wants to merge 8 commits into
deepspeedai:masterfrom
PKUWZP:zhipwang_opd_pr

Conversation

@PKUWZP

@PKUWZP PKUWZP commented May 26, 2026

Copy link
Copy Markdown
Collaborator

Summary

This PR adds a DeepSpeed-native on-policy distillation trainer. It also incorporates the PR from @delock (PKUWZP#1), which abstracts the OPSD example into DeepSpeed submodules, including a rollout engine abstract layer, two rollout implementations (HybridEngine and vLLM), and OPSD trainer.

On-policy distillation: a small student generates rollouts, a frozen large teacher scores them, and the student is updated by a per-token divergence (forward-KL / reverse-KL / JSD) between the two distributions on the student's own samples. Each step has three phases — student rollout → teacher forward + CPU logit cache → student forward + streamed divergence + backward — so the full [B, T, V] teacher tensor never co-resides with the student logits on the training device.

Key Design Decisions

  • Rollout engine abstracted as RolloutEngine ABC, created via build_rollout() factory
  • HybridEngineRollout runs in-process, reusing model weights — no cross-process weight transfer needed
  • VLLMRollout runs as a subprocess to avoid vLLM's new_group() deadlocking with DeepSpeed launcher
  • HybridEngineRollout support continuous batching and graph capture

Modules:

  • deepspeed/runtime/rlhf/trainer/opsd.py — chunked / streamed forward-KL, reverse-KL, JSD with sequence-axis chunking
  • deepspeed/runtime/rlhf/trainer/teacher.py — frozen teacher wrapper + TeacherLogitCache (host-resident, chunk fetch)
  • deepspeed/runtime/rlhf/rollout/hybrid_engine_rollout.py — HybridEngine rollout backends
  • deepspeed/runtime/rlhf/rollout/vllm_rollout.py — vLLM rollout backends
  • benchmarks/opsd — Benchmark scripts for OPSD training
  • tests/unit/runtime/rollout — testing scripts for OPSD rollout backends (vLLM + HybridEngine)

Validated end-to-end: On 2× H200 with Qwen2.5-0.5B-Instruct student + Qwen2.5-1.5B-Instruct teacher via the hybrid-engine path; loss finite for 5 steps. See README for the smoke recipe.

Follow-up items: documented in README file and SGLang rollout integration.

Test plan

  • cd examples/opsd && python -m pytest tests/ -v87/87 passing on CPU
  • deepspeed --num_gpus 2 main.py --config configs/smoke_hybrid.json end-to-end on 2× H200 → 5 finite-loss steps
  • pre-commit run --files <all changed files> → green (yapf, flake8, check-torchdist, check-license, check-torchcuda, codespell)
  • vLLM rollout end-to-end
  • Larger-scale training run (out of scope for the initial PR)

@PKUWZP PKUWZP requested a review from tohtana as a code owner May 26, 2026 07:31
@PKUWZP PKUWZP changed the title Add On-Policy Distillation (OPSD) example app [Draft] Add On-Policy Distillation (OPSD) Trainer in DeepSpeed May 26, 2026

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 6384396b48

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread examples/opsd/scripts/train_opsd_vllm.sh Outdated
Comment thread examples/opsd/main.py Outdated
Comment thread examples/opsd/main.py Outdated
@sfc-gh-truwase

Copy link
Copy Markdown
Collaborator

@PKUWZP, this needs to be split between examples and engine logic.

The examples should go to the examples repo: https://github.com/deepspeedai/DeepSpeedexamples

The engine should go somewhere under deepspeed. Since this already an existing hybrid_engine, we should figure out if it makes sense to replace or consolidate.

@PKUWZP PKUWZP changed the title [Draft] Add On-Policy Distillation (OPSD) Trainer in DeepSpeed Add On-Policy Distillation (OPSD) Trainer backend in DeepSpeed Jul 1, 2026
@PKUWZP PKUWZP requested a review from delock July 2, 2026 00:00
PKUWZP and others added 6 commits July 1, 2026 14:43
First slice of the on-policy distillation example app under examples/opsd/.
This commit lands the framework-agnostic foundation: the OPSDConfig dataclass
hierarchy, chunked / streamed forward-KL / reverse-KL / JSD losses with
sequence-axis chunking to bound peak memory, response-mask + shift helpers,
and a 24-case CPU-only test suite covering identity, masking, chunk
equivalence, gradient flow, and numerical edge cases.

Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Adds the two-phase teacher path:

  * TeacherWrapper loads a HuggingFace causal LM, freezes it, and runs
    forward-only. Two modes: load + pin on GPU (offload_to_cpu=false), or
    wrap with deepspeed.initialize using a ZeRO-3 + offload_param=cpu
    config (offload_to_cpu=true). Avoids deepspeed.zero.Init() around
    from_pretrained because HF's loader partitions params to zero-width
    shards before the checkpoint can fill them.

  * TeacherLogitCache stages the [B, T, V] teacher logits to (pinned) host
    memory in bf16, and exposes chunk_to_device() so the student-side loss
    can pull sequence slices back on demand. This is the memory-economising
    half of the two-phase update.

CPU-only tests cover the cache shape / dtype / round-trip / chunk-bounds
behaviour and verify the streamed-via-cache loss matches the direct
chunked loss bit-for-bit.

Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Lands the fully-runnable hybrid-engine training path: a backend-agnostic
RolloutEngine ABC with RolloutRequest / RolloutBatch / SamplingConfig
dataclasses, a HybridEngineRollout implementation that uses DeepSpeed's
accelerated decode when an inference policy exists and otherwise falls
back to GatheredParameters + the raw HF generate (covers Qwen-family and
other models not in DeepSpeed's inference container list), a left-padded
prompt dataset + collator, a three-phase trainer loop (rollout -> teacher
forward + cache -> student forward + streamed KL + backward + step), the
argparse + deepspeed.initialize entry point, base DeepSpeed ZeRO-3 +
hybrid_engine JSON configs, a 5-step smoke config and launcher script,
and a 20-prompt math toy dataset for the smoke run.

Smoke-validated end-to-end on 2x H200 with Qwen2.5-0.5B-Instruct student
and Qwen2.5-1.5B-Instruct teacher; loss finite for 5 steps. Rollout
interface contract is covered by tests/test_rollout_interface.py.

Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Lands the second-stage rollout path, weight-sync infrastructure, and the
example app's README. Includes:

  * VLLMRollout that constructs vllm.LLM on training rank 0 and broadcasts
    generated token ids to peer ranks, with disjoint-GPU (subprocess) and
    shared (in-process) topology paths. Weight sync gathers ZeRO-3 params
    cooperatively then pushes to vLLM via LLM.collective_rpc("load_weights").

  * WeightBridge ABC with COLUMN / ROW / VOCAB / REPLICATED parallel kinds
    and an even-slice per-rank slicer; Qwen2WeightBridge with the full
    per-parameter table for Qwen2 / Qwen2.5; Qwen3WeightBridge adding the
    per-head q_norm / k_norm tensors as REPLICATED.

  * vLLM-side prompt+response stitching factored into stitch_rollout() so
    its index math is unit-testable without a live vLLM.

  * CPU-only tests: tests/test_weight_bridge.py covers parallel-kind
    dispatch, per-rank shape/gather round-trips across tp_size in {1,2,4},
    indivisibility / invalid-rank guards, and the registry;
    tests/test_vllm_stitch.py covers prompt/response stitching for the
    common shapes including variable response lengths and left-padded
    prompts.

  * configs + launch scripts for both production and smoke vLLM runs.

**Known blocker called out in README and module docstring:** vLLM's worker
init calls new_group() on the global process group, which deadlocks when
launched under the standard `deepspeed --num_gpus N` launcher (rank 0
calls vLLM, other ranks never participate in vLLM's collective). The
documented fix is the TRL/OpenRLHF separate-server pattern; this PR lands
the scaffolding so that work can begin against a green codebase.

Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
HybridEngineRollout:
- model.generate() path for sampling (temperature>0)
- Graph capture + DeepSpeedStaticCache path for greedy (temperature=0, 3x faster)
- DeepSpeedStaticCache: CUDA-graph-compatible KV cache with external write_position
- RolloutRequest/RolloutBatch/RolloutConfig dataclasses

VLLMRollout:
- Weight sync via gdr/http backends
- vllm_python config for interpreter selection
- vLLM compat sitecustomize shim
- Only sync requires_grad params to vLLM

OPSD trainer/config:
- Move trainer to deepspeed/runtime/rlhf/trainer/
- Move config to deepspeed/runtime/rlhf/config.py
- Force weight_sync_interval=1 for on-policy correctness

Tests:
- CPU unit tests for HybridEngineRollout and VLLMRollout
- Graph capture verified: HF StaticCache == DeepSpeedStaticCache == graph (100 steps, 0 diff)

Verified on Qwen2.5-0.5B-Instruct / RTX 5090:
- model.generate(): 90 tok/s (batch=1)
- graph capture: 270 tok/s (3x speedup)
- OPSD smoke test: 3 training steps pass end-to-end

Signed-off-by: Guokai Ma <guokai.ma@intel.com>
Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
…m_dtype to engine_dtype

Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
@PKUWZP PKUWZP force-pushed the zhipwang_opd_pr branch from 381ca38 to 19929c7 Compare July 2, 2026 00:44
PKUWZP added 2 commits July 1, 2026 14:53
Formatting hooks:
- Add the standard license header to the opsd benchmark scripts
  (bench_14b_rollout.py, bench_autotp_gc.py, bench_vllm_tp2.py).
- Mark CUDA-specific benchmark calls with #ignore-cuda and drop the
  unused `sys` import flagged by flake8.
- yapf: collapse the rlhf __init__ imports that now fit in 119 cols.

cpu-torch-latest unit tests:
- Align tests/unit/runtime/rollout/test_hybrid_engine_rollout.py with the
  current cfg-based HybridEngineRollout constructor: the tests previously
  passed continuous_batching_size / use_graph_capture kwargs and expected
  _generate_continuous_batching / _generate_graph_capture_cb dispatch that
  the implementation does not provide. Tests now build a
  HybridEngineRolloutConfig and assert the real generate() dispatch
  (module.generate by default, _generate_graph for greedy graph capture).
- Update the two benchmark constructor calls to the cfg-based API.

Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Keep the SPDX Apache-2.0 identifier and DeepSpeed Team attribution, which
are what the check-license hook requires.

Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
@PKUWZP PKUWZP force-pushed the zhipwang_opd_pr branch from 19929c7 to 837c241 Compare July 2, 2026 00:55
@@ -0,0 +1,144 @@
# SPDX-License-Identifier: Apache-2.0

@sfc-gh-truwase sfc-gh-truwase Jul 2, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move benchmarks to examples repo. That will reduce the dependencies for building DS as benchmarks increase in complexity.


def _gdr_available() -> bool:
try:
return torch.cuda.is_available() and torch.cuda.nccl.version() is not None #ignore-cuda

@sfc-gh-truwase sfc-gh-truwase Jul 2, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These cuda references is breaking accelerator abstraction.

}})
init_thread.start()

from vllm.distributed.utils import StatelessProcessGroup

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we making vllm a dependency? This breaks the lightweight nature of DS and introduces unnecessary complexity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants