diff --git a/deepspeed/runtime/rollout/__init__.py b/deepspeed/runtime/rollout/__init__.py new file mode 100644 index 000000000000..16f6fc595da6 --- /dev/null +++ b/deepspeed/runtime/rollout/__init__.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Rollout engines for on-policy generation during RL/distillation training. + +Provides: + - :class:`RolloutEngine` — abstract base class + - :class:`RolloutRequest`, :class:`RolloutBatch`, :class:`SamplingConfig` — dataclasses + - :class:`HybridEngineRollout` — concrete implementation using DeepSpeed hybrid engine + - :func:`build_rollout` — factory that selects the engine from config +""" + +from deepspeed.runtime.rollout.base import ( + RolloutBatch, + RolloutConfig, + RolloutEngine, + RolloutRequest, + SamplingConfig, +) +from deepspeed.runtime.rollout.hybrid_engine_rollout import HybridEngineRollout + +__all__ = [ + "HybridEngineRollout", + "RolloutBatch", + "RolloutConfig", + "RolloutEngine", + "RolloutRequest", + "SamplingConfig", + "build_rollout", +] + + +def build_rollout(rollout_cfg, student_engine=None, tokenizer=None, **kwargs): + """Factory: construct the rollout engine specified by ``rollout_cfg.engine``. + + Args: + rollout_cfg: :class:`RolloutConfig` (or any object with an ``engine`` + attribute set to ``"hybrid_engine"``). + student_engine: DeepSpeed engine wrapping the student model. + tokenizer: HuggingFace tokenizer. + """ + engine_name = rollout_cfg.engine + if engine_name == "hybrid_engine": + if student_engine is None or tokenizer is None: + raise ValueError("hybrid_engine rollout needs both student_engine and tokenizer") + return HybridEngineRollout(engine=student_engine, tokenizer=tokenizer, cfg=rollout_cfg) + + raise ValueError(f"Unknown rollout engine {engine_name!r}; choose from 'hybrid_engine'") diff --git a/deepspeed/runtime/rollout/base.py b/deepspeed/runtime/rollout/base.py new file mode 100644 index 000000000000..abff6c6ccb12 --- /dev/null +++ b/deepspeed/runtime/rollout/base.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Rollout engine interface. + +The trainer talks to its rollout engine through three small dataclasses +(``RolloutRequest`` in / ``RolloutBatch`` out / ``SamplingConfig``) and one +ABC. This keeps engine-specific concerns out of the trainer loop. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import torch + + +@dataclass +class RolloutConfig: + """Configuration for the rollout engine.""" + engine: str = "hybrid_engine" + + # Use CUDA graph capture for decode acceleration. + use_graph_capture: bool = False + + +@dataclass +class SamplingConfig: + """Sampling knobs that the trainer passes to ``generate`` each step.""" + + max_new_tokens: int + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + n_samples_per_prompt: int = 1 + + +@dataclass +class RolloutRequest: + """Input to ``RolloutEngine.generate``. + + Prompts arrive *left-padded* (i.e. real tokens at the right edge) so that + causal generation appends naturally after them. + """ + + prompt_ids: torch.Tensor # [B, T_p] left-padded with pad_token_id + prompt_attention_mask: torch.Tensor # [B, T_p], 1 on real prompt tokens + + def __post_init__(self) -> None: + if self.prompt_ids.dim() != 2: + raise ValueError(f"prompt_ids must be 2-D [B, T_p]; got {tuple(self.prompt_ids.shape)}") + if self.prompt_attention_mask.shape != self.prompt_ids.shape: + raise ValueError(f"prompt_attention_mask shape {tuple(self.prompt_attention_mask.shape)} " + f"does not match prompt_ids {tuple(self.prompt_ids.shape)}") + + +@dataclass +class RolloutBatch: + """Output of ``RolloutEngine.generate``. + + ``input_ids`` holds the *concatenation* of (left-padded) prompt and + response, right-padded to the longest sequence in the batch. + """ + + input_ids: torch.Tensor # [B', T_p + T_r]; B' = B * n_samples_per_prompt + attention_mask: torch.Tensor # [B', T_p + T_r] + response_start_idx: torch.Tensor # [B'] int + + def __post_init__(self) -> None: + if self.input_ids.dim() != 2: + raise ValueError(f"input_ids must be 2-D; got {tuple(self.input_ids.shape)}") + if self.attention_mask.shape != self.input_ids.shape: + raise ValueError(f"attention_mask shape {tuple(self.attention_mask.shape)} does not " + f"match input_ids {tuple(self.input_ids.shape)}") + B = self.input_ids.shape[0] + if self.response_start_idx.shape != (B, ): + raise ValueError(f"response_start_idx must be 1-D of length {B}; got " + f"{tuple(self.response_start_idx.shape)}") + + @property + def batch_size(self) -> int: + return int(self.input_ids.shape[0]) + + @property + def seq_len(self) -> int: + return int(self.input_ids.shape[1]) + + +class RolloutEngine(ABC): + """Abstract base for rollout engines.""" + + name: str = "base" + + @abstractmethod + def generate(self, request: RolloutRequest, sampling: SamplingConfig) -> RolloutBatch: + """Run generation, return prompt+response in one tensor.""" + + @abstractmethod + def sync_weights(self, step: int) -> None: + """Push updated weights into the rollout backend. + + No-op when the rollout engine is co-located with the training engine + (e.g. hybrid engine shares weights directly). + """ + + def shutdown(self) -> None: + """Release any backend resources. Default no-op.""" + return None diff --git a/deepspeed/runtime/rollout/hybrid_engine_rollout.py b/deepspeed/runtime/rollout/hybrid_engine_rollout.py new file mode 100644 index 000000000000..7e6279b8bf83 --- /dev/null +++ b/deepspeed/runtime/rollout/hybrid_engine_rollout.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Rollout engine backed by DeepSpeed's hybrid engine. + +Two generation paths: + 1. **model.generate()** (default): delegates to HuggingFace generate. + Supports sampling (temperature, top_p) and greedy. + 2. **graph capture + DeepSpeedStaticCache**: only for greedy (temperature=0). + Pre-allocates a StaticCache, captures the decode forward pass with a + CUDA graph, and replays it for each decode step. Eliminates kernel + launch overhead. +""" + +from dataclasses import dataclass + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.rollout.base import RolloutBatch, RolloutEngine, RolloutRequest, SamplingConfig + + +@dataclass +class HybridEngineRolloutConfig: + """Configuration for HybridEngineRollout.""" + use_graph_capture: bool = False + + +class HybridEngineRollout(RolloutEngine): + """Rollout engine using DeepSpeed hybrid engine. + + Args: + engine: DeepSpeed engine wrapping the model. + tokenizer: HuggingFace tokenizer (must have pad_token_id or eos_token_id). + cfg: Optional HybridEngineRolloutConfig. + """ + + def __init__(self, engine, tokenizer, cfg=None): + self.engine = engine + self.tokenizer = tokenizer + self.use_graph_capture = getattr(cfg, 'use_graph_capture', False) if cfg else False + + @torch.no_grad() + def generate(self, request: RolloutRequest, sampling: SamplingConfig) -> RolloutBatch: + device = request.prompt_ids.device + B = request.prompt_ids.shape[0] + n = sampling.n_samples_per_prompt + total = B * n + prompt_len = request.prompt_ids.shape[1] + max_new_tokens = sampling.max_new_tokens + pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id + + module = self.engine.module + + # Expand prompts for n samples per prompt + if n > 1: + prompt_ids = request.prompt_ids.repeat_interleave(n, dim=0) + prompt_attn = request.prompt_attention_mask.repeat_interleave(n, dim=0) + else: + prompt_ids = request.prompt_ids + prompt_attn = request.prompt_attention_mask + + is_greedy = sampling.temperature <= 0.0 + + if self.use_graph_capture and is_greedy: + output_ids = self._generate_graph(prompt_ids, prompt_attn, max_new_tokens, pad_token_id, module, device) + else: + temperature = max(sampling.temperature, 1e-8) + do_sample = not is_greedy + output_ids = module.generate( + prompt_ids, + attention_mask=prompt_attn, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature if do_sample else 1.0, + top_p=sampling.top_p if do_sample else 1.0, + pad_token_id=pad_token_id, + ) + + # Build attention mask: pad positions (both left padding from prompt + # and right padding from EOS / shorter sequences) are 0. + full_len = output_ids.shape[1] + response_start = prompt_len + attention_mask = (output_ids != pad_token_id).long() + for i in range(total): + prompt_valid = request.prompt_attention_mask[i // n if B > 1 else 0] + attention_mask[i, :prompt_len] = prompt_valid + + return RolloutBatch( + input_ids=output_ids, + attention_mask=attention_mask, + response_start_idx=torch.full((total, ), response_start, dtype=torch.long, device=device), + ) + + # ------------------------------------------------------------------ + # Graph capture decode loop (greedy only) + # ------------------------------------------------------------------ + + def _generate_graph(self, prompt_ids, prompt_attn, max_new_tokens, pad_token_id, module, device): + """Greedy decode with DeepSpeedStaticCache + CUDA graph capture.""" + from transformers import StaticCache + from deepspeed.utils.static_cache import DeepSpeedStaticCache + + batch_size = prompt_ids.shape[0] + prompt_len = prompt_ids.shape[1] + max_len = prompt_len + max_new_tokens + eos_token_id = self.tokenizer.eos_token_id + model_dtype = next(module.parameters()).dtype + + # --- Prefill with HF StaticCache (correct attention semantics) --- + prefill_cache = StaticCache( + config=module.config, + batch_size=batch_size, + max_cache_len=max_len, + device=device, + dtype=model_dtype, + ) + prefill_attn = torch.ones(batch_size, prompt_len, dtype=torch.long, device=device) + prefill_attn[:, :prompt_len] = prompt_attn + prefill_out = module( + prompt_ids, + attention_mask=prefill_attn, + past_key_values=prefill_cache, + use_cache=True, + cache_position=torch.arange(prompt_len, device=device), + ) + next_token = prefill_out.logits[:, -1, :].argmax(dim=-1, keepdim=True) + + # --- Copy prefill KV into DeepSpeedStaticCache --- + write_pos = torch.tensor(prompt_len - 1, dtype=torch.long, device=device) + ds_cache = DeepSpeedStaticCache( + module.config, + batch_size=batch_size, + max_cache_len=max_len, + device=device, + dtype=model_dtype, + ) + ds_cache.set_write_position(write_pos) + # Trigger lazy init then copy real data + for layer_idx in range(len(ds_cache.layers)): + ds_layer = ds_cache.layers[layer_idx] + hf_layer = prefill_cache.layers[layer_idx] + if not ds_layer.is_initialized: + ds_layer.lazy_initialization(hf_layer.keys, hf_layer.values) + ds_layer.keys[:, :, :prompt_len, :].copy_(hf_layer.keys[:, :, :prompt_len, :]) + ds_layer.values[:, :, :prompt_len, :].copy_(hf_layer.values[:, :, :prompt_len, :]) + + output_ids = [prompt_ids, next_token] + + # --- Static buffers for graph capture --- + static_token = torch.zeros(batch_size, 1, dtype=torch.long, device=device) + static_attn = torch.zeros(batch_size, max_len, dtype=torch.long, device=device) + static_attn[:, :prompt_len] = prompt_attn + static_attn[:, prompt_len] = 1 # first decode position + static_pos = torch.tensor(prompt_len, dtype=torch.long, device=device) + static_cache_pos = static_pos.unsqueeze(0) # [1] for cache_position + static_pos_ids = static_pos.reshape(1, 1).expand(batch_size, 1) # [batch, 1] + + write_pos.fill_(prompt_len) + + # Remove forward hooks (they synchronize — illegal during graph capture) + saved_pre = dict(module._forward_pre_hooks) + saved_post = dict(module._forward_hooks) + module._forward_pre_hooks.clear() + module._forward_hooks.clear() + + try: + # Warmup on side stream + static_token.copy_(next_token) + s = get_accelerator().Stream() + s.wait_stream(get_accelerator().current_stream()) + with get_accelerator().stream(s): + for _ in range(3): + out = module( + static_token, + attention_mask=static_attn, + past_key_values=ds_cache, + use_cache=True, + cache_position=static_cache_pos, + position_ids=static_pos_ids, + ) + get_accelerator().current_stream().wait_stream(s) + + # Capture + graph = get_accelerator().create_graph() + with get_accelerator().capture_to_graph(graph): + out = module( + static_token, + attention_mask=static_attn, + past_key_values=ds_cache, + use_cache=True, + cache_position=static_cache_pos, + position_ids=static_pos_ids, + ) + static_logits = out.logits + finally: + module._forward_pre_hooks.update(saved_pre) + module._forward_hooks.update(saved_post) + + # --- Decode loop --- + eos_mask = torch.zeros(batch_size, dtype=torch.bool, device=device) + for step in range(max_new_tokens - 1): + if eos_mask.all(): + output_ids.append(torch.full((batch_size, 1), pad_token_id, dtype=torch.long, device=device)) + continue + + # Update static inputs + static_token.copy_(next_token) + pos = prompt_len + step + write_pos.fill_(pos) + static_cache_pos.fill_(pos) + static_pos_ids.fill_(pos) + static_attn[:, pos] = 1 + + # Replay + get_accelerator().replay_graph(graph) + next_token = static_logits[:, -1, :].argmax(dim=-1, keepdim=True) + output_ids.append(next_token) + eos_mask |= (next_token.squeeze(1) == eos_token_id) + + return torch.cat(output_ids, dim=1) + + @staticmethod + def _sample_top_p(logits: torch.Tensor, temperature: float = 1.0, top_p: float = 1.0) -> torch.Tensor: + """Sample from logits with temperature and nucleus (top-p) filtering.""" + logits = logits / temperature + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + mask = (cumulative_probs - torch.softmax(sorted_logits, dim=-1)) >= top_p + sorted_logits[mask] = -float('inf') + probs = torch.softmax(sorted_logits, dim=-1) + sampled = torch.multinomial(probs, 1) + tokens = sorted_indices.gather(1, sampled) + else: + probs = torch.softmax(logits, dim=-1) + tokens = torch.multinomial(probs, 1) + return tokens + + def sync_weights(self, step: int) -> None: # noqa: ARG002 + """No-op: hybrid engine reads model weights live.""" + return None diff --git a/deepspeed/utils/static_cache.py b/deepspeed/utils/static_cache.py new file mode 100644 index 000000000000..520bef9c7314 --- /dev/null +++ b/deepspeed/utils/static_cache.py @@ -0,0 +1,230 @@ +# Copyright (c) DeepSpeed Team +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""CUDA-graph-compatible static KV cache for hybrid engine rollout. + +Derived from HuggingFace transformers ``StaticCache`` / ``StaticLayer``, but +with a critical difference: the write position is supplied externally via a +shared tensor instead of an internal ``cumulative_length`` counter. + +Why this matters +---------------- +Transformers' ``StaticLayer.update()`` maintains its own ``cumulative_length`` +tensor that advances on every call. During CUDA graph capture the captured +forward "freezes" this counter at whatever value it had at capture time. +On replay the counter does *not* advance, so subsequent KV writes go to the +wrong positions and the model silently produces incorrect logits. + +Our ``DeepSpeedStaticCache`` instead reads the write position from a shared +tensor (``write_position``) that the caller updates in-place before each graph +replay. Because ``write_position`` is a real tensor at a fixed address, CUDA +graph replays read the current value each time. + +The caller (HybridEngineRollout) must call ``cache.set_write_position(pos)`` +before each replay, where ``pos`` is a scalar ``torch.long`` tensor on the +correct device. +""" + +import torch + + +class DeepSpeedStaticLayer: + """A single layer's static KV cache whose write position is externally set. + + Parameters + ---------- + max_cache_len : int + Maximum number of tokens the cache can hold (last dim size). + """ + + is_compileable = True + is_sliding = False + + def __init__(self, max_cache_len: int): + self.max_cache_len = max_cache_len + self.keys: torch.Tensor | None = None + self.values: torch.Tensor | None = None + self.is_initialized = False + self._write_position: torch.Tensor | None = None + + def set_write_position(self, pos: torch.Tensor): + self._write_position = pos + + def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None: + self.dtype = key_states.dtype + self.device = key_states.device + max_batch_size, num_heads = key_states.shape[:2] + self.max_batch_size = max_batch_size + self.num_heads = num_heads + self.k_head_dim = key_states.shape[-1] + self.v_head_dim = value_states.shape[-1] + + self.keys = torch.zeros( + (max_batch_size, num_heads, self.max_cache_len, self.k_head_dim), + dtype=self.dtype, + device=self.device, + ) + self.values = torch.zeros( + (max_batch_size, num_heads, self.max_cache_len, self.v_head_dim), + dtype=self.dtype, + device=self.device, + ) + torch._dynamo.mark_static_address(self.keys) + torch._dynamo.mark_static_address(self.values) + self.is_initialized = True + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + *args, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if not self.is_initialized: + self.lazy_initialization(key_states, value_states) + + kv_length = key_states.shape[-2] + + if self._write_position is not None: + cache_position = torch.arange(kv_length, device=self.device) + self._write_position + else: + cache_position = torch.arange(kv_length, device=self.device) + + try: + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + except NotImplementedError: + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + + return self.keys, self.values + + def get_mask_sizes(self, query_length: int) -> tuple[int, int]: + return self.max_cache_len, 0 + + def get_seq_length(self) -> int: + if not self.is_initialized: + return 0 + if self._write_position is not None: + return self._write_position + 1 + return 0 + + def get_max_cache_shape(self) -> int: + return self.max_cache_len + + def reset(self) -> None: + if self.is_initialized: + self.keys.zero_() + self.values.zero_() + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + if self.is_initialized: + self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device)) + self.values = self.values.index_select(0, beam_idx.to(self.values.device)) + + +class DeepSpeedStaticCache: + """CUDA-graph-compatible static KV cache. + + Drop-in replacement for ``transformers.StaticCache`` in the graph-capture + decode path of ``HybridEngineRollout``. All layers share a single + ``write_position`` tensor that the caller updates before each graph replay. + + Parameters + ---------- + config : PreTrainedConfig + HuggingFace model config (used to determine number of layers and head + dimensions). + batch_size : int + Batch size for eager initialization. + max_cache_len : int + Maximum sequence length (prompt + generated tokens). + device : torch.device | int | str | None + Device for eager initialization. + dtype : torch.dtype | None + Dtype for eager initialization. + """ + + def __init__( + self, + config, + batch_size: int = 1, + max_cache_len: int = 4096, + device=None, + dtype=None, + ): + self.config = config + text_config = getattr(config, "text_config", config) + num_layers = getattr(text_config, "num_hidden_layers", 1) + self._layers = [DeepSpeedStaticLayer(max_cache_len) for _ in range(num_layers)] + self._max_cache_len = max_cache_len + self._write_position: torch.Tensor | None = None + + if dtype is not None and device is not None and batch_size > 0: + num_heads = getattr(text_config, "num_key_value_heads", getattr(text_config, "num_attention_heads", 1)) + head_dim = getattr(text_config, "hidden_size", 1) // getattr(text_config, "num_attention_heads", 1) + self.early_initialization(batch_size, num_heads, head_dim, dtype, device) + + @property + def layers(self): + return self._layers + + def set_write_position(self, pos: torch.Tensor): + """Set the write position shared by all layers. + + Must be called before each graph replay with the decode step position + as a scalar ``torch.long`` tensor on the correct device. The tensor is + stored by reference so subsequent in-place updates (e.g. + ``pos.fill_(new_val)``) are immediately visible to all layers. + """ + self._write_position = pos + for layer in self._layers: + layer.set_write_position(pos) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + *args, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if layer_idx >= len(self._layers): + raise IndexError(f"layer_idx {layer_idx} out of range (cache has {len(self._layers)} layers)") + return self._layers[layer_idx].update(key_states, value_states, *args, **kwargs) + + def early_initialization( + self, + batch_size: int, + num_heads: int, + head_dim: int, + dtype: torch.dtype, + device, + ): + for layer in self._layers: + fake_k = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=dtype, device=device) + fake_v = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=dtype, device=device) + layer.lazy_initialization(fake_k, fake_v) + + def get_seq_length(self, layer_idx: int = 0) -> int: + if layer_idx >= len(self._layers): + return 0 + return self._layers[layer_idx].get_seq_length() + + def get_max_cache_shape(self, layer_idx: int = 0) -> int: + if layer_idx >= len(self._layers): + return self._max_cache_len + return self._layers[layer_idx].get_max_cache_shape() + + def get_mask_sizes(self, query_length: int, layer_idx: int = 0) -> tuple[int, int]: + if layer_idx >= len(self._layers): + return self._max_cache_len, 0 + return self._layers[layer_idx].get_mask_sizes(query_length) + + def reset(self): + for layer in self._layers: + layer.reset() + + def __len__(self): + return len(self._layers) diff --git a/tests/unit/runtime/rollout/test_hybrid_engine_rollout.py b/tests/unit/runtime/rollout/test_hybrid_engine_rollout.py new file mode 100644 index 000000000000..2431fc2f652a --- /dev/null +++ b/tests/unit/runtime/rollout/test_hybrid_engine_rollout.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""CPU-only unit tests for HybridEngineRollout (no GPU needed). + +Tests cover configuration defaults and the pure-tensor sampling helper. +""" + +from unittest.mock import MagicMock + +import torch + +from deepspeed.runtime.rollout.hybrid_engine_rollout import ( + HybridEngineRollout, + HybridEngineRolloutConfig, +) + + +def _make_engine(): + engine = MagicMock() + engine.module = MagicMock() + engine.module.parameters.return_value = iter([]) + return engine + + +def _make_tokenizer(): + tok = MagicMock() + tok.pad_token_id = 0 + tok.eos_token_id = 2 + return tok + + +# -- config defaults ---------------------------------------------------- + + +def test_config_defaults(): + cfg = HybridEngineRolloutConfig() + assert cfg.use_graph_capture is False + + +# -- constructor -------------------------------------------------------- + + +def test_constructor_stores_config(): + engine = _make_engine() + tok = _make_tokenizer() + cfg = HybridEngineRolloutConfig(use_graph_capture=True) + rollout = HybridEngineRollout(engine, tok, cfg=cfg) + assert rollout.use_graph_capture is True + assert rollout.engine is engine + assert rollout.tokenizer is tok + + +def test_constructor_defaults_without_cfg(): + rollout = HybridEngineRollout(_make_engine(), _make_tokenizer()) + assert rollout.use_graph_capture is False + + +# -- _sample_top_p ------------------------------------------------------ + + +def test_sample_top_p_returns_correct_shape(): + logits = torch.randn(4, 100) + tokens = HybridEngineRollout._sample_top_p(logits, temperature=1.0, top_p=1.0) + assert tokens.shape == (4, 1) + + +def test_sample_top_p_deterministic_with_low_temp(): + logits = torch.tensor([[1.0, 10.0, 2.0]]) + tok = HybridEngineRollout._sample_top_p(logits, temperature=1e-10, top_p=1.0) + assert tok.item() == 1 + + +def test_sample_top_p_top_p_filters(): + logits = torch.tensor([[0.0, 0.0, 100.0]]) + tok = HybridEngineRollout._sample_top_p(logits, temperature=1.0, top_p=0.5) + assert tok.item() == 2 + + +def test_sample_top_p_batch(): + logits = torch.randn(8, 50) + tokens = HybridEngineRollout._sample_top_p(logits, temperature=0.8, top_p=0.9) + assert tokens.shape == (8, 1) + assert (tokens >= 0).all() and (tokens < 50).all() + + +# -- sync_weights is no-op --------------------------------------------- + + +def test_sync_weights_is_noop(): + rollout = HybridEngineRollout(_make_engine(), _make_tokenizer()) + assert rollout.sync_weights(step=0) is None + + +# -- generate dispatches correctly ------------------------------------- + + +def test_generate_calls_graph_capture_when_enabled(): + engine = _make_engine() + tok = _make_tokenizer() + cfg = HybridEngineRolloutConfig(use_graph_capture=True) + rollout = HybridEngineRollout(engine, tok, cfg=cfg) + rollout._generate_graph = MagicMock(return_value=torch.zeros(1, 5, dtype=torch.long)) + + req = MagicMock() + req.prompt_ids = torch.tensor([[1, 2]]) + req.prompt_attention_mask = torch.ones(1, 2, dtype=torch.long) + sampling = MagicMock() + sampling.temperature = 0 + sampling.n_samples_per_prompt = 1 + sampling.max_new_tokens = 3 + + rollout.generate(req, sampling) + rollout._generate_graph.assert_called_once() diff --git a/tests/unit/runtime/rollout/test_rollout_interface.py b/tests/unit/runtime/rollout/test_rollout_interface.py new file mode 100644 index 000000000000..bb45267ef5ac --- /dev/null +++ b/tests/unit/runtime/rollout/test_rollout_interface.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Conformance tests for the RolloutEngine interface. + +Validates the dataclass invariants and exercises the interface against a +``FakeRollout`` so the contract is testable without GPUs or a model. The real +backends are tested manually with a launched training script (see README). +""" + +import pytest +import torch + +from deepspeed.runtime.rollout import ( + RolloutBatch, + RolloutEngine, + RolloutRequest, + SamplingConfig, + build_rollout, +) + +# --- dataclass invariants --------------------------------------------------- + + +def test_rollout_request_validates_shapes(): + with pytest.raises(ValueError, match="must be 2-D"): + RolloutRequest(prompt_ids=torch.zeros(8), prompt_attention_mask=torch.ones(8)) + with pytest.raises(ValueError, match="does not match"): + RolloutRequest(prompt_ids=torch.zeros(2, 4, dtype=torch.long), prompt_attention_mask=torch.ones(2, 5)) + + +def test_rollout_batch_validates_shapes(): + with pytest.raises(ValueError, match="must be 2-D"): + RolloutBatch(input_ids=torch.zeros(8, dtype=torch.long), + attention_mask=torch.ones(8), + response_start_idx=torch.tensor([4])) + with pytest.raises(ValueError, match="does not match"): + RolloutBatch(input_ids=torch.zeros(2, 4, dtype=torch.long), + attention_mask=torch.ones(2, 5), + response_start_idx=torch.tensor([4, 4])) + with pytest.raises(ValueError, match="1-D of length"): + RolloutBatch(input_ids=torch.zeros(2, 4, dtype=torch.long), + attention_mask=torch.ones(2, 4), + response_start_idx=torch.tensor([4])) + + +def test_rollout_batch_accessors(): + batch = RolloutBatch( + input_ids=torch.zeros(3, 12, dtype=torch.long), + attention_mask=torch.ones(3, 12), + response_start_idx=torch.tensor([4, 5, 6]), + ) + assert batch.batch_size == 3 + assert batch.seq_len == 12 + + +def test_sampling_config_defaults(): + cfg = SamplingConfig(max_new_tokens=32) + assert cfg.temperature == 1.0 + assert cfg.top_p == 1.0 + assert cfg.top_k == -1 + assert cfg.n_samples_per_prompt == 1 + + +# --- interface conformance via FakeRollout --------------------------------- + + +class FakeRollout(RolloutEngine): + """Deterministic stub: appends ``[42] * max_new_tokens`` to each prompt.""" + + name = "fake" + + def __init__(self, response_token: int = 42): + self.response_token = response_token + self.sync_calls: list = [] + + def generate(self, request: RolloutRequest, sampling: SamplingConfig) -> RolloutBatch: + B, T_p = request.prompt_ids.shape + n = sampling.n_samples_per_prompt + T_r = sampling.max_new_tokens + + prompts_expanded = request.prompt_ids.repeat_interleave(n, dim=0) + attn_p_expanded = request.prompt_attention_mask.repeat_interleave(n, dim=0) + response = torch.full((B * n, T_r), self.response_token, dtype=request.prompt_ids.dtype) + response_attn = torch.ones((B * n, T_r), dtype=attn_p_expanded.dtype) + + input_ids = torch.cat([prompts_expanded, response], dim=1) + attention_mask = torch.cat([attn_p_expanded, response_attn], dim=1) + response_start_idx = torch.full((B * n, ), T_p, dtype=torch.long) + return RolloutBatch(input_ids=input_ids, attention_mask=attention_mask, response_start_idx=response_start_idx) + + def sync_weights(self, step: int) -> None: + self.sync_calls.append(step) + + +def test_fake_rollout_shape_basic(): + fake = FakeRollout() + req = RolloutRequest(prompt_ids=torch.tensor([[1, 2, 3], [4, 5, 6]]), + prompt_attention_mask=torch.ones(2, 3, dtype=torch.long)) + out = fake.generate(req, SamplingConfig(max_new_tokens=4)) + assert out.input_ids.shape == (2, 7) + assert out.attention_mask.shape == (2, 7) + # With left-padded (fully real here) prompts of width 3, response begins + # at column 3 for every sample. + assert out.response_start_idx.tolist() == [3, 3] + + +def test_fake_rollout_with_n_samples(): + fake = FakeRollout() + req = RolloutRequest(prompt_ids=torch.tensor([[1, 2], [3, 4]]), + prompt_attention_mask=torch.ones(2, 2, dtype=torch.long)) + out = fake.generate(req, SamplingConfig(max_new_tokens=3, n_samples_per_prompt=4)) + assert out.input_ids.shape == (8, 5) + assert out.response_start_idx.tolist() == [2] * 8 + + +def test_fake_rollout_left_padded_prompts(): + fake = FakeRollout() + # left-padded prompts: prompt B has only the last 2 positions real, but + # response_start_idx still equals the prompt column width T_p. + prompt_ids = torch.tensor([[1, 2, 3, 4], [0, 0, 5, 6]]) + attn = torch.tensor([[1, 1, 1, 1], [0, 0, 1, 1]], dtype=torch.long) + req = RolloutRequest(prompt_ids=prompt_ids, prompt_attention_mask=attn) + out = fake.generate(req, SamplingConfig(max_new_tokens=2)) + assert out.response_start_idx.tolist() == [4, 4] + + +def test_sync_records_steps(): + fake = FakeRollout() + fake.sync_weights(0) + fake.sync_weights(5) + assert fake.sync_calls == [0, 5] + + +def test_engine_factory_unknown_raises(): + from deepspeed.runtime.rollout.base import RolloutConfig + + with pytest.raises(ValueError, match="Unknown rollout engine"): + build_rollout(RolloutConfig(engine="totally_made_up")) + + +def test_engine_factory_hybrid_requires_student_engine(): + from deepspeed.runtime.rollout.base import RolloutConfig + + with pytest.raises(ValueError, match="needs both"): + build_rollout(RolloutConfig(engine="hybrid_engine"))