Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
ca6c5a8
Add OPSD example: config, divergence losses, utils + tests
PKUWZP May 26, 2026
cfc2768
Add OPSD frozen teacher with CPU logit cache + tests
PKUWZP May 26, 2026
c9b333a
Add OPSD trainer, hybrid-engine rollout, and end-to-end entry point
PKUWZP May 26, 2026
f6cfd68
Add OPSD vLLM rollout scaffold, Qwen2/Qwen3 weight bridges, and README
PKUWZP May 26, 2026
dedfe73
feat(rollout): OPSD rollout engine with graph capture, vLLM backend
delock Jun 21, 2026
10ef325
Use ROLLOUT_VISIBLE_DEVICE env var for vLLM GPU placement; rename vll…
delock Jul 1, 2026
5716ecd
Fix formatting and CPU unit-test checks for OPSD rollout
PKUWZP Jul 2, 2026
837c241
Remove Microsoft Corporation copyright line from OPSD file headers
PKUWZP Jul 2, 2026
734230c
Remove vLLM rollout, move trainer/losses/utils/benchmarks to DeepSpee…
delock Jul 3, 2026
2a46e40
Move static_cache.py to deepspeed/utils/
delock Jul 3, 2026
b260126
Use accelerator abstraction for CUDA graph capture in hybrid engine r…
delock Jul 3, 2026
365f6a0
Remove capture_error_mode parameter from accelerator API
delock Jul 3, 2026
89c9bf1
Trim RolloutConfig to engine-only fields (engine, use_graph_capture)
delock Jul 3, 2026
eb19237
Clean up vLLM references in rollout/base.py docstrings
delock Jul 3, 2026
d626b0b
Replace remaining torch.cuda stream calls with accelerator abstraction
delock Jul 3, 2026
f711569
Remove remaining rlhf/data.py and rlhf/teacher.py
delock Jul 3, 2026
0cf471e
Remove 'tags' from .gitignore (editor artifact)
delock Jul 3, 2026
0df8944
Fix yapf formatting in test_rollout_interface.py
delock Jul 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions deepspeed/runtime/rollout/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
"""Rollout engines for on-policy generation during RL/distillation training.

Provides:
- :class:`RolloutEngine` — abstract base class
- :class:`RolloutRequest`, :class:`RolloutBatch`, :class:`SamplingConfig` — dataclasses
- :class:`HybridEngineRollout` — concrete implementation using DeepSpeed hybrid engine
- :func:`build_rollout` — factory that selects the engine from config
"""

from deepspeed.runtime.rollout.base import (
RolloutBatch,
RolloutConfig,
RolloutEngine,
RolloutRequest,
SamplingConfig,
)
from deepspeed.runtime.rollout.hybrid_engine_rollout import HybridEngineRollout

__all__ = [
"HybridEngineRollout",
"RolloutBatch",
"RolloutConfig",
"RolloutEngine",
"RolloutRequest",
"SamplingConfig",
"build_rollout",
]


def build_rollout(rollout_cfg, student_engine=None, tokenizer=None, **kwargs):
"""Factory: construct the rollout engine specified by ``rollout_cfg.engine``.

Args:
rollout_cfg: :class:`RolloutConfig` (or any object with an ``engine``
attribute set to ``"hybrid_engine"``).
student_engine: DeepSpeed engine wrapping the student model.
tokenizer: HuggingFace tokenizer.
"""
engine_name = rollout_cfg.engine
if engine_name == "hybrid_engine":
if student_engine is None or tokenizer is None:
raise ValueError("hybrid_engine rollout needs both student_engine and tokenizer")
return HybridEngineRollout(engine=student_engine, tokenizer=tokenizer, cfg=rollout_cfg)

raise ValueError(f"Unknown rollout engine {engine_name!r}; choose from 'hybrid_engine'")
107 changes: 107 additions & 0 deletions deepspeed/runtime/rollout/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
"""Rollout engine interface.

The trainer talks to its rollout engine through three small dataclasses
(``RolloutRequest`` in / ``RolloutBatch`` out / ``SamplingConfig``) and one
ABC. This keeps engine-specific concerns out of the trainer loop.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass

import torch


@dataclass
class RolloutConfig:
"""Configuration for the rollout engine."""
engine: str = "hybrid_engine"

# Use CUDA graph capture for decode acceleration.
use_graph_capture: bool = False


@dataclass
class SamplingConfig:
"""Sampling knobs that the trainer passes to ``generate`` each step."""

max_new_tokens: int
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
n_samples_per_prompt: int = 1


@dataclass
class RolloutRequest:
"""Input to ``RolloutEngine.generate``.

Prompts arrive *left-padded* (i.e. real tokens at the right edge) so that
causal generation appends naturally after them.
"""

prompt_ids: torch.Tensor # [B, T_p] left-padded with pad_token_id
prompt_attention_mask: torch.Tensor # [B, T_p], 1 on real prompt tokens

def __post_init__(self) -> None:
if self.prompt_ids.dim() != 2:
raise ValueError(f"prompt_ids must be 2-D [B, T_p]; got {tuple(self.prompt_ids.shape)}")
if self.prompt_attention_mask.shape != self.prompt_ids.shape:
raise ValueError(f"prompt_attention_mask shape {tuple(self.prompt_attention_mask.shape)} "
f"does not match prompt_ids {tuple(self.prompt_ids.shape)}")


@dataclass
class RolloutBatch:
"""Output of ``RolloutEngine.generate``.

``input_ids`` holds the *concatenation* of (left-padded) prompt and
response, right-padded to the longest sequence in the batch.
"""

input_ids: torch.Tensor # [B', T_p + T_r]; B' = B * n_samples_per_prompt
attention_mask: torch.Tensor # [B', T_p + T_r]
response_start_idx: torch.Tensor # [B'] int

def __post_init__(self) -> None:
if self.input_ids.dim() != 2:
raise ValueError(f"input_ids must be 2-D; got {tuple(self.input_ids.shape)}")
if self.attention_mask.shape != self.input_ids.shape:
raise ValueError(f"attention_mask shape {tuple(self.attention_mask.shape)} does not "
f"match input_ids {tuple(self.input_ids.shape)}")
B = self.input_ids.shape[0]
if self.response_start_idx.shape != (B, ):
raise ValueError(f"response_start_idx must be 1-D of length {B}; got "
f"{tuple(self.response_start_idx.shape)}")

@property
def batch_size(self) -> int:
return int(self.input_ids.shape[0])

@property
def seq_len(self) -> int:
return int(self.input_ids.shape[1])


class RolloutEngine(ABC):
"""Abstract base for rollout engines."""

name: str = "base"

@abstractmethod
def generate(self, request: RolloutRequest, sampling: SamplingConfig) -> RolloutBatch:
"""Run generation, return prompt+response in one tensor."""

@abstractmethod
def sync_weights(self, step: int) -> None:
"""Push updated weights into the rollout backend.

No-op when the rollout engine is co-located with the training engine
(e.g. hybrid engine shares weights directly).
"""

def shutdown(self) -> None:
"""Release any backend resources. Default no-op."""
return None
242 changes: 242 additions & 0 deletions deepspeed/runtime/rollout/hybrid_engine_rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
"""Rollout engine backed by DeepSpeed's hybrid engine.

Two generation paths:
1. **model.generate()** (default): delegates to HuggingFace generate.
Supports sampling (temperature, top_p) and greedy.
2. **graph capture + DeepSpeedStaticCache**: only for greedy (temperature=0).
Pre-allocates a StaticCache, captures the decode forward pass with a
CUDA graph, and replays it for each decode step. Eliminates kernel
launch overhead.
"""

from dataclasses import dataclass

import torch

from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.rollout.base import RolloutBatch, RolloutEngine, RolloutRequest, SamplingConfig


@dataclass
class HybridEngineRolloutConfig:
"""Configuration for HybridEngineRollout."""
use_graph_capture: bool = False


class HybridEngineRollout(RolloutEngine):
"""Rollout engine using DeepSpeed hybrid engine.

Args:
engine: DeepSpeed engine wrapping the model.
tokenizer: HuggingFace tokenizer (must have pad_token_id or eos_token_id).
cfg: Optional HybridEngineRolloutConfig.
"""

def __init__(self, engine, tokenizer, cfg=None):
self.engine = engine
self.tokenizer = tokenizer
self.use_graph_capture = getattr(cfg, 'use_graph_capture', False) if cfg else False

@torch.no_grad()
def generate(self, request: RolloutRequest, sampling: SamplingConfig) -> RolloutBatch:
device = request.prompt_ids.device
B = request.prompt_ids.shape[0]
n = sampling.n_samples_per_prompt
total = B * n
prompt_len = request.prompt_ids.shape[1]
max_new_tokens = sampling.max_new_tokens
pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id

module = self.engine.module

# Expand prompts for n samples per prompt
if n > 1:
prompt_ids = request.prompt_ids.repeat_interleave(n, dim=0)
prompt_attn = request.prompt_attention_mask.repeat_interleave(n, dim=0)
else:
prompt_ids = request.prompt_ids
prompt_attn = request.prompt_attention_mask

is_greedy = sampling.temperature <= 0.0

if self.use_graph_capture and is_greedy:
output_ids = self._generate_graph(prompt_ids, prompt_attn, max_new_tokens, pad_token_id, module, device)
else:
temperature = max(sampling.temperature, 1e-8)
do_sample = not is_greedy
output_ids = module.generate(
prompt_ids,
attention_mask=prompt_attn,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature if do_sample else 1.0,
top_p=sampling.top_p if do_sample else 1.0,
pad_token_id=pad_token_id,
)

# Build attention mask: pad positions (both left padding from prompt
# and right padding from EOS / shorter sequences) are 0.
full_len = output_ids.shape[1]
response_start = prompt_len
attention_mask = (output_ids != pad_token_id).long()
for i in range(total):
prompt_valid = request.prompt_attention_mask[i // n if B > 1 else 0]
attention_mask[i, :prompt_len] = prompt_valid

return RolloutBatch(
input_ids=output_ids,
attention_mask=attention_mask,
response_start_idx=torch.full((total, ), response_start, dtype=torch.long, device=device),
)

# ------------------------------------------------------------------
# Graph capture decode loop (greedy only)
# ------------------------------------------------------------------

def _generate_graph(self, prompt_ids, prompt_attn, max_new_tokens, pad_token_id, module, device):
"""Greedy decode with DeepSpeedStaticCache + CUDA graph capture."""
from transformers import StaticCache
from deepspeed.utils.static_cache import DeepSpeedStaticCache

batch_size = prompt_ids.shape[0]
prompt_len = prompt_ids.shape[1]
max_len = prompt_len + max_new_tokens
eos_token_id = self.tokenizer.eos_token_id
model_dtype = next(module.parameters()).dtype

# --- Prefill with HF StaticCache (correct attention semantics) ---
prefill_cache = StaticCache(
config=module.config,
batch_size=batch_size,
max_cache_len=max_len,
device=device,
dtype=model_dtype,
)
prefill_attn = torch.ones(batch_size, prompt_len, dtype=torch.long, device=device)
prefill_attn[:, :prompt_len] = prompt_attn
prefill_out = module(
prompt_ids,
attention_mask=prefill_attn,
past_key_values=prefill_cache,
use_cache=True,
cache_position=torch.arange(prompt_len, device=device),
)
next_token = prefill_out.logits[:, -1, :].argmax(dim=-1, keepdim=True)

# --- Copy prefill KV into DeepSpeedStaticCache ---
write_pos = torch.tensor(prompt_len - 1, dtype=torch.long, device=device)
ds_cache = DeepSpeedStaticCache(
module.config,
batch_size=batch_size,
max_cache_len=max_len,
device=device,
dtype=model_dtype,
)
ds_cache.set_write_position(write_pos)
# Trigger lazy init then copy real data
for layer_idx in range(len(ds_cache.layers)):
ds_layer = ds_cache.layers[layer_idx]
hf_layer = prefill_cache.layers[layer_idx]
if not ds_layer.is_initialized:
ds_layer.lazy_initialization(hf_layer.keys, hf_layer.values)
ds_layer.keys[:, :, :prompt_len, :].copy_(hf_layer.keys[:, :, :prompt_len, :])
ds_layer.values[:, :, :prompt_len, :].copy_(hf_layer.values[:, :, :prompt_len, :])

output_ids = [prompt_ids, next_token]

# --- Static buffers for graph capture ---
static_token = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
static_attn = torch.zeros(batch_size, max_len, dtype=torch.long, device=device)
static_attn[:, :prompt_len] = prompt_attn
static_attn[:, prompt_len] = 1 # first decode position
static_pos = torch.tensor(prompt_len, dtype=torch.long, device=device)
static_cache_pos = static_pos.unsqueeze(0) # [1] for cache_position
static_pos_ids = static_pos.reshape(1, 1).expand(batch_size, 1) # [batch, 1]

write_pos.fill_(prompt_len)

# Remove forward hooks (they synchronize — illegal during graph capture)
saved_pre = dict(module._forward_pre_hooks)
saved_post = dict(module._forward_hooks)
module._forward_pre_hooks.clear()
module._forward_hooks.clear()

try:
# Warmup on side stream
static_token.copy_(next_token)
s = get_accelerator().Stream()
s.wait_stream(get_accelerator().current_stream())
with get_accelerator().stream(s):
for _ in range(3):
out = module(
static_token,
attention_mask=static_attn,
past_key_values=ds_cache,
use_cache=True,
cache_position=static_cache_pos,
position_ids=static_pos_ids,
)
get_accelerator().current_stream().wait_stream(s)

# Capture
graph = get_accelerator().create_graph()
with get_accelerator().capture_to_graph(graph):
out = module(
static_token,
attention_mask=static_attn,
past_key_values=ds_cache,
use_cache=True,
cache_position=static_cache_pos,
position_ids=static_pos_ids,
)
static_logits = out.logits
finally:
module._forward_pre_hooks.update(saved_pre)
module._forward_hooks.update(saved_post)

# --- Decode loop ---
eos_mask = torch.zeros(batch_size, dtype=torch.bool, device=device)
for step in range(max_new_tokens - 1):
if eos_mask.all():
output_ids.append(torch.full((batch_size, 1), pad_token_id, dtype=torch.long, device=device))
continue

# Update static inputs
static_token.copy_(next_token)
pos = prompt_len + step
write_pos.fill_(pos)
static_cache_pos.fill_(pos)
static_pos_ids.fill_(pos)
static_attn[:, pos] = 1

# Replay
get_accelerator().replay_graph(graph)
next_token = static_logits[:, -1, :].argmax(dim=-1, keepdim=True)
output_ids.append(next_token)
eos_mask |= (next_token.squeeze(1) == eos_token_id)

return torch.cat(output_ids, dim=1)

@staticmethod
def _sample_top_p(logits: torch.Tensor, temperature: float = 1.0, top_p: float = 1.0) -> torch.Tensor:
"""Sample from logits with temperature and nucleus (top-p) filtering."""
logits = logits / temperature
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
mask = (cumulative_probs - torch.softmax(sorted_logits, dim=-1)) >= top_p
sorted_logits[mask] = -float('inf')
probs = torch.softmax(sorted_logits, dim=-1)
sampled = torch.multinomial(probs, 1)
tokens = sorted_indices.gather(1, sampled)
else:
probs = torch.softmax(logits, dim=-1)
tokens = torch.multinomial(probs, 1)
return tokens

def sync_weights(self, step: int) -> None: # noqa: ARG002
"""No-op: hybrid engine reads model weights live."""
return None
Loading
Loading