diff --git a/tests/unit/ops/muon/test_muon.py b/tests/unit/ops/muon/test_muon.py index 5be9d505239d..57194f0bf71b 100644 --- a/tests/unit/ops/muon/test_muon.py +++ b/tests/unit/ops/muon/test_muon.py @@ -4,6 +4,7 @@ # DeepSpeed Team import deepspeed +import deepspeed.comm as dist import torch import pytest @@ -211,3 +212,132 @@ def test_muon_reduce_scatter_raises(self, zero_stage): model=model, model_parameters=model.parameters(), dist_init_required=False) + + +class TestMuonZero12NumericalCorrectness(DistributedTest): + """Numerical-correctness regression for #7807. + + Under ZeRO-1/2, Muon's Newton-Schulz orthogonalization must run on the FULL DP-averaged + gradient on every rank. The existing Muon tests only assert that parameters changed, which + cannot detect a wrong-but-nonzero update. Here we run the supported reduce_scatter=False + path on >=2 ranks, sized so a 2D weight straddles the gradient-partition boundary (exactly + the case #7807 corrupted), and compare the applied Muon update against an independent + reference that applies the real muon_update to the full averaged gradient. A + partition-then-orthogonalize bug diverges by O(1) -- far above fp16/bf16 NS rounding.""" + + world_size = 2 + + @pytest.mark.parametrize('ns_method', ['gram', 'standard']) + @pytest.mark.parametrize('zero_stage', [1, 2]) + def test_update_matches_full_gradient_reference(self, zero_stage, ns_method): + import copy + from deepspeed.utils import safe_get_full_fp32_param + from deepspeed.runtime.zero.muon.original_muon import muon_update + + hidden_dim, nlayers = 256, 3 + lr, momentum = 0.02, 0.95 + micro = 8 + world = dist.get_world_size() + rank = dist.get_rank() + + torch.manual_seed(1234) + model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers) + init_state = copy.deepcopy(model.state_dict()) + + config_dict = { + "train_micro_batch_size_per_gpu": micro, + "gradient_accumulation_steps": 1, + # No clipping: keep the applied update exactly -lr * muon_update(grad) for the + # reference comparison (Muon's orthogonalized update has a large global norm, so the + # default gradient_clipping=1.0 would otherwise rescale it). + "gradient_clipping": 0.0, + "optimizer": { + "type": "muon", + "params": { + "lr": lr, + "momentum": momentum, + "ns_method": ns_method + } + }, + # Static loss scale so the update is unscaled and matches the reference. + "fp16": { + "enabled": True, + "loss_scale": 1.0 + }, + "zero_optimization": { + "stage": zero_stage, + "reduce_scatter": False + }, + } + engine, _, _, _ = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=model.parameters(), + dist_init_required=False) + device = engine.device + + # Precondition on the ACTUAL flattened ZeRO partition (includes alignment padding and the + # real param ordering): a 2D Muon weight must straddle the rank-0/rank-1 boundary, else + # #7807 (which only corrupts cross-partition weights) cannot be exercised at all. + opt = engine.optimizer + muon_groups = [gi for gi, ps in enumerate(opt.bit16_groups) if ps and all(p.dim() >= 2 for p in ps)] + assert muon_groups, "could not locate the Muon (2D-weight) param group in the optimizer" + crosses = False + for gi in muon_groups: + boundary = opt.bit16_groups_flat[gi].numel() // world + offset = 0 + for p in opt.bit16_groups[gi]: + if offset < boundary < offset + p.numel(): + crosses = True + offset += p.numel() + assert crosses, "no 2D Muon weight straddles the partition boundary; resize the model" + + # Deterministic global batch, identical on every rank; each rank consumes its own slice so + # the DP-averaged gradient equals the full-batch gradient used by the reference. + gen = torch.Generator().manual_seed(999) + gx = torch.randn(world * micro, hidden_dim, generator=gen) + gy = torch.randint(0, hidden_dim, (world * micro, ), generator=gen) + x = gx[rank * micro:(rank + 1) * micro].to(device).half() + y = gy[rank * micro:(rank + 1) * micro].to(device) + + muon_named = [(n, p) for n, p in engine.module.named_parameters() if p.ndim >= 2] + pre = {n: safe_get_full_fp32_param(p).clone() for n, p in muon_named} + + loss = engine(x, y) + engine.backward(loss) + engine.step() + + post = {n: safe_get_full_fp32_param(p).clone() for n, p in muon_named} + + # The post-step weight is all-gathered to every rank, so rank 0's assembled weight already + # reflects every rank's contribution (including the cross-partition slices owned by others). + if rank != 0: + return + + # Independent reference: same init, full global batch, real muon_update on the full grad. + # Run in fp16 to mirror the engine's forward/backward precision (minimizes the legitimate + # gap). weight_decay=0 and gradient_clipping=0 make the applied update exactly -lr*update. + ref = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers).to(device).half() + ref.load_state_dict({k: v.to(device).half() for k, v in init_state.items()}) + ref.zero_grad(set_to_none=True) + ref(gx.to(device).half(), gy.to(device)).backward() + ref_grad = {n: p.grad.detach().float() for n, p in ref.named_parameters() if p.ndim >= 2} + + changed = False + for n in pre: + applied_update = ((pre[n] - post[n]) / lr).float().cpu() # delta = -lr * update (wd=0, no clip) + if applied_update.abs().max().item() > 0: + changed = True + g = ref_grad[n] + # muon_update mutates grad/momentum in place; pass clones and a fresh zero buffer + # (matches the engine's lazily-zeroed first-step momentum buffer). + ref_update = muon_update(g.clone(), torch.zeros_like(g), beta=momentum, ns_method=ns_method).float().cpu() + rel_err = ((applied_update - ref_update).norm() / (ref_update.norm() + 1e-8)).item() + # Newton-Schulz amplifies fp16 gradient rounding, so a correct update still differs from + # the reference by a few percent (measured up to ~0.07 for gram, ~0.22 for standard); the + # #7807 partition-then-orthogonalize bug diverges by O(1) (measured ~0.6-0.67 on the + # cross-partition weight). 0.40 separates them robustly for both ns_method values. + assert rel_err < 0.40, ( + f"{n} (ZeRO-{zero_stage}, ns_method={ns_method}): Muon update rel error {rel_err:.3f} vs " + f"full-gradient reference -- orthogonalization likely ran on a partition slice rather than " + f"the full averaged gradient (#7807)") + assert changed, "optimizer step did not update any Muon weight (skipped step?)"