diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 7b7c50454874..f3b07eb81eae 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -880,6 +880,35 @@ def _no_gather_coalesced(params: Iterable[Parameter]) -> AllGatherCoalescedHandl return NoGatherCoalescedHandle(params) +def _unsharded_single_rank_warning(dp_world_size, data_parallel_group, env=None): + """Detect the silent single-rank fallback described in #8084. + + When a multi-process launcher (``deepspeed``, ``torchrun``, accelerate, ...) sets ``WORLD_SIZE > 1`` but the + distributed process group was not initialized before ``zero.Init`` ran, the group resolved here collapses to a + single rank. ``zero.Init`` then creates every parameter whole on every rank instead of partitioning it, so each + rank allocates the full (unsharded) model and typically OOMs. The failure is otherwise silent and looks exactly + like a "model too big" OOM. Return an actionable warning message in that case, else ``None``. + + Only the default (world-group) path is checked: an explicitly supplied ``data_parallel_group`` of size 1 is + treated as intentional. + """ + if dp_world_size != 1 or data_parallel_group is not None: + return None + env = os.environ if env is None else env + try: + launcher_world_size = int(env.get("WORLD_SIZE", "0") or "0") + except (TypeError, ValueError): + return None + if launcher_world_size <= 1: + return None + return ( + "zero.Init resolved a process group of world_size=1, but the launcher environment reports " + f"WORLD_SIZE={launcher_world_size}. The distributed process group was likely not initialized before " + "zero.Init ran (for example, `from_pretrained` executed before `deepspeed.init_distributed()`). Parameters " + "will NOT be partitioned: every rank allocates the full model and will likely OOM. Call " + "`deepspeed.init_distributed()` before constructing the model under zero.Init.") + + # Replaces all parameters in module with Scattered Parameters class Init(InsertPostInitMethodToModuleSubClasses): param_id = 0 @@ -1035,6 +1064,10 @@ def __init__(self, self.rank = dist.get_rank(group=self.ds_process_group) self.dp_world_size = dist.get_world_size(group=self.ds_process_group) + _unsharded_warning = _unsharded_single_rank_warning(self.dp_world_size, data_parallel_group) + if _unsharded_warning is not None: + logger.warning(_unsharded_warning) + self.zero_param_process_group = zero_param_parallel_group if _ds_config is not None and _ds_config.zero_config.zero_hpz_partition_size > 1 and self.zero_param_process_group is None: groups._create_zero_param_parallel_group(_ds_config.zero_config.zero_hpz_partition_size) diff --git a/tests/unit/runtime/zero/test_zero_init_unsharded_warning.py b/tests/unit/runtime/zero/test_zero_init_unsharded_warning.py new file mode 100644 index 000000000000..11fc045b4345 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_init_unsharded_warning.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Regression coverage for #8084: zero.Init silently falls back to a single-rank (unsharded) group when the +# distributed process group is not initialized before it runs (e.g. `from_pretrained` before +# `deepspeed.init_distributed()`), so every rank allocates the full model and OOMs. The detection helper must warn +# only when the launcher reports a multi-process world but the resolved group collapsed to one rank. + +import pytest + +from deepspeed.runtime.zero.partition_parameters import _unsharded_single_rank_warning + + +def test_warns_when_launcher_multiprocess_but_group_is_single_rank(): + msg = _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=None, env={"WORLD_SIZE": "8"}) + assert msg is not None + assert "WORLD_SIZE=8" in msg + assert "init_distributed" in msg + + +def test_no_warning_for_genuine_single_process(): + assert _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=None, env={"WORLD_SIZE": "1"}) is None + assert _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=None, env={}) is None + + +def test_no_warning_when_group_actually_shards(): + assert _unsharded_single_rank_warning(dp_world_size=8, data_parallel_group=None, env={"WORLD_SIZE": "8"}) is None + + +def test_no_warning_when_explicit_dp_group_supplied(): + # An explicitly provided size-1 data_parallel_group is treated as intentional. + sentinel_group = object() + assert _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=sentinel_group, env={"WORLD_SIZE": + "8"}) is None + + +@pytest.mark.parametrize("bad", ["", "not-an-int", None]) +def test_malformed_world_size_does_not_raise(bad): + assert _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=None, env={"WORLD_SIZE": bad}) is None