diff --git a/init2winit/optimizer_lib/muon.py b/init2winit/optimizer_lib/muon.py index cc799ab0..3c8f09ab 100644 --- a/init2winit/optimizer_lib/muon.py +++ b/init2winit/optimizer_lib/muon.py @@ -59,12 +59,6 @@ def orthogonalize_via_newton_schulz( Returns: The orthogonalized matrix. """ - was_reshaped = False - original_shape = updates.shape - - if updates.ndim == 3: - updates = updates.reshape(updates.shape[0], -1) - was_reshaped = True if updates.ndim != 2: raise ValueError(f'Input must be 2D, got {updates.shape}') @@ -99,18 +93,15 @@ def newton_schulz_iterator(x: jax.Array, coeffs: jax.Array) -> jax.Array: fan_in = orthogonalized_updates.shape[0] # Scaling factor taken from https://jeremybernste.in/writing/deriving-muon - # and https://github.com/KellerJordan/modded-nanogpt/blame/822ab2dd79140ed34ae43a20450f0bdc36457a24/train_gpt.py#L184 # pylint: disable=line-too-long scale_factor = jnp.maximum(1.0, jnp.sqrt(fan_out / fan_in)) orthogonalized_updates *= scale_factor - if was_reshaped: - orthogonalized_updates = orthogonalized_updates.reshape(original_shape) - return orthogonalized_updates class MuonState(NamedTuple): - """State for the Adam algorithm.""" + """State for Muon.""" + count: chex.Array # shape=(), dtype=jnp.int32. momentum: optax.Updates @@ -128,6 +119,10 @@ def _bias_correction(moment, decay, count): return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) +def _default_reshape_fn(x): + return x.reshape(x.shape[0], -1) + + def scale_by_muon( learning_rate: float = 0.0, beta: float = 0.95, @@ -137,9 +132,16 @@ def scale_by_muon( eps: float = 1e-7, nesterov: bool = True, bias_correction: bool = False, + reshape_fn=_default_reshape_fn, ) -> optax.GradientTransformation: r"""Rescale updates according to the Muon algorithm. + This is the pure Muon transformation: momentum followed by Newton-Schulz + orthogonalization and decoupled weight decay. It does **not** include a + pair optimizer or masking logic. The caller is responsible for partitioning + parameters (e.g. via ``optax.masked``) and constructing a separate + optimizer for non-Muon parameters. + Args: learning_rate: Learning rate. beta: Decay rate for the gradient momentum. @@ -149,6 +151,10 @@ def scale_by_muon( eps: Term added to denominators to improve numerical stability. nesterov: Whether to use Nesterov momentum. bias_correction: Whether to perform bias correction. + reshape_fn: A function ``(jnp.ndarray,) -> jnp.ndarray`` that reshapes a >2D + parameter to 2D for orthogonalization. The result is reshaped back to the + original shape afterwards. Defaults to ``lambda x: x.reshape(x.shape[0], + -1)``. Pass ``None`` to disable (will error on >2D inputs). Returns: A `GradientTransformation` object. @@ -165,7 +171,7 @@ def scale_by_muon( def init_fn(params): momentum = jax.tree_util.tree_map( lambda p: jnp.zeros_like(p, jnp.float32), params - ) # First moment + ) return MuonState( count=jnp.zeros([], jnp.int32), @@ -184,19 +190,22 @@ def update_fn(updates, state, params=None): if bias_correction: momentum = _bias_correction(momentum, beta, new_count) - # Apply Newton-schulz orthogonalization. + def _orthogonalize(x): + orig_shape = x.shape + if x.ndim > 2 and reshape_fn is not None: + x = reshape_fn(x) + x = orthogonalize_via_newton_schulz(x, ns_coeffs, ns_steps, eps) + return x.reshape(orig_shape) + scaled_orthogonalized_momentum = jax.tree_util.tree_map( - lambda x: orthogonalize_via_newton_schulz( - x, ns_coeffs, ns_steps, eps - ), + _orthogonalize, momentum, ) - # Apply weight decay similar to how it's being applied here : - # https://github.com/KellerJordan/Muon/commit/e0ffefd4f7ea88f2db724caa2c7cfe859155995d#diff-ff0575a769b2390ce3256edb1c20e4d741d514a77c4f0697c2fa628f810f46b1R60-R80 new_updates = jax.tree_util.tree_map( lambda u, p: -learning_rate * (u + weight_decay * p), - scaled_orthogonalized_momentum, params + scaled_orthogonalized_momentum, + params, ) return new_updates, MuonState( diff --git a/init2winit/optimizer_lib/optimizers.py b/init2winit/optimizer_lib/optimizers.py index a20344d8..0ba13646 100644 --- a/init2winit/optimizer_lib/optimizers.py +++ b/init2winit/optimizer_lib/optimizers.py @@ -167,15 +167,41 @@ def get_optimizer(hps, model=None, batch_axis_name=None): weight_decay=weight_decay, ) elif hps.optimizer == 'muon': - opt_init, opt_update = utils.static_inject_hyperparams(muon.scale_by_muon)( - learning_rate=0.0, # Manually injected on each train step. - weight_decay=hps.opt_hparams.get('weight_decay', 0.01), - beta=hps.opt_hparams.get('beta', 0.95), - nesterov=hps.opt_hparams.get('nesterov', True), - ns_coeffs=hps.opt_hparams.get('ns_coeffs', (3.4445, -4.7750, 2.0315)), - ns_steps=hps.opt_hparams.get('ns_steps', 5), - eps=hps.opt_hparams.get('eps', 1e-7), - bias_correction=hps.opt_hparams.get('bias_correction', False), + muon_hparams = hps.opt_hparams.get('muon_hparams', {}) + pair_optimizer = hps.opt_hparams.get('pair_optimizer', 'rmsprop') + pair_hparams = hps.opt_hparams.get('pair_hparams', {}) + + pair_hps = copy.deepcopy(hps) + pair_hps.optimizer = pair_optimizer + pair_hps.opt_hparams = pair_hparams + pair_hps.l2_decay_factor = None + pair_tx = optax.GradientTransformation(*get_optimizer(pair_hps)) + + muon_mask = hps.opt_hparams.get('muon_mask', None) + if muon_mask is None: + muon_mask = lambda p: jax.tree_util.tree_map_with_path( + lambda path, x: 'muon' + if (x.ndim >= 2 and 'embed' not in jax.tree_util.keystr(path).lower()) + else 'pair', + p, + ) + + def _muon_with_pair(learning_rate=0.0): + _muon_hparams = dict(**muon_hparams) + learning_rate_multiplier = _muon_hparams.pop('lr_multiplier', 1.0) + _muon_opt = optax.chain( + muon.scale_by_muon(learning_rate, **_muon_hparams), + optax.scale_by_learning_rate( + learning_rate_multiplier, flip_sign=False + ), + ) + return optax.partition( + transforms={'muon': _muon_opt, 'pair': pair_tx}, + param_labels=muon_mask, + ) + + opt_init, opt_update = utils.static_inject_hyperparams(_muon_with_pair)( + learning_rate=0.0, ) elif hps.optimizer == 'diag_bubbles': opt_init, opt_update = utils.static_inject_hyperparams( diff --git a/init2winit/optimizer_lib/test_optimizers.py b/init2winit/optimizer_lib/test_optimizers.py index 7a3484fb..75403a94 100644 --- a/init2winit/optimizer_lib/test_optimizers.py +++ b/init2winit/optimizer_lib/test_optimizers.py @@ -14,6 +14,7 @@ # limitations under the License. """Tests for optimizers.""" + import shutil import tempfile @@ -23,10 +24,12 @@ from init2winit.init_lib import initializers from init2winit.model_lib import model_utils from init2winit.model_lib import models +from init2winit.optimizer_lib import muon from init2winit.optimizer_lib import optimizers from init2winit.optimizer_lib import utils as optimizers_utils from init2winit.optimizer_lib.kitchen_sink._src.transform import ScaleByAdapropState # pylint: disable=g-importing-member import jax +import jax.numpy as jnp from ml_collections import config_dict import optax from optax._src.transform import ScaleByAdamState # pylint: disable=g-importing-member @@ -97,7 +100,8 @@ def test_generic_multi_optimizer_init(self): experiment_config.initializer, experiment_config.dataset, hparam_file=None, - hparam_overrides=experiment_config.hparam_overrides) + hparam_overrides=experiment_config.hparam_overrides, + ) model = model_cls( merged_hps, @@ -117,8 +121,8 @@ def test_generic_multi_optimizer_init(self): unreplicated_optimizer_state = opt_init_fn(unreplicated_params) self.assertIsInstance( - unreplicated_optimizer_state, - optax.transforms.PartitionState) + unreplicated_optimizer_state, optax.transforms.PartitionState + ) # unreplicated_optimizer_state should be a Dict mapping param type # to opt_state where only params mapping to that param_type have non-empty @@ -159,5 +163,161 @@ def tearDown(self): + +class MuonTest(absltest.TestCase): + """Tests for Muon optimizer.""" + + def test_muon_split(self): + """Verifies that Muon handles 2D params and RMSProp handles 1D params.""" + params = { + 'p1d': jnp.ones((10,)), + 'p2d': jnp.ones((10, 10)), + } + grads = { + 'p1d': jnp.full((10,), 0.1), + 'p2d': jnp.eye(10) * 0.1, + } + + lr = 0.1 + beta = 0.95 + wd = 0.01 + + muon_tx = muon.scale_by_muon( + learning_rate=lr, + beta=beta, + weight_decay=wd, + ) + rms_prop = optax.chain( + optax.scale_by_rms(decay=beta, eps=1e-7), + optax.add_decayed_weights(wd), + optax.scale_by_learning_rate(lr, flip_sign=True), + ) + muon_mask = lambda p: jax.tree_util.tree_map(lambda x: x.ndim >= 2, p) + rmsprop_mask = lambda p: jax.tree_util.tree_map(lambda x: x.ndim < 2, p) + tx = optax.chain( + optax.masked(muon_tx, mask=muon_mask), + optax.masked(rms_prop, mask=rmsprop_mask), + ) + + state = tx.init(params) + updates, _ = tx.update(grads, state, params) + + expected_1d = -0.4482136 + self.assertTrue(jnp.allclose(updates['p1d'], expected_1d, rtol=1e-3)) + + self.assertFalse(jnp.allclose(updates['p2d'], -0.4482136, rtol=1e-1)) + + def test_muon_get_optimizer(self): + """Verifies get_optimizer('muon') matches manual construction.""" + params = { + 'p1d': jnp.ones((10,)), + 'p2d': jnp.ones((10, 10)), + } + grads = { + 'p1d': jnp.full((10,), 0.1), + 'p2d': jnp.eye(10) * 0.1, + } + + lr = 0.1 + beta = 0.95 + wd = 0.01 + + hps = config_dict.ConfigDict({ + 'optimizer': 'muon', + 'l2_decay_factor': None, + 'batch_size': 8, + 'opt_hparams': { + 'muon_hparams': { + 'beta': beta, + 'weight_decay': wd, + }, + 'pair_optimizer': 'sgd', + 'pair_hparams': { + 'weight_decay': wd, + }, + }, + }) + + opt_init, opt_update = optimizers.get_optimizer(hps) + state = opt_init(params) + state = optimizers.inject_learning_rate(state, lr) + updates, _ = opt_update(grads, state, params=params) + + muon_tx = muon.scale_by_muon( + learning_rate=lr, + beta=beta, + weight_decay=wd, + ) + sgd_tx = optax.chain( + optax.add_decayed_weights(wd), + optax.sgd(learning_rate=lr, momentum=None, nesterov=False), + ) + muon_mask = lambda p: jax.tree_util.tree_map(lambda x: x.ndim >= 2, p) + pair_mask = lambda p: jax.tree_util.tree_map(lambda x: x.ndim < 2, p) + manual_tx = optax.chain( + optax.masked(muon_tx, mask=muon_mask), + optax.masked(sgd_tx, mask=pair_mask), + ) + manual_state = manual_tx.init(params) + manual_updates, _ = manual_tx.update(grads, manual_state, params) + + self.assertTrue( + jnp.allclose(updates['p2d'], manual_updates['p2d'], atol=1e-5) + ) + self.assertTrue( + jnp.allclose(updates['p1d'], manual_updates['p1d'], atol=1e-5) + ) + + def test_muon_lr_multiplier(self): + """Verifies lr_multiplier scales Muon's effective learning rate.""" + params = {'w': jnp.ones((10, 10))} + grads = {'w': jnp.eye(10) * 0.1} + + lr = 0.1 + base = muon.scale_by_muon(learning_rate=lr, lr_multiplier=1.0) + scaled = muon.scale_by_muon(learning_rate=lr, lr_multiplier=2.0) + + base_state = base.init(params) + scaled_state = scaled.init(params) + + base_updates, _ = base.update(grads, base_state, params) + scaled_updates, _ = scaled.update(grads, scaled_state, params) + + self.assertTrue( + jnp.allclose(scaled_updates['w'], 2.0 * base_updates['w'], atol=1e-6) + ) + + def test_muon_3d_reshape(self): + """Verifies Muon reshapes >2D params to (s0, -1) for orthogonalization.""" + params_3d = {'w': jnp.ones((4, 5, 6))} + grads_3d = {'w': jax.random.normal(jax.random.PRNGKey(0), (4, 5, 6))} + + tx = muon.scale_by_muon(learning_rate=0.1) + state = tx.init(params_3d) + updates, _ = tx.update(grads_3d, state, params_3d) + + self.assertEqual(updates['w'].shape, (4, 5, 6)) + + def test_muon_embed_mask(self): + """Verifies default mask excludes embed params and 1D params.""" + params = { + 'embed': jnp.ones((100, 64)), + 'dense': jnp.ones((64, 64)), + 'bias': jnp.ones((64,)), + } + + muon_mask = lambda p: jax.tree_util.tree_map_with_path( + lambda path, x: ( + x.ndim >= 2 and 'embed' not in jax.tree_util.keystr(path).lower() + ), + p, + ) + mask = muon_mask(params) + + self.assertTrue(mask['dense']) + self.assertFalse(mask['embed']) + self.assertFalse(mask['bias']) + + if __name__ == '__main__': absltest.main()