diff --git a/tests/test_optim.py b/tests/test_optim.py index 995e4fd8fb..4abf8dab57 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -402,6 +402,14 @@ def test_muon(optimizer): _test_model(optimizer, dict(lr=1e-3)) +@pytest.mark.parametrize('optimizer', ['adamuon', 'nadamuon']) +def test_adamuon(optimizer): + _test_rosenbrock( + lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) + ) + _test_model(optimizer, dict(lr=1e-3)) + + @pytest.mark.parametrize('optimizer', ['adopt', 'adoptw']) def test_adopt(optimizer): _test_rosenbrock( diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index 4d99406c75..5cf80d6220 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -897,6 +897,24 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None: has_betas=True, defaults={'nesterov': True} ), + OptimInfo( + name='adamuon', + opt_class=Muon, + description='AdaMuon: Muon with adaptive second moment estimation on orthogonalized directions', + has_momentum=True, + has_eps=True, + has_betas=True, + defaults={'algo': 'adamuon'} + ), + OptimInfo( + name='nadamuon', + opt_class=Muon, + description='AdaMuon with Nesterov momentum and NAdamW fallback for 1D params', + has_momentum=True, + has_eps=True, + has_betas=True, + defaults={'algo': 'adamuon', 'nesterov': True} + ), OptimInfo( name='novograd', opt_class=NvNovoGrad, diff --git a/timm/optim/muon.py b/timm/optim/muon.py index 15e0ef1b56..fe52c3d724 100644 --- a/timm/optim/muon.py +++ b/timm/optim/muon.py @@ -7,6 +7,10 @@ - Optional spatial normalization - Selectable coefficient presets - Automatic fallback to AdamW for 1D / scalar parameters (biases, norms, etc.) and optional fallback via param groups +- AdaMuon (https://arxiv.org/abs/2507.11005) +- mUP eps damping factor (https://arxiv.org/abs/2512.05620v1) + +TODO look into mUP LR scaling and independent weight-decay scale Based on implementation by Keller Jordan, see - https://github.com/KellerJordan/Muon/blob/master/muon.py @@ -76,6 +80,30 @@ NSCoeff = Union[str, Tuple[float, float, float], List[Tuple[float, float, float]]] +def scale_eps_for_ns( + eps: float, + shape: Tuple[int, ...], +) -> float: + """Scale epsilon for Newton-Schulz based on matrix dimensions (μP-style). + + For μP compatibility, epsilon should scale as eps * sqrt(din/dout) to maintain + consistent damping behavior across different model widths. + + Reference: https://arxiv.org/abs/2512.05620 + + Args: + eps: Base epsilon value + shape: Shape of the matrix (out, in) or (batch, out, in) + + Returns: + Scaled epsilon value + """ + # Get din, dout from shape (handle both 2D and 3D batched) + # FIXME TBD paper includes depth in the damping scale, e.g: eps * (din / dout) ** 0.5 / N + dout, din = (shape[-2], shape[-1]) + return eps * (din / dout) ** 0.5 + + def zeropower_via_newtonschulz( G: torch.Tensor, steps: int, @@ -83,6 +111,7 @@ def zeropower_via_newtonschulz( eps: float = MUON_EPS, safety_factor: float = 1.0, dtype: torch.dtype = torch.bfloat16, + scale_eps: bool = False, ) -> torch.Tensor: """Newton-Schulz quintic iteration to compute the zeroth power / orthogonalization of gradient. @@ -100,6 +129,7 @@ def zeropower_via_newtonschulz( eps: Numerical stability epsilon for norm safety_factor: Multiplicative safety factor for norm (1.01 is common safety value in 'polar express' variants) dtype: Computation dtype + scale_eps: If True, scale epsilon by sqrt(din/dout) for μP compatibility Returns: Orthogonalized tensor of same shape as G @@ -111,6 +141,10 @@ def zeropower_via_newtonschulz( coeff_sequence = coefficients[:steps] if steps <= num_cs else \ coefficients + [coefficients[-1]] * (steps - num_cs) + # Scale epsilon by sqrt(din/dout) for μP compatibility if requested + if scale_eps: + eps = scale_eps_for_ns(eps, G.shape) + X = G.to(dtype=dtype, copy=True) # Transpose if needed (operate on dimension with fewer elements) @@ -119,7 +153,11 @@ def zeropower_via_newtonschulz( X = X.mT # Normalize spectral norm to at most 1 - X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).clamp_min(eps)) + if scale_eps: + # more of a damping factor in this case, use add instead of clamp + X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).add_(eps)) + else: + X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).clamp_min_(eps)) # Batched vs unbatched fused MM mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm @@ -145,9 +183,17 @@ def zeropower_via_newtonschulz( def get_lr_scale( param_shape: torch.Size, - adjust_lr_fn: str = "match_rms_adamw" + adjust_lr_fn: str = "match_rms_adamw", ) -> float: - """Adjust learning rate based on parameter shape.""" + """Adjust learning rate based on parameter shape for Muon. + + Args: + param_shape: Shape of the parameter tensor + adjust_lr_fn: Scaling function name + - "original": sqrt(max(1, out/in)) - Original Muon impl + - "match_rms_adamw": 0.2 * sqrt(max(out, in)) - Kimi scaling + - "rms_to_rms": sqrt(out/in) - Scion/Bernstein scaling + """ out_chs, in_chs = (param_shape[-2], param_shape[-1]) if len(param_shape) > 1 else (1., 1.) if adjust_lr_fn == "original": @@ -161,7 +207,34 @@ def get_lr_scale( # Bernstein et al. (https://jeremybernste.in/writing/deriving-muon) return (out_chs / in_chs) ** 0.5 else: - assert False, f'Invalid scaling function "{adjust_lr_fn}"' + assert False, f'Invalid scaling function "{adjust_lr_fn}" for Muon' + + +def get_adamuon_lr_scale( + param_shape: torch.Size, + adjust_lr_fn: str = "match_rms_adamw", +) -> Tuple[float, bool]: + """Adjust learning rate based on parameter shape for AdaMuon. + + Args: + param_shape: Shape of the parameter tensor + adjust_lr_fn: Scaling function name + + Returns: + Tuple of (scale_factor, use_rms_norm) + """ + out_chs, in_chs = (param_shape[-2], param_shape[-1]) if len(param_shape) > 1 else (1., 1.) + + if adjust_lr_fn == "match_rms_adamw": + # AdaMuon paper: normalize by RMS, then scale by 0.2 * sqrt(numel) + # https://arxiv.org/abs/2507.11005 + return 0.2 * (out_chs * in_chs) ** 0.5, True + elif adjust_lr_fn == "rms_to_rms": + return (out_chs / in_chs) ** 0.5, False + elif adjust_lr_fn == "rsqrt_in": + return in_chs ** -0.5, False + else: + assert False, f'Invalid scaling function "{adjust_lr_fn}" for AdaMuon' def _is_suitable_for_muon( @@ -287,6 +360,7 @@ def muon( adjust_lr_fn: Optional[str], conv_mode: str, normalize_spatial: bool, + scale_eps: bool, ) -> None: """Functional API that performs Muon algorithm computation.""" _single_tensor_muon( @@ -304,6 +378,58 @@ def muon( adjust_lr_fn=adjust_lr_fn, conv_mode=conv_mode, normalize_spatial=normalize_spatial, + scale_eps=scale_eps, + ) + + +def adamuon( + params: List[torch.Tensor], + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + exp_avg_sqs: List[torch.Tensor], + state_steps: List[torch.Tensor], + *, + lr: float, + weight_decay: float, + momentum: float, + nesterov: bool, + beta2: float, + ns_steps: int, + ns_coefficients: NSCoeff, + eps: float, + safety_factor: float, + adjust_lr_fn: Optional[str], + conv_mode: str, + normalize_spatial: bool, + scale_eps: bool, +) -> None: + """Functional API that performs AdaMuon algorithm computation. + + AdaMuon extends Muon with element-wise second moment estimation applied + to orthogonalized update directions, providing Adam-like adaptive scaling + while preserving Muon's geometric benefits. + + Reference: https://arxiv.org/abs/2507.11005 + """ + _single_tensor_adamuon( + params, + grads, + momentum_bufs, + exp_avg_sqs, + state_steps, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + beta2=beta2, + ns_steps=ns_steps, + ns_coefficients=ns_coefficients, + eps=eps, + safety_factor=safety_factor, + adjust_lr_fn=adjust_lr_fn, + conv_mode=conv_mode, + normalize_spatial=normalize_spatial, + scale_eps=scale_eps, ) @@ -323,6 +449,7 @@ def _single_tensor_muon( adjust_lr_fn: Optional[str], conv_mode: str, normalize_spatial: bool, + scale_eps: bool, ) -> None: """Single tensor Muon update.""" ns_coefficients = resolve_ns_coefficients(ns_coefficients, _COEFFICIENTS) @@ -352,11 +479,14 @@ def _single_tensor_muon( ns_coefficients, eps=eps, safety_factor=safety_factor, - #dtype=torch.bfloat16, # wire to arg? + scale_eps=scale_eps, ) # Adjust learning rate based on parameter shape - scale = get_lr_scale(update_ortho.shape, adjust_lr_fn) + if adjust_lr_fn: + scale = get_lr_scale(update_ortho.shape, adjust_lr_fn) + else: + scale = 1.0 # Apply spatial normalization and permute back if in batched mode if conv_mode == "batched" and update_ortho.ndim >= 3: @@ -372,11 +502,131 @@ def _single_tensor_muon( param.add_(update_ortho, alpha=-lr * scale) +def _single_tensor_adamuon( + params: List[torch.Tensor], + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + exp_avg_sqs: List[torch.Tensor], + state_steps: List[torch.Tensor], + *, + lr: float, + weight_decay: float, + momentum: float, + nesterov: bool, + beta2: float, + ns_steps: int, + ns_coefficients: NSCoeff, + eps: float, + safety_factor: float, + adjust_lr_fn: Optional[str], + conv_mode: str, + normalize_spatial: bool, + scale_eps: bool, +) -> None: + """Single tensor AdaMuon update. + + AdaMuon applies second-moment estimation to the orthogonalized directions, + then rescales using RMS-alignment to maintain stable step sizes. + + Algorithm: + 1. Update momentum buffer: M = β₁·M + (1-β₁)·G + 2. Orthogonalize: O = Newton-Schulz(M) or Newton-Schulz(nesterov_update) + 3. Update second moment: v = β₂·v + (1-β₂)·O² + 4. Bias correct: v̂ = v/(1-β₂^t) + 5. Adaptive scaling: Ô = O / (√v̂ + ε) + 6. RMS-aligned rescaling and apply update + """ + ns_coefficients = resolve_ns_coefficients(ns_coefficients, _COEFFICIENTS) + + for i, param in enumerate(params): + grad = grads[i] + momentum_buf = momentum_bufs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Increment step + step_t += 1 + step = step_t.item() + + # Apply weight decay (decoupled) + param.mul_(1 - lr * weight_decay) + + # Update momentum buffer + momentum_buf.lerp_(grad, 1. - momentum) + update = grad.lerp_(momentum_buf, momentum) if nesterov else momentum_buf.clone() + + # Reshape for processing (handle 3D+ tensors like conv weights) + if update.ndim >= 3: + update_reshaped, original_shape = reshape_for_muon(update, mode=conv_mode) + else: + update_reshaped = update + original_shape = update.shape + + # Apply Newton-Schulz orthogonalization + update_ortho = zeropower_via_newtonschulz( + update_reshaped, + ns_steps, + ns_coefficients, + eps=eps, + safety_factor=safety_factor, + scale_eps=scale_eps, + ) + + # Reshape back to original shape for second moment tracking + if conv_mode == "batched" and update_ortho.ndim >= 3: + # Permute back: (spatial_prod, out, in) -> (out, in, spatial_prod) + update_ortho = update_ortho.permute(1, 2, 0) + update_ortho = update_ortho.reshape(original_shape) + + # Update second moment on orthogonalized directions (element-wise) + exp_avg_sq.mul_(beta2).addcmul_(update_ortho, update_ortho, value=1.0 - beta2) + + # Get shape-based LR scaling and whether to apply RMS normalization + if adjust_lr_fn: + scale, use_rms_norm = get_adamuon_lr_scale(update_ortho.shape, adjust_lr_fn) + else: + scale, use_rms_norm = 1.0, False + + if use_rms_norm: + # Bias correction not needed if scaling by norm + denom = exp_avg_sq.sqrt().add_(eps) + else: + # Bias correction for second moment + bias_correction2 = 1.0 - beta2 ** step + denom = (exp_avg_sq / bias_correction2).sqrt().add_(eps) + + # Adaptive scaling: divide by sqrt of bias-corrected second moment + # This is the key AdaMuon modification + update_adaptive = update_ortho / denom + + # RMS-aligned rescaling: normalize by update norm, then scale by shape factor + # Used by AdaMuon paper approach (match_rms_adamw), not by μP approach (rms_to_rms) + if use_rms_norm: + update_norm = update_adaptive.norm().add_(eps) + update_adaptive = update_adaptive / update_norm + + # Apply spatial normalization if in batched mode + if conv_mode == "batched" and len(original_shape) >= 3: + if normalize_spatial: + spatial_prod = 1 + for d in original_shape[2:]: + spatial_prod *= d + scale *= spatial_prod ** -0.5 + + # Apply update + param.add_(update_adaptive, alpha=-lr * scale) + + class Muon(torch.optim.Optimizer): """Muon - MomentUm Orthogonalized by Newton-schulz Combines Muon for 2D+ parameters (weight matrices) with AdamW for 1D parameters (biases, norms) and parameter groups with 'use_fallback=True' set (or 'use_muon=False' for compatibility). + + Supports two algorithms: + - "muon": Standard Muon algorithm with momentum + orthogonalization + - "adamuon": AdaMuon algorithm that adds element-wise second moment estimation + to orthogonalized directions for Adam-like adaptive scaling """ def __init__( @@ -395,6 +645,8 @@ def __init__( normalize_spatial: bool = True, adamw_lr: Optional[float] = None, betas: Tuple[float, float] = (0.9, 0.95), + algo: str = "muon", + scale_eps: bool = False, verbose: bool = False, ): """ Create Muon optimizer. @@ -408,11 +660,17 @@ def __init__( ns_coefficients: Coefficients for NS iteration eps: Numerical stability epsilon safety_factor: Multiplicative safety factor for NS norm - adjust_lr_fn: LR adjustment function - "original" or "match_rms_adamw" + adjust_lr_fn: LR adjustment function - "original", "match_rms_adamw", or "rms_to_rms". + For adamuon mode, can set to None to disable (RMS rescaling handles scaling). conv_mode: How to handle convolutions - "flatten" or "batched" normalize_spatial: Whether to normalize by sqrt(spatial_size) in batched mode adamw_lr: Learning rate for AdamW (1D params), defaults to lr if not specified - betas: AdamW beta coefficients + betas: Beta coefficients - (beta1, beta2) where beta1 is used for AdamW fallback + and beta2 is used for both AdamW fallback and AdaMuon second moment + algo: Algorithm - "muon" for standard Muon, "adamuon" for AdaMuon with + adaptive second moment estimation (https://arxiv.org/abs/2507.11005) + scale_eps: If True, scale epsilon by sqrt(din/dout) in Newton-Schulz for μP + compatibility (https://arxiv.org/abs/2512.05620) verbose: Log parameter routing decisions (Muon vs AdamW) Example: @@ -420,6 +678,9 @@ def __init__( # Simple usage - automatically uses Muon for 2D+ params, AdamW for 1D optimizer = Muon(model.parameters(), lr=0.02) + # Use AdaMuon algorithm for adaptive scaling + optimizer = Muon(model.parameters(), lr=6e-4, algo="adamuon") + # Manual control over parameter groups optimizer = Muon([ {'params': weight_matrices, 'lr': 0.02}, @@ -437,6 +698,8 @@ def __init__( raise ValueError(f"Invalid epsilon value: {eps}") if conv_mode not in ["flatten", "batched"]: raise ValueError(f"Invalid conv_mode: {conv_mode}") + if algo not in ["muon", "adamuon"]: + raise ValueError(f"Invalid algo: {algo}. Must be 'muon' or 'adamuon'") defaults = dict( lr=lr, @@ -452,10 +715,18 @@ def __init__( normalize_spatial=normalize_spatial, adamw_lr=adamw_lr if adamw_lr is not None else lr, betas=betas, + algo=algo, + scale_eps=scale_eps, verbose=verbose, ) super().__init__(params, defaults) + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('algo', 'muon') + group.setdefault('scale_eps', False) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step.""" @@ -472,10 +743,15 @@ def step(self, closure=None): routing_reasons = {} if verbose else None for group in self.param_groups: + algo = group.get("algo", "muon") + # Separate params into Muon and AdamW groups muon_params = [] muon_grads = [] muon_momentum_bufs = [] + # Additional state for adamuon mode + muon_exp_avg_sqs = [] + muon_state_steps = [] adamw_params = [] adamw_grads = [] @@ -529,10 +805,18 @@ def step(self, closure=None): muon_grads.append(p.grad) muon_count += 1 - # State initialization for Muon + # State initialization for Muon/AdaMuon if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) muon_momentum_bufs.append(state["momentum_buffer"]) + + # Additional state for adamuon mode + if algo == "adamuon": + if "step" not in state: + state["step"] = torch.tensor(0.) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + muon_exp_avg_sqs.append(state["exp_avg_sq"]) + muon_state_steps.append(state["step"]) else: # Collect AdamW/NAdamW params adamw_params.append(p) @@ -549,24 +833,48 @@ def step(self, closure=None): adamw_exp_avg_sqs.append(state["exp_avg_sq"]) adamw_state_steps.append(state["step"]) - # Apply Muon updates + # Apply Muon/AdaMuon updates if muon_params: - muon( - muon_params, - muon_grads, - muon_momentum_bufs, - lr=group["lr"], - weight_decay=group["weight_decay"], - momentum=group["momentum"], - nesterov=group["nesterov"], - ns_steps=group["ns_steps"], - ns_coefficients=group["ns_coefficients"], - eps=group["eps"], - safety_factor=group["safety_factor"], - adjust_lr_fn=group["adjust_lr_fn"], - conv_mode=group["conv_mode"], - normalize_spatial=group["normalize_spatial"], - ) + if algo == "adamuon": + _, beta2 = group["betas"] + adamuon( + muon_params, + muon_grads, + muon_momentum_bufs, + muon_exp_avg_sqs, + muon_state_steps, + lr=group["lr"], + weight_decay=group["weight_decay"], + momentum=group["momentum"], + nesterov=group["nesterov"], + beta2=beta2, + ns_steps=group["ns_steps"], + ns_coefficients=group["ns_coefficients"], + eps=group["eps"], + safety_factor=group["safety_factor"], + adjust_lr_fn=group["adjust_lr_fn"], + conv_mode=group["conv_mode"], + normalize_spatial=group["normalize_spatial"], + scale_eps=group["scale_eps"], + ) + else: + muon( + muon_params, + muon_grads, + muon_momentum_bufs, + lr=group["lr"], + weight_decay=group["weight_decay"], + momentum=group["momentum"], + nesterov=group["nesterov"], + ns_steps=group["ns_steps"], + ns_coefficients=group["ns_coefficients"], + eps=group["eps"], + safety_factor=group["safety_factor"], + adjust_lr_fn=group["adjust_lr_fn"], + conv_mode=group["conv_mode"], + normalize_spatial=group["normalize_spatial"], + scale_eps=group["scale_eps"], + ) # Apply AdamW updates if adamw_params: