This commit is contained in:
Awni Hannun
2025-07-17 06:26:43 -07:00
parent baad6e392b
commit 7f39e9c299
3 changed files with 50 additions and 47 deletions

View File

@@ -19,3 +19,4 @@ Common Optimizers
Adamax
Lion
MultiOptimizer
Muon

View File

@@ -849,28 +849,28 @@ class Adafactor(Optimizer):
class Muon(Optimizer):
r"""The Muon optimizer - MomentUm Orthogonalized by Newton-schulz.
r"""The Muon optimizer.
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, a Newton-Schulz iteration is used, which has
the advantage that it can be stably run in bfloat16 on the GPU.
For more details, see: https://kellerjordan.github.io/posts/muon/
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
original implementation: `Muon: An optimizer for hidden layers in neural
networks <https://kellerjordan.github.io/posts/muon/>`_
Note:
- This optimizer may not be optimal for the embedding layer, the final fully connected layer,
or any 0D/1D parameters; those should be optimized by a standard method (e.g., AdamW).
- For 4D convolutional filters, it works by flattening their last dimensions.
- Muon may be sub-optimal for the embedding layer, the final fully
connected layer, or any 0D/1D parameters. Those should be optimized
by a different method (e.g., :class:`AdamW`).
- For 4D convolutional filters, it works by flattening their last
dimensions.
Args:
learning_rate (float or callable): The learning rate used by the internal SGD.
learning_rate (float or callable): The learning rate.
momentum (float, optional): The momentum strength. Default: ``0.95``
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0.01``
nesterov (bool, optional): Enables Nesterov momentum. Recommended for better performance.
Default: ``True``
ns_steps (int, optional): Number of Newton-Schulz iteration steps for orthogonalization.
Default: ``5``
weight_decay (float, optional): The weight decay (L2 penalty).
Default: ``0.01``
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
better performance. Default: ``True``
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
orthogonalization. Default: ``5``
"""
def __init__(
@@ -882,7 +882,7 @@ class Muon(Optimizer):
ns_steps: int = 5,
):
super().__init__()
self._maybe_schedule("learning_rate", learning_rate)
self.momentum = momentum
self.weight_decay = weight_decay
@@ -894,55 +894,46 @@ class Muon(Optimizer):
state["v"] = mx.zeros_like(parameter)
def _zeropower_via_newtonschulz5(self, G, steps: int):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.astype(mx.bfloat16)
transpose_needed = G.shape[-2] > G.shape[-1]
if transpose_needed:
X = X.T
# Ensure spectral norm is at most 1
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
X = X / norm
# Perform the NS iterations
for _ in range(steps):
A = X @ X.T
B = b * A + c * (A @ A)
X = a * X + B @ X
if transpose_needed:
X = X.T
return X
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Muon parameter update"""
# Apply weight decay
if self.weight_decay != 0:
gradient = gradient + self.weight_decay * parameter
# Update momentum buffer
v = self.momentum * state["v"]
v = v + (1 - self.momentum) * gradient
state["v"] = v
# Get effective gradient
if self.nesterov:
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
else:
effective_grad = v
# For tensors with fewer than 2 dimensions, skip Newton-Schulz
if effective_grad.ndim < 2:
orthogonalized_grad = effective_grad
@@ -951,22 +942,33 @@ class Muon(Optimizer):
# Save original shape for 4D conv filters
original_shape = effective_grad.shape
reshape_needed = effective_grad.ndim > 2
if reshape_needed:
effective_grad = mx.reshape(effective_grad, (effective_grad.shape[0], -1))
effective_grad = mx.reshape(
effective_grad, (effective_grad.shape[0], -1)
)
# Apply Newton-Schulz orthogonalization
orthogonalized_grad = self._zeropower_via_newtonschulz5(effective_grad, steps=self.ns_steps)
orthogonalized_grad = self._zeropower_via_newtonschulz5(
effective_grad, steps=self.ns_steps
)
# Reshape back if needed
if reshape_needed:
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
# Calculate scaling factor
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
scale_factor = max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor
scale_factor = (
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
)
return (
parameter
- self.learning_rate.astype(gradient.dtype)
* orthogonalized_grad
* scale_factor
)
def clip_grad_norm(grads, max_norm):

View File

@@ -307,7 +307,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
# Test update
updated_params = optim.apply_gradients(grads, params)
# Check that shapes are preserved
self.assertTrue(
tree_equal(
@@ -316,7 +316,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
updated_params,
)
)
# Check that parameters actually changed
self.assertFalse(
tree_equal(
@@ -325,11 +325,11 @@ class TestOptimizers(mlx_tests.MLXTestCase):
updated_params,
)
)
# Test with different configurations
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
optim_no_nesterov.apply_gradients(grads, params)
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
optim_no_momentum.apply_gradients(grads, params)