From 3aa6f82598dcbd9d13992f6e75e10d9388412c58 Mon Sep 17 00:00:00 2001 From: whycoming Date: Thu, 25 Jun 2026 20:13:24 +0800 Subject: [PATCH] fix(zero): reject Muon + reduce_scatter in ZeRO-1/2 Muon's Newton-Schulz orthogonalization requires the full all-reduced gradient matrix on each rank. With reduce_scatter (the default), ZeRO-1/2 delivers each rank only its own partition slice, so a parameter whose flattened gradient crosses a partition boundary is orthogonalized on a partially-reduced, rank-divergent gradient and silently receives an incorrect update (#7807). Raise a clear error at initialization when Muon is combined with reduce_scatter, consistent with the existing ZeRO-3 guard (#7919), and add a regression test. Users should set "reduce_scatter": false to run Muon with ZeRO-1/2, as the Muon tests already do. Closes #7807 Signed-off-by: whycoming --- deepspeed/runtime/zero/stage_1_and_2.py | 7 +++++ tests/unit/ops/muon/test_muon.py | 34 +++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 89a45aa0fa41..a8624075ef05 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -35,6 +35,7 @@ from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.accelerator import get_accelerator from deepspeed.runtime.zero.muon.original_muon import muon_update +from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS) @@ -218,6 +219,12 @@ def __init__(self, self.reduce_scatter = reduce_scatter + # Muon's Newton-Schulz orthogonalization needs the full all-reduced gradient on each + # rank; reduce_scatter delivers only this rank's partition slice and silently corrupts + # cross-partition parameters (#7807). ZeRO-3 already guards this (see stage3.py). + if isinstance(self.optimizer, MuonWithAuxAdam) and self.reduce_scatter: + raise ValueError("Muon and reduce scatter cannot be used together") + self.overlap_comm = overlap_comm self.deepspeed_adam_offload = self.cpu_offload diff --git a/tests/unit/ops/muon/test_muon.py b/tests/unit/ops/muon/test_muon.py index 84b06dd96265..5be9d505239d 100644 --- a/tests/unit/ops/muon/test_muon.py +++ b/tests/unit/ops/muon/test_muon.py @@ -177,3 +177,37 @@ def test_ns_method_stage3(self, ns_method): loss = engine(x, y) engine.backward(loss) engine.step() + + +class TestMuonRejectsReduceScatter(DistributedTest): + """Muon needs the full all-reduced gradient matrix on each rank for its Newton-Schulz + orthogonalization. reduce_scatter only delivers each rank its own partition slice, which + silently corrupts cross-partition parameters in ZeRO-1/2 (#7807). Initialization must fail + loudly, consistent with the ZeRO-3 guard in stage3.py (added in #7919).""" + + world_size = 1 + + @pytest.mark.parametrize('zero_stage', [1, 2]) + def test_muon_reduce_scatter_raises(self, zero_stage): + config_dict = { + "train_batch_size": 4, + "optimizer": { + "type": "muon", + "params": { + "lr": 0.01 + } + }, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage, + "reduce_scatter": True, + }, + } + model = SimpleModel(hidden_dim=32, nlayers=2) + with pytest.raises(ValueError, match="Muon and reduce scatter cannot be used together"): + deepspeed.initialize(config=config_dict, + model=model, + model_parameters=model.parameters(), + dist_init_required=False)