From cc0fb6ab5ba1f12fd8af09f69ed77b95b4c6c33d Mon Sep 17 00:00:00 2001 From: Achyuthan Sivasankar Date: Tue, 30 Jun 2026 15:02:11 +0400 Subject: [PATCH] ZeRO-3: stream partitioning of oversized parameters in zero.Init Under zero.Init, each parameter is broadcast and partitioned by first materializing the full tensor on a single device. A single very large fused parameter (e.g. a 128-expert MoE weight) can exceed device memory during a from_pretrained load even when the sharded model fits; offload_param does not help because it only controls where the resulting partition is stored. Add an opt-in stage3_partition_stream_chunk_size: a parameter larger than the threshold that is not already on the accelerator is partitioned by streaming its flattened data through fixed-size chunks (stage chunk -> broadcast from owner rank -> copy this rank's slice), bounding the partition-time device peak to roughly the chunk size. Defaults to 0 (disabled), leaving the existing path unchanged. Signed-off-by: Achyuthan Sivasankar --- deepspeed/runtime/zero/config.py | 9 ++ .../runtime/zero/partition_parameters.py | 91 +++++++++++- docs/_pages/config-json.md | 7 + .../runtime/zero/test_partition_streaming.py | 133 ++++++++++++++++++ 4 files changed, 238 insertions(+), 2 deletions(-) create mode 100644 tests/unit/runtime/zero/test_partition_streaming.py diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 79fbcb97a188..c5f0ca2fc6af 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -277,6 +277,15 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): Recommended for scenarios with high memory pressure. """ + partition_stream_chunk_size: int = Field(pp_int(0), ge=0, alias="stage3_partition_stream_chunk_size") + """ + Partition parameters with more than this many elements by streaming their flattened data through + fixed-size chunks of this size, instead of materializing the full parameter on a single device during + ``zero.Init``. This bounds the peak device memory used while partitioning to roughly the chunk size, + which is required when a single (e.g. fused MoE-expert) parameter is too large to fit on one device. + ``0`` (default) disables streaming and uses the standard broadcast-then-partition path. Used by ZeRO3. + """ + stage3_gather_fp16_weights_on_model_save: bool = Field(False, json_schema_extra={ "deprecated": True, diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 9cdff411237c..9aadde0ed566 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -316,6 +316,27 @@ def free_param(param: Parameter) -> None: param.ds_status = ZeroParamStatus.NOT_AVAILABLE +def _partition_chunk_overlap(chunk_offset, chunk_numel, partition_start, partition_numel): + """Locate where a streamed chunk of a flattened parameter lands in a partition. + + When a parameter is partitioned by streaming its flattened data in fixed-size + chunks, each rank owns the flat slice ``[partition_start, partition_start + + partition_numel)``. This returns the part of the chunk ``[chunk_offset, + chunk_offset + chunk_numel)`` that falls inside that partition as + ``(dst_offset, src_offset, numel)`` -- ``dst_offset`` indexes the partition and + ``src_offset`` indexes the chunk -- or ``None`` when the chunk does not overlap + the partition. + """ + overlap_start = max(chunk_offset, partition_start) + overlap_end = min(chunk_offset + chunk_numel, partition_start + partition_numel) + if overlap_start >= overlap_end: + return None + dst_offset = overlap_start - partition_start + src_offset = overlap_start - chunk_offset + numel = overlap_end - overlap_start + return dst_offset, src_offset, numel + + reuse_buffers = False temp_contiguous_tensor = None empty_buffers = {} @@ -1096,6 +1117,12 @@ def __init__(self, else: self.param_swapper = None + # Threshold/chunk size for streaming the partitioning of very large parameters. + # Read before the module conversion below since partitioning happens there. + self.partition_stream_chunk_size = get_config_default(DeepSpeedZeroConfig, "partition_stream_chunk_size") + if _ds_config is not None: + self.partition_stream_chunk_size = _ds_config.zero_config.partition_stream_chunk_size + # If we are provided an already-allocated module to prepare. if module is not None: assert isinstance(module, torch.nn.Module) @@ -1119,6 +1146,9 @@ def _update_persist_config(self, ds_config): def _zero_init_param(self, param): self._convert_to_deepspeed_param(param) + if self._should_stream_partition(param): + self._partition_param_streaming(param) + return partition_group = self.get_partition_dp_group(param) if dist.get_world_group() == partition_group: dist.broadcast(param.data, 0, partition_group) @@ -1126,12 +1156,69 @@ def _zero_init_param(self, param): dist.broadcast(param.data, dist.get_global_rank(partition_group, 0), partition_group) param.partition() + def _should_stream_partition(self, param): + # Stream the broadcast and partitioning of a parameter that is too large to + # safely materialize in full on a single device. The nvme / quantized / + # ZeRO++ secondary-partition paths stage parameters differently and are left + # on the standard path. + if self.partition_stream_chunk_size <= 0 or self.num_partitions <= 1: + return False + if param.numel() <= self.partition_stream_chunk_size: + return False + if self.remote_device == OffloadDeviceEnum.nvme: + return False + if self.quantized_initialization or self.quantized_nontrainable_weights: + return False + if self.zero_param_process_group is not None: + return False + return True + + def _partition_param_streaming(self, param): + # Partition a very large parameter without ever materializing the full tensor + # on a single device. The full parameter stays on its current (host) device; + # each chunk is staged on the accelerator, broadcast from the owner rank for + # consistency, and only the slice that belongs to this rank's partition is + # copied into ds_tensor. + tensor_size = self._aligned_size(param) + partition_size = tensor_size // self._partition_world_size(param) + + partition_device = self.local_device if param.ds_persist else self.remote_device + partitioned_tensor = torch.empty(partition_size, dtype=param.dtype, device=partition_device) + if partition_device == OffloadDeviceEnum.cpu and self.pin_memory: + partitioned_tensor = get_accelerator().pin_memory(partitioned_tensor) + partitioned_tensor.requires_grad = False + param.ds_tensor = partitioned_tensor + param.ds_tensor.ds_numel = partition_size + param.ds_tensor.status = PartitionedParamStatus.AVAILABLE + param.ds_tensor.final_location = None + param.ds_numel_aligned = tensor_size + + partition_start = partition_size * self._partition_rank(param) + src_rank = dist.get_global_rank(self.get_partition_dp_group(param), 0) + compute_device = get_accelerator().current_device_name() + full_param = param.data.contiguous().view(-1) + + chunk_offset = 0 + while chunk_offset < param.ds_numel: + chunk_numel = min(self.partition_stream_chunk_size, param.ds_numel - chunk_offset) + chunk = full_param.narrow(0, chunk_offset, chunk_numel).to(compute_device) + dist.broadcast(chunk, src_rank, self.get_partition_dp_group(param)) + overlap = _partition_chunk_overlap(chunk_offset, chunk_numel, partition_start, partition_size) + if overlap is not None: + dst_offset, src_offset, numel = overlap + with torch.no_grad(): + param.ds_tensor.narrow(0, dst_offset, numel).copy_(chunk.narrow(0, src_offset, numel)) + chunk_offset += chunk_numel + + free_param(param) + def _convert_to_zero_parameters(self, param_list): for param in param_list: if is_zero_param(param): continue - param.data = param.data.to(self.local_device) + if not self._should_stream_partition(param): + param.data = param.data.to(self.local_device) self._zero_init_param(param) def _validate_remote_device(self, remote_device, ds_config): @@ -1159,7 +1246,7 @@ def _post_init_method(self, module): InsertPostInitMethodToModuleSubClasses.num_module_parameters += 1 InsertPostInitMethodToModuleSubClasses.num_module_elements += param.numel() if not is_zero_param(param): - if not get_accelerator().on_accelerator(param): + if not get_accelerator().on_accelerator(param) and not self._should_stream_partition(param): param.data = param.data.to(self.local_device) if name == 'weight' and self.quantized_initialization and type(module) in WEIGHT_QUANTIZATION_LAYERS: diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 6d395ae33370..66309142b0ff 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -574,6 +574,13 @@ Enabling and configuring ZeRO memory optimizations | Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages). | `1e5` | +***stage3_partition_stream_chunk_size***: [integer] + +| Description | Default | +| -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | +| Partition parameters with more than this many elements by streaming their flattened data through fixed-size chunks of this size, bounding peak device memory during `zero.Init` instead of materializing the full parameter on one device. Needed when a single (e.g. fused MoE-expert) parameter is too large to fit on one device. `0` disables streaming. | `0` | + + ***stage3_gather_16bit_weights_on_model_save***: [boolean] | Description | Default | diff --git a/tests/unit/runtime/zero/test_partition_streaming.py b/tests/unit/runtime/zero/test_partition_streaming.py new file mode 100644 index 000000000000..29fda798629b --- /dev/null +++ b/tests/unit/runtime/zero/test_partition_streaming.py @@ -0,0 +1,133 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +import deepspeed +from deepspeed.runtime.zero import partition_parameters as pp +from deepspeed.runtime.zero.partition_parameters import _partition_chunk_overlap + +from unit.common import DistributedTest + + +def _aligned_numel(numel, num_partitions): + remainder = numel % num_partitions + return numel + ((num_partitions - remainder) if remainder else 0) + + +def _stream_one_partition(full_flat, rank, num_partitions, chunk_numel): + """Rebuild a single rank's partition by streaming the flattened parameter in + fixed-size chunks, mirroring ``_partition_param_streaming``. Padding elements + (beyond the real numel) are left as the ``-1`` sentinel so the test can assert + they are never written.""" + ds_numel = full_flat.numel() + partition_numel = _aligned_numel(ds_numel, num_partitions) // num_partitions + partition_start = partition_numel * rank + out = torch.full((partition_numel, ), -1.0) + offset = 0 + while offset < ds_numel: + cur = min(chunk_numel, ds_numel - offset) + overlap = _partition_chunk_overlap(offset, cur, partition_start, partition_numel) + if overlap is not None: + dst_offset, src_offset, numel = overlap + chunk = full_flat.narrow(0, offset, cur) + out.narrow(0, dst_offset, numel).copy_(chunk.narrow(0, src_offset, numel)) + offset += cur + return out + + +@pytest.mark.parametrize("numel,num_partitions,chunk_numel", [ + (64, 4, 8), + (64, 4, 7), + (60, 8, 5), + (10, 4, 3), + (1, 2, 4), + (100, 3, 1), + (128, 1, 16), +]) +def test_streamed_partitions_match_direct_slicing(numel, num_partitions, chunk_numel): + full = torch.arange(numel, dtype=torch.float32) + partition_numel = _aligned_numel(numel, num_partitions) // num_partitions + aligned = partition_numel * num_partitions + + rebuilt = torch.full((aligned, ), -1.0) + for rank in range(num_partitions): + partition = _stream_one_partition(full, rank, num_partitions, chunk_numel) + rebuilt.narrow(0, rank * partition_numel, partition_numel).copy_(partition) + + # The real (non-padded) region must match the original parameter exactly. + assert torch.equal(rebuilt.narrow(0, 0, numel), full) + # Padding elements must never be written by the streaming copy. + if aligned > numel: + padding = rebuilt.narrow(0, numel, aligned - numel) + assert torch.all(padding == -1.0) + + +class TestStreamingPartitionMatchesStandard(DistributedTest): + world_size = 2 + + def test_streaming_matches_standard(self): + + def build(chunk_size): + config = { + "train_batch_size": self.world_size, + "zero_optimization": { + "stage": 3, + "stage3_partition_stream_chunk_size": chunk_size, + }, + } + torch.manual_seed(1234) + with deepspeed.zero.Init(config_dict_or_path=config): + linear = torch.nn.Linear(64, 64, bias=False) + return linear + + # Reference: the standard broadcast-then-partition path (streaming disabled). + reference = build(0) + reference_partition = reference.weight.ds_tensor.detach().clone() + + # Streaming the same parameter (4096 elements) in 512-element chunks must + # produce a byte-identical partition while actually exercising the new path. + streaming_calls = {"count": 0} + original = pp.Init._partition_param_streaming + + def counting_stream(self, param, *args, **kwargs): + streaming_calls["count"] += 1 + return original(self, param, *args, **kwargs) + + pp.Init._partition_param_streaming = counting_stream + try: + streamed = build(512) + finally: + pp.Init._partition_param_streaming = original + + assert streaming_calls["count"] >= 1, "streaming partition path was not exercised" + assert torch.equal(streamed.weight.ds_tensor, reference_partition) + + def test_streaming_via_module_path(self): + # zero.Init(module=...) decides whether to stream on plain torch parameters, + # before they are converted to ZeRO params. The decision must therefore not + # require ZeRO metadata (e.g. a per-parameter process group) that is only + # attached during conversion. + + def build(chunk_size): + config = { + "train_batch_size": self.world_size, + "zero_optimization": { + "stage": 3, + "stage3_partition_stream_chunk_size": chunk_size, + }, + } + torch.manual_seed(99) + linear = torch.nn.Linear(64, 64, bias=True) # built on the host, then converted + deepspeed.zero.Init(module=linear, config_dict_or_path=config) + return linear + + reference = build(0) + # weight (4096 elements) streams; bias (64) stays on the standard path but is + # still passed through the pre-conversion stream check. + streamed = build(512) + assert torch.equal(streamed.weight.ds_tensor, reference.weight.ds_tensor) + assert torch.equal(streamed.bias.ds_tensor, reference.bias.ds_tensor)