ZeRO-3: stream partitioning of oversized parameters in zero.Init#8103
ZeRO-3: stream partitioning of oversized parameters in zero.Init#8103Achyuthan-S wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: cc35972244
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if not self._should_stream_partition(param): | ||
| param.data = param.data.to(self.local_device) |
There was a problem hiding this comment.
Avoid checking streaming before ZeRO metadata exists
When stage3_partition_stream_chunk_size is set and zero.Init(module=prebuilt_model, ...) is used, this new pre-check runs on ordinary torch.nn.Parameters before _zero_init_param() calls _convert_to_deepspeed_param(). _should_stream_partition() immediately asks for _partition_world_size(param), which dereferences param.ds_process_group; that attribute is only installed later in _convert_to_deepspeed_param(), so the module-conversion path raises AttributeError even for parameters smaller than the chunk size. Move the stream decision until after conversion, or make the pre-check use the default process group without requiring ZeRO metadata.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed. _should_stream_partition now gates on the global num_partitions instead of _partition_world_size(param), so the zero.Init(module=...) path no longer dereferences param.ds_process_group before _convert_to_deepspeed_param attaches it. The per-parameter group is still used in the actual partitioning (_partition_param_streaming), which runs after conversion. Added a DistributedTest that exercises the module= path with streaming enabled to guard this.
Under zero.Init, each parameter is broadcast and partitioned by first materializing the full tensor on a single device. A single very large fused parameter (e.g. a 128-expert MoE weight) can exceed device memory during a from_pretrained load even when the sharded model fits; offload_param does not help because it only controls where the resulting partition is stored. Add an opt-in stage3_partition_stream_chunk_size: a parameter larger than the threshold that is not already on the accelerator is partitioned by streaming its flattened data through fixed-size chunks (stage chunk -> broadcast from owner rank -> copy this rank's slice), bounding the partition-time device peak to roughly the chunk size. Defaults to 0 (disabled), leaving the existing path unchanged. Signed-off-by: Achyuthan Sivasankar <achyuthan.sivasankar@gmail.com>
cc35972 to
cc0fb6a
Compare
|
Hey @tohtana , I have been working on this issue and opened a PR with the solution. |
Problem
Under zero.Init (ZeRO-3), every parameter is moved to the accelerator, broadcast in full, and then sliced into per-rank partitions. A single very large fused parameter — e.g. a 128-expert MoE weight — must be fully materialized on one device during this step, which can OOM that device during a from_pretrained load even when the sharded model fits. offload_param: {device: cpu} does not help: it only controls where the resulting partition is stored, not where the full tensor is staged.
Closes #8085.
Change
Adds an opt-in ZeRO-3 config stage3_partition_stream_chunk_size (default 0 = disabled). When set, a parameter with more elements than the threshold that is not already on the accelerator (the host-staged from_pretrained / low_cpu_mem_usage path) is partitioned by streaming its flattened data through fixed-size chunks: stage a chunk on the accelerator → broadcast from the owner rank → copy only this rank's slice into ds_tensor. The full tensor is never materialized on a single device, bounding the partition-time peak to roughly the chunk size.
With the default (0) the standard broadcast-then-partition path runs unchanged. Streaming is skipped for the nvme / quantized / ZeRO++ secondary-partition paths, which stage parameters differently.
Validation
Correctness — new unit test covers the chunk/partition overlap math (incl. padding, single-rank). End-to-end, the streamed partition reconstructs bit-for-bit identically to the standard path across world sizes 1–3, with padding, all_gather round-trip, and offload_param: cpu.
NCCL + peak memory (2× NVIDIA L40S):
[A] NCCL correctness (gathered streamed == standard): True
[B] peak GPU memory during zero.Init (world=2, dim=22528, fp32)
full param : 2.03 GB partition/rank: 1.02 GB chunk: 40 MB
streaming OFF peak : 3.05 GB
streaming ON peak : 1.10 GB
peak reduction : 1.95 GB (64% lower)
Scope
Applies to parameters that reach partitioning off-GPU (the from_pretrained / low_cpu_mem_usage path this issue targets). Parameters constructed directly on the accelerator inside zero.Init are unaffected — the spike there happens at construction time, which can be addressed as a follow-up.
cc @tohtana @tjruwase @loadams