Reject Muon optimizer with reduce_scatter in ZeRO-1/2#8090
Conversation
Muon's Newton-Schulz orthogonalization requires the full all-reduced gradient matrix on each rank. With reduce_scatter (the default), ZeRO-1/2 delivers each rank only its own partition slice, so a parameter whose flattened gradient crosses a partition boundary is orthogonalized on a partially-reduced, rank-divergent gradient and silently receives an incorrect update (deepspeedai#7807). Raise a clear error at initialization when Muon is combined with reduce_scatter, consistent with the existing ZeRO-3 guard (deepspeedai#7919), and add a regression test. Users should set "reduce_scatter": false to run Muon with ZeRO-1/2, as the Muon tests already do. Closes deepspeedai#7807 Signed-off-by: whycoming <alwaysxd666@gmail.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3aa6f82598
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # Muon's Newton-Schulz orthogonalization needs the full all-reduced gradient on each | ||
| # rank; reduce_scatter delivers only this rank's partition slice and silently corrupts | ||
| # cross-partition parameters (#7807). ZeRO-3 already guards this (see stage3.py). | ||
| if isinstance(self.optimizer, MuonWithAuxAdam) and self.reduce_scatter: |
There was a problem hiding this comment.
Only reject active Muon parameter groups
This check treats every MuonWithAuxAdam instance as unsafe, but the ZeRO-1/2 whole-matrix Muon path is only taken for params/groups with use_muon=True in get_flat_partition(). When users select the muon optimizer for a model whose parameters are all excluded from Muon (for example embeddings/lm_head/1-D params) or pass a MuonWithAuxAdam with all groups marked use_muon=False, training falls back to the auxiliary Adam path, which is elementwise and compatible with reduce_scatter; because reduce_scatter defaults to true, those valid runs now fail during initialization. Gate the error on an active use_muon group instead of the optimizer class alone.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
get_flat_partition does gate the Newton-Schulz path on use_muon, so strictly the hazard only exists when an active Muon group is present. I kept the condition as isinstance(MuonWithAuxAdam) here to mirror the merged ZeRO-3 guard (stage3.py, #7919), which uses the same check — so the two stages stay consistent, and it matches the conservative style of the existing reduce_scatter guards (e.g. the MoE assertion in stage_1_and_2.py).
The only configuration this over-rejects is a MuonWithAuxAdam whose groups are all use_muon=False — i.e. selecting Muon for a model with no Muon-eligible 2-D params — which is degenerate, and the remedy (reduce_scatter: false) is identical to every other Muon run, so the practical cost is nil.
If you'd prefer the precise gating (reject only when an active use_muon group exists), I'm happy to apply it to both ZeRO-1/2 and ZeRO-3 so they remain consistent — just let me know.
There was a problem hiding this comment.
@whycoming I agree with you that there is no need for precise gating which would only complicate the mechanism. Besides. keeping zero1/2/3 with same behavior could avoid surprise when switch from zero 1/2 to zero3.
|
@whycoming Thanks for looking into this bug and provide the fix! I have put your fix and test case into the merge queue. Thanks! |
) ## Summary Adds a numerical-correctness regression test for the Muon optimizer under ZeRO-1/2. The existing Muon tests only assert that parameters changed, which cannot detect a wrong-but-nonzero update — exactly the failure mode of deepspeedai#7807, where `reduce_scatter` fed Newton-Schulz orthogonalization a partition slice instead of the full DP-averaged gradient. This complements the guard in deepspeedai#8090 by verifying the supported `reduce_scatter: false` path is actually numerically correct. ## What the test does `TestMuonZero12NumericalCorrectness` (in `tests/unit/ops/muon/test_muon.py`), `world_size=2`, parametrized over ZeRO stage `[1, 2]` and `ns_method ['gram', 'standard']`: 1. Builds a model sized so a 2-D weight's flattened gradient straddles the rank-0/rank-1 partition boundary, and asserts this from the **actual** flattened ZeRO partition (`optimizer.bit16_groups` / `bit16_groups_flat`, accounting for alignment padding) — the exact case deepspeedai#7807 corrupts. 2. Runs one step on the supported `reduce_scatter: false` path with `gradient_clipping=0` and `loss_scale=1`, so the applied master-weight update is exactly `-lr * muon_update(grad)`. 3. Compares that applied update against an independent reference that applies the real `muon_update` to the full DP-averaged gradient (using the library function, so Newton-Schulz rounding cancels), via relative Frobenius error. A correct update differs from the reference by only a few percent; the partition-then-orthogonalize bug diverges by O(1) on the cross-partition weight, so the assertion uses a 0.40 threshold that cleanly separates the two. ## Verification (2x RTX 4090, torch 2.9.1+cu128, ZeRO stage 1 and 2) Relative error of the applied Muon update vs the full-gradient reference: | ns_method | correct path (max over weights) | buggy path (cross-partition weight) | |---|---|---| | gram (default) | 0.068 | 0.603 | | standard | 0.216 | 0.673 | The cross-partition weight is the only one affected; wholly-owned weights are identical on both paths. With `reduce_scatter: false` the test passes for both stages and both ns_methods; injecting the bug (`reduce_scatter: true`, pre-guard) makes the cross-partition assertion fail by a wide margin — i.e. this test would have caught deepspeedai#7807. ## Notes Follow-up to deepspeedai#8090 (which adds the guard and closes deepspeedai#7807). Kept in the existing `test_muon.py`. Requires >=2 GPUs (fp16). Refs deepspeedai#7807 cc @PKUWZP @pengdurice (ZeRO-3 Muon guard, deepspeedai#7919) @tohtana Signed-off-by: whycoming <alwaysxd666@gmail.com>
## Summary ZeRO-1/2 silently produces incorrect, rank-divergent parameter updates when the Muon optimizer is used together with `reduce_scatter` (the default). This adds an explicit error at initialization, mirroring the existing ZeRO-3 guard, and includes a regression test. Closes deepspeedai#7807. ## Root cause Muon's Newton-Schulz orthogonalization is a whole-matrix operation: the rank that updates a parameter must hold that parameter's complete, fully-reduced gradient matrix, then take its partition slice of the orthogonalized result. - `get_flat_partition()` (`deepspeed/runtime/zero/stage_1_and_2.py`) applies `muon_update()` to each parameter's gradient reshaped to its full 2-D shape, and only then narrows to this rank's partition. - With `reduce_scatter=True`, `average_tensor()` reduce-scatters the gradients: each rank receives the averaged values only for its own partition slice. For the rest of a parameter whose flattened gradient crosses a partition boundary, the rank still holds its local, un-all-reduced gradient. - So for any cross-partition parameter, no rank holds the full reduced matrix. `muon_update` orthogonalizes a partly-reduced, rank-divergent matrix, and each rank silently applies a different, incorrect update. Parameters that lie wholly inside one partition are unaffected — exactly matching the report. ZeRO-3 already guards this exact conflict in `deepspeed/runtime/zero/stage3.py` (added in deepspeedai#7919): ```python if self.use_muon and self.reduce_scatter: raise ValueError("Muon and reduce scatter cannot be used together") ``` ZeRO-1/2 had no equivalent. The existing Muon unit tests pin `"reduce_scatter": false` everywhere, which implicitly acknowledges the path is unsupported but never enforces it for users — and since `reduce_scatter` defaults to `true`, a default Muon + ZeRO-1/2 run is silently wrong. ## Fix Mirror the ZeRO-3 guard in ZeRO-1/2: raise the same `ValueError` at initialization when the optimizer is `MuonWithAuxAdam` and `reduce_scatter` is enabled. To run Muon under ZeRO-1/2, set `"reduce_scatter": false` (as the Muon tests already do). The change is the import plus the guard, with no other behavioral change. ## Verification (2x RTX 4090, torch 2.9.1+cu128, ZeRO stage 1 and 2) - **Before**: `deepspeed.initialize` with Muon + `reduce_scatter=true` succeeds silently. With `world_size=2` and a model sized so a 2-D weight straddles the gradient-partition boundary, that weight's post-step update diverges by ~0.67 in relative Frobenius norm from the correct full-gradient result, while wholly-owned weights are unaffected — confirming the silent cross-partition corruption. - **After**: the same configuration raises `ValueError: Muon and reduce scatter cannot be used together` for both ZeRO stage 1 and 2. The existing Muon tests (which use `reduce_scatter: false`) remain green. ## Notes This supersedes deepspeedai#7878 and deepspeedai#7808, which aimed at the same issue by trying to force a full all-reduce for Muon but ended up with a self-contradictory guard. Aligning ZeRO-1/2 with the merged ZeRO-3 behavior (deepspeedai#7919) keeps the two code paths consistent and turns silent numerical corruption into a clear, actionable error. A follow-up PR adds a numerical-correctness regression test for the supported `reduce_scatter: false` Muon path, since the current Muon tests only assert that parameters changed. Closes deepspeedai#7807 cc @PKUWZP @pengdurice (ZeRO-3 Muon guard, deepspeedai#7919) @tohtana Signed-off-by: whycoming <alwaysxd666@gmail.com> Co-authored-by: Ma, Guokai <guokai.ma@gmail.com> Signed-off-by: nathon-lee <leejianwoo@gmail.com>
## Summary ZeRO-1/2 silently produces incorrect, rank-divergent parameter updates when the Muon optimizer is used together with `reduce_scatter` (the default). This adds an explicit error at initialization, mirroring the existing ZeRO-3 guard, and includes a regression test. Closes deepspeedai#7807. ## Root cause Muon's Newton-Schulz orthogonalization is a whole-matrix operation: the rank that updates a parameter must hold that parameter's complete, fully-reduced gradient matrix, then take its partition slice of the orthogonalized result. - `get_flat_partition()` (`deepspeed/runtime/zero/stage_1_and_2.py`) applies `muon_update()` to each parameter's gradient reshaped to its full 2-D shape, and only then narrows to this rank's partition. - With `reduce_scatter=True`, `average_tensor()` reduce-scatters the gradients: each rank receives the averaged values only for its own partition slice. For the rest of a parameter whose flattened gradient crosses a partition boundary, the rank still holds its local, un-all-reduced gradient. - So for any cross-partition parameter, no rank holds the full reduced matrix. `muon_update` orthogonalizes a partly-reduced, rank-divergent matrix, and each rank silently applies a different, incorrect update. Parameters that lie wholly inside one partition are unaffected — exactly matching the report. ZeRO-3 already guards this exact conflict in `deepspeed/runtime/zero/stage3.py` (added in deepspeedai#7919): ```python if self.use_muon and self.reduce_scatter: raise ValueError("Muon and reduce scatter cannot be used together") ``` ZeRO-1/2 had no equivalent. The existing Muon unit tests pin `"reduce_scatter": false` everywhere, which implicitly acknowledges the path is unsupported but never enforces it for users — and since `reduce_scatter` defaults to `true`, a default Muon + ZeRO-1/2 run is silently wrong. ## Fix Mirror the ZeRO-3 guard in ZeRO-1/2: raise the same `ValueError` at initialization when the optimizer is `MuonWithAuxAdam` and `reduce_scatter` is enabled. To run Muon under ZeRO-1/2, set `"reduce_scatter": false` (as the Muon tests already do). The change is the import plus the guard, with no other behavioral change. ## Verification (2x RTX 4090, torch 2.9.1+cu128, ZeRO stage 1 and 2) - **Before**: `deepspeed.initialize` with Muon + `reduce_scatter=true` succeeds silently. With `world_size=2` and a model sized so a 2-D weight straddles the gradient-partition boundary, that weight's post-step update diverges by ~0.67 in relative Frobenius norm from the correct full-gradient result, while wholly-owned weights are unaffected — confirming the silent cross-partition corruption. - **After**: the same configuration raises `ValueError: Muon and reduce scatter cannot be used together` for both ZeRO stage 1 and 2. The existing Muon tests (which use `reduce_scatter: false`) remain green. ## Notes This supersedes deepspeedai#7878 and deepspeedai#7808, which aimed at the same issue by trying to force a full all-reduce for Muon but ended up with a self-contradictory guard. Aligning ZeRO-1/2 with the merged ZeRO-3 behavior (deepspeedai#7919) keeps the two code paths consistent and turns silent numerical corruption into a clear, actionable error. A follow-up PR adds a numerical-correctness regression test for the supported `reduce_scatter: false` Muon path, since the current Muon tests only assert that parameters changed. Closes deepspeedai#7807 cc @PKUWZP @pengdurice (ZeRO-3 Muon guard, deepspeedai#7919) @tohtana Signed-off-by: whycoming <alwaysxd666@gmail.com> Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>
Summary
ZeRO-1/2 silently produces incorrect, rank-divergent parameter updates when the Muon optimizer is used together with
reduce_scatter(the default). This adds an explicit error at initialization, mirroring the existing ZeRO-3 guard, and includes a regression test. Closes #7807.Root cause
Muon's Newton-Schulz orthogonalization is a whole-matrix operation: the rank that updates a parameter must hold that parameter's complete, fully-reduced gradient matrix, then take its partition slice of the orthogonalized result.
get_flat_partition()(deepspeed/runtime/zero/stage_1_and_2.py) appliesmuon_update()to each parameter's gradient reshaped to its full 2-D shape, and only then narrows to this rank's partition.reduce_scatter=True,average_tensor()reduce-scatters the gradients: each rank receives the averaged values only for its own partition slice. For the rest of a parameter whose flattened gradient crosses a partition boundary, the rank still holds its local, un-all-reduced gradient.muon_updateorthogonalizes a partly-reduced, rank-divergent matrix, and each rank silently applies a different, incorrect update. Parameters that lie wholly inside one partition are unaffected — exactly matching the report.ZeRO-3 already guards this exact conflict in
deepspeed/runtime/zero/stage3.py(added in #7919):ZeRO-1/2 had no equivalent. The existing Muon unit tests pin
"reduce_scatter": falseeverywhere, which implicitly acknowledges the path is unsupported but never enforces it for users — and sincereduce_scatterdefaults totrue, a default Muon + ZeRO-1/2 run is silently wrong.Fix
Mirror the ZeRO-3 guard in ZeRO-1/2: raise the same
ValueErrorat initialization when the optimizer isMuonWithAuxAdamandreduce_scatteris enabled. To run Muon under ZeRO-1/2, set"reduce_scatter": false(as the Muon tests already do). The change is the import plus the guard, with no other behavioral change.Verification (2x RTX 4090, torch 2.9.1+cu128, ZeRO stage 1 and 2)
deepspeed.initializewith Muon +reduce_scatter=truesucceeds silently. Withworld_size=2and a model sized so a 2-D weight straddles the gradient-partition boundary, that weight's post-step update diverges by ~0.67 in relative Frobenius norm from the correct full-gradient result, while wholly-owned weights are unaffected — confirming the silent cross-partition corruption.ValueError: Muon and reduce scatter cannot be used togetherfor both ZeRO stage 1 and 2. The existing Muon tests (which usereduce_scatter: false) remain green.Notes
This supersedes #7878 and #7808, which aimed at the same issue by trying to force a full all-reduce for Muon but ended up with a self-contradictory guard. Aligning ZeRO-1/2 with the merged ZeRO-3 behavior (#7919) keeps the two code paths consistent and turns silent numerical corruption into a clear, actionable error.
A follow-up PR adds a numerical-correctness regression test for the supported
reduce_scatter: falseMuon path, since the current Muon tests only assert that parameters changed.Closes #7807
cc @PKUWZP @pengdurice (ZeRO-3 Muon guard, #7919) @tohtana