Skip to content
Open

Muon #873

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions init2winit/optimizer_lib/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,6 @@ def orthogonalize_via_newton_schulz(
Returns:
The orthogonalized matrix.
"""
was_reshaped = False
original_shape = updates.shape

if updates.ndim == 3:
updates = updates.reshape(updates.shape[0], -1)
was_reshaped = True

if updates.ndim != 2:
raise ValueError(f'Input must be 2D, got {updates.shape}')
Expand Down Expand Up @@ -99,18 +93,15 @@ def newton_schulz_iterator(x: jax.Array, coeffs: jax.Array) -> jax.Array:
fan_in = orthogonalized_updates.shape[0]

# Scaling factor taken from https://jeremybernste.in/writing/deriving-muon
# and https://github.com/KellerJordan/modded-nanogpt/blame/822ab2dd79140ed34ae43a20450f0bdc36457a24/train_gpt.py#L184 # pylint: disable=line-too-long
scale_factor = jnp.maximum(1.0, jnp.sqrt(fan_out / fan_in))
orthogonalized_updates *= scale_factor

if was_reshaped:
orthogonalized_updates = orthogonalized_updates.reshape(original_shape)

return orthogonalized_updates


class MuonState(NamedTuple):
"""State for the Adam algorithm."""
"""State for Muon."""

count: chex.Array # shape=(), dtype=jnp.int32.
momentum: optax.Updates

Expand All @@ -128,6 +119,10 @@ def _bias_correction(moment, decay, count):
return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment)


def _default_reshape_fn(x):
return x.reshape(x.shape[0], -1)


def scale_by_muon(
learning_rate: float = 0.0,
beta: float = 0.95,
Expand All @@ -137,9 +132,16 @@ def scale_by_muon(
eps: float = 1e-7,
nesterov: bool = True,
bias_correction: bool = False,
reshape_fn=_default_reshape_fn,
) -> optax.GradientTransformation:
r"""Rescale updates according to the Muon algorithm.

This is the pure Muon transformation: momentum followed by Newton-Schulz
orthogonalization and decoupled weight decay. It does **not** include a
pair optimizer or masking logic. The caller is responsible for partitioning
parameters (e.g. via ``optax.masked``) and constructing a separate
optimizer for non-Muon parameters.

Args:
learning_rate: Learning rate.
beta: Decay rate for the gradient momentum.
Expand All @@ -149,6 +151,10 @@ def scale_by_muon(
eps: Term added to denominators to improve numerical stability.
nesterov: Whether to use Nesterov momentum.
bias_correction: Whether to perform bias correction.
reshape_fn: A function ``(jnp.ndarray,) -> jnp.ndarray`` that reshapes a >2D
parameter to 2D for orthogonalization. The result is reshaped back to the
original shape afterwards. Defaults to ``lambda x: x.reshape(x.shape[0],
-1)``. Pass ``None`` to disable (will error on >2D inputs).

Returns:
A `GradientTransformation` object.
Expand All @@ -165,7 +171,7 @@ def scale_by_muon(
def init_fn(params):
momentum = jax.tree_util.tree_map(
lambda p: jnp.zeros_like(p, jnp.float32), params
) # First moment
)

return MuonState(
count=jnp.zeros([], jnp.int32),
Expand All @@ -184,19 +190,22 @@ def update_fn(updates, state, params=None):
if bias_correction:
momentum = _bias_correction(momentum, beta, new_count)

# Apply Newton-schulz orthogonalization.
def _orthogonalize(x):
orig_shape = x.shape
if x.ndim > 2 and reshape_fn is not None:
x = reshape_fn(x)
x = orthogonalize_via_newton_schulz(x, ns_coeffs, ns_steps, eps)
return x.reshape(orig_shape)

scaled_orthogonalized_momentum = jax.tree_util.tree_map(
lambda x: orthogonalize_via_newton_schulz(
x, ns_coeffs, ns_steps, eps
),
_orthogonalize,
momentum,
)

# Apply weight decay similar to how it's being applied here :
# https://github.com/KellerJordan/Muon/commit/e0ffefd4f7ea88f2db724caa2c7cfe859155995d#diff-ff0575a769b2390ce3256edb1c20e4d741d514a77c4f0697c2fa628f810f46b1R60-R80
new_updates = jax.tree_util.tree_map(
lambda u, p: -learning_rate * (u + weight_decay * p),
scaled_orthogonalized_momentum, params
scaled_orthogonalized_momentum,
params,
)

return new_updates, MuonState(
Expand Down
44 changes: 35 additions & 9 deletions init2winit/optimizer_lib/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,41 @@ def get_optimizer(hps, model=None, batch_axis_name=None):
weight_decay=weight_decay,
)
elif hps.optimizer == 'muon':
opt_init, opt_update = utils.static_inject_hyperparams(muon.scale_by_muon)(
learning_rate=0.0, # Manually injected on each train step.
weight_decay=hps.opt_hparams.get('weight_decay', 0.01),
beta=hps.opt_hparams.get('beta', 0.95),
nesterov=hps.opt_hparams.get('nesterov', True),
ns_coeffs=hps.opt_hparams.get('ns_coeffs', (3.4445, -4.7750, 2.0315)),
ns_steps=hps.opt_hparams.get('ns_steps', 5),
eps=hps.opt_hparams.get('eps', 1e-7),
bias_correction=hps.opt_hparams.get('bias_correction', False),
muon_hparams = hps.opt_hparams.get('muon_hparams', {})
pair_optimizer = hps.opt_hparams.get('pair_optimizer', 'rmsprop')
pair_hparams = hps.opt_hparams.get('pair_hparams', {})

pair_hps = copy.deepcopy(hps)
pair_hps.optimizer = pair_optimizer
pair_hps.opt_hparams = pair_hparams
pair_hps.l2_decay_factor = None
pair_tx = optax.GradientTransformation(*get_optimizer(pair_hps))

muon_mask = hps.opt_hparams.get('muon_mask', None)
if muon_mask is None:
muon_mask = lambda p: jax.tree_util.tree_map_with_path(
lambda path, x: 'muon'
if (x.ndim >= 2 and 'embed' not in jax.tree_util.keystr(path).lower())
else 'pair',
p,
)

def _muon_with_pair(learning_rate=0.0):
_muon_hparams = dict(**muon_hparams)
learning_rate_multiplier = _muon_hparams.pop('lr_multiplier', 1.0)
_muon_opt = optax.chain(
muon.scale_by_muon(learning_rate, **_muon_hparams),
optax.scale_by_learning_rate(
learning_rate_multiplier, flip_sign=False
),
)
return optax.partition(
transforms={'muon': _muon_opt, 'pair': pair_tx},
param_labels=muon_mask,
)

opt_init, opt_update = utils.static_inject_hyperparams(_muon_with_pair)(
learning_rate=0.0,
)
elif hps.optimizer == 'diag_bubbles':
opt_init, opt_update = utils.static_inject_hyperparams(
Expand Down
166 changes: 163 additions & 3 deletions init2winit/optimizer_lib/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

"""Tests for optimizers."""

import shutil
import tempfile

Expand All @@ -23,10 +24,12 @@
from init2winit.init_lib import initializers
from init2winit.model_lib import model_utils
from init2winit.model_lib import models
from init2winit.optimizer_lib import muon
from init2winit.optimizer_lib import optimizers
from init2winit.optimizer_lib import utils as optimizers_utils
from init2winit.optimizer_lib.kitchen_sink._src.transform import ScaleByAdapropState # pylint: disable=g-importing-member
import jax
import jax.numpy as jnp
from ml_collections import config_dict
import optax
from optax._src.transform import ScaleByAdamState # pylint: disable=g-importing-member
Expand Down Expand Up @@ -97,7 +100,8 @@ def test_generic_multi_optimizer_init(self):
experiment_config.initializer,
experiment_config.dataset,
hparam_file=None,
hparam_overrides=experiment_config.hparam_overrides)
hparam_overrides=experiment_config.hparam_overrides,
)

model = model_cls(
merged_hps,
Expand All @@ -117,8 +121,8 @@ def test_generic_multi_optimizer_init(self):

unreplicated_optimizer_state = opt_init_fn(unreplicated_params)
self.assertIsInstance(
unreplicated_optimizer_state,
optax.transforms.PartitionState)
unreplicated_optimizer_state, optax.transforms.PartitionState
)

# unreplicated_optimizer_state should be a Dict mapping param type
# to opt_state where only params mapping to that param_type have non-empty
Expand Down Expand Up @@ -159,5 +163,161 @@ def tearDown(self):




class MuonTest(absltest.TestCase):
"""Tests for Muon optimizer."""

def test_muon_split(self):
"""Verifies that Muon handles 2D params and RMSProp handles 1D params."""
params = {
'p1d': jnp.ones((10,)),
'p2d': jnp.ones((10, 10)),
}
grads = {
'p1d': jnp.full((10,), 0.1),
'p2d': jnp.eye(10) * 0.1,
}

lr = 0.1
beta = 0.95
wd = 0.01

muon_tx = muon.scale_by_muon(
learning_rate=lr,
beta=beta,
weight_decay=wd,
)
rms_prop = optax.chain(
optax.scale_by_rms(decay=beta, eps=1e-7),
optax.add_decayed_weights(wd),
optax.scale_by_learning_rate(lr, flip_sign=True),
)
muon_mask = lambda p: jax.tree_util.tree_map(lambda x: x.ndim >= 2, p)
rmsprop_mask = lambda p: jax.tree_util.tree_map(lambda x: x.ndim < 2, p)
tx = optax.chain(
optax.masked(muon_tx, mask=muon_mask),
optax.masked(rms_prop, mask=rmsprop_mask),
)

state = tx.init(params)
updates, _ = tx.update(grads, state, params)

expected_1d = -0.4482136
self.assertTrue(jnp.allclose(updates['p1d'], expected_1d, rtol=1e-3))

self.assertFalse(jnp.allclose(updates['p2d'], -0.4482136, rtol=1e-1))

def test_muon_get_optimizer(self):
"""Verifies get_optimizer('muon') matches manual construction."""
params = {
'p1d': jnp.ones((10,)),
'p2d': jnp.ones((10, 10)),
}
grads = {
'p1d': jnp.full((10,), 0.1),
'p2d': jnp.eye(10) * 0.1,
}

lr = 0.1
beta = 0.95
wd = 0.01

hps = config_dict.ConfigDict({
'optimizer': 'muon',
'l2_decay_factor': None,
'batch_size': 8,
'opt_hparams': {
'muon_hparams': {
'beta': beta,
'weight_decay': wd,
},
'pair_optimizer': 'sgd',
'pair_hparams': {
'weight_decay': wd,
},
},
})

opt_init, opt_update = optimizers.get_optimizer(hps)
state = opt_init(params)
state = optimizers.inject_learning_rate(state, lr)
updates, _ = opt_update(grads, state, params=params)

muon_tx = muon.scale_by_muon(
learning_rate=lr,
beta=beta,
weight_decay=wd,
)
sgd_tx = optax.chain(
optax.add_decayed_weights(wd),
optax.sgd(learning_rate=lr, momentum=None, nesterov=False),
)
muon_mask = lambda p: jax.tree_util.tree_map(lambda x: x.ndim >= 2, p)
pair_mask = lambda p: jax.tree_util.tree_map(lambda x: x.ndim < 2, p)
manual_tx = optax.chain(
optax.masked(muon_tx, mask=muon_mask),
optax.masked(sgd_tx, mask=pair_mask),
)
manual_state = manual_tx.init(params)
manual_updates, _ = manual_tx.update(grads, manual_state, params)

self.assertTrue(
jnp.allclose(updates['p2d'], manual_updates['p2d'], atol=1e-5)
)
self.assertTrue(
jnp.allclose(updates['p1d'], manual_updates['p1d'], atol=1e-5)
)

def test_muon_lr_multiplier(self):
"""Verifies lr_multiplier scales Muon's effective learning rate."""
params = {'w': jnp.ones((10, 10))}
grads = {'w': jnp.eye(10) * 0.1}

lr = 0.1
base = muon.scale_by_muon(learning_rate=lr, lr_multiplier=1.0)
scaled = muon.scale_by_muon(learning_rate=lr, lr_multiplier=2.0)

base_state = base.init(params)
scaled_state = scaled.init(params)

base_updates, _ = base.update(grads, base_state, params)
scaled_updates, _ = scaled.update(grads, scaled_state, params)

self.assertTrue(
jnp.allclose(scaled_updates['w'], 2.0 * base_updates['w'], atol=1e-6)
)

def test_muon_3d_reshape(self):
"""Verifies Muon reshapes >2D params to (s0, -1) for orthogonalization."""
params_3d = {'w': jnp.ones((4, 5, 6))}
grads_3d = {'w': jax.random.normal(jax.random.PRNGKey(0), (4, 5, 6))}

tx = muon.scale_by_muon(learning_rate=0.1)
state = tx.init(params_3d)
updates, _ = tx.update(grads_3d, state, params_3d)

self.assertEqual(updates['w'].shape, (4, 5, 6))

def test_muon_embed_mask(self):
"""Verifies default mask excludes embed params and 1D params."""
params = {
'embed': jnp.ones((100, 64)),
'dense': jnp.ones((64, 64)),
'bias': jnp.ones((64,)),
}

muon_mask = lambda p: jax.tree_util.tree_map_with_path(
lambda path, x: (
x.ndim >= 2 and 'embed' not in jax.tree_util.keystr(path).lower()
),
p,
)
mask = muon_mask(params)

self.assertTrue(mask['dense'])
self.assertFalse(mask['embed'])
self.assertFalse(mask['bias'])


if __name__ == '__main__':
absltest.main()