Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero.muon.original_muon import muon_update
from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,
SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,
BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
Expand Down Expand Up @@ -218,6 +219,12 @@ def __init__(self,

self.reduce_scatter = reduce_scatter

# Muon's Newton-Schulz orthogonalization needs the full all-reduced gradient on each
# rank; reduce_scatter delivers only this rank's partition slice and silently corrupts
# cross-partition parameters (#7807). ZeRO-3 already guards this (see stage3.py).
if isinstance(self.optimizer, MuonWithAuxAdam) and self.reduce_scatter:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Only reject active Muon parameter groups

This check treats every MuonWithAuxAdam instance as unsafe, but the ZeRO-1/2 whole-matrix Muon path is only taken for params/groups with use_muon=True in get_flat_partition(). When users select the muon optimizer for a model whose parameters are all excluded from Muon (for example embeddings/lm_head/1-D params) or pass a MuonWithAuxAdam with all groups marked use_muon=False, training falls back to the auxiliary Adam path, which is elementwise and compatible with reduce_scatter; because reduce_scatter defaults to true, those valid runs now fail during initialization. Gate the error on an active use_muon group instead of the optimizer class alone.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_flat_partition does gate the Newton-Schulz path on use_muon, so strictly the hazard only exists when an active Muon group is present. I kept the condition as isinstance(MuonWithAuxAdam) here to mirror the merged ZeRO-3 guard (stage3.py, #7919), which uses the same check — so the two stages stay consistent, and it matches the conservative style of the existing reduce_scatter guards (e.g. the MoE assertion in stage_1_and_2.py).

The only configuration this over-rejects is a MuonWithAuxAdam whose groups are all use_muon=False — i.e. selecting Muon for a model with no Muon-eligible 2-D params — which is degenerate, and the remedy (reduce_scatter: false) is identical to every other Muon run, so the practical cost is nil.

If you'd prefer the precise gating (reject only when an active use_muon group exists), I'm happy to apply it to both ZeRO-1/2 and ZeRO-3 so they remain consistent — just let me know.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@whycoming I agree with you that there is no need for precise gating which would only complicate the mechanism. Besides. keeping zero1/2/3 with same behavior could avoid surprise when switch from zero 1/2 to zero3.

raise ValueError("Muon and reduce scatter cannot be used together")

self.overlap_comm = overlap_comm

self.deepspeed_adam_offload = self.cpu_offload
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/ops/muon/test_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,37 @@ def test_ns_method_stage3(self, ns_method):
loss = engine(x, y)
engine.backward(loss)
engine.step()


class TestMuonRejectsReduceScatter(DistributedTest):
"""Muon needs the full all-reduced gradient matrix on each rank for its Newton-Schulz
orthogonalization. reduce_scatter only delivers each rank its own partition slice, which
silently corrupts cross-partition parameters in ZeRO-1/2 (#7807). Initialization must fail
loudly, consistent with the ZeRO-3 guard in stage3.py (added in #7919)."""

world_size = 1

@pytest.mark.parametrize('zero_stage', [1, 2])
def test_muon_reduce_scatter_raises(self, zero_stage):
config_dict = {
"train_batch_size": 4,
"optimizer": {
"type": "muon",
"params": {
"lr": 0.01
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage,
"reduce_scatter": True,
},
}
model = SimpleModel(hidden_dim=32, nlayers=2)
with pytest.raises(ValueError, match="Muon and reduce scatter cannot be used together"):
deepspeed.initialize(config=config_dict,
model=model,
model_parameters=model.parameters(),
dist_init_required=False)
Loading