This commit is contained in:
Awni Hannun
2025-07-17 06:26:43 -07:00
parent baad6e392b
commit 7f39e9c299
3 changed files with 50 additions and 47 deletions

View File

@@ -19,3 +19,4 @@ Common Optimizers
Adamax Adamax
Lion Lion
MultiOptimizer MultiOptimizer
Muon

View File

@@ -849,28 +849,28 @@ class Adafactor(Optimizer):
class Muon(Optimizer): class Muon(Optimizer):
r"""The Muon optimizer - MomentUm Orthogonalized by Newton-schulz. r"""The Muon optimizer.
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal original implementation: `Muon: An optimizer for hidden layers in neural
matrix. To efficiently orthogonalize each update, a Newton-Schulz iteration is used, which has networks <https://kellerjordan.github.io/posts/muon/>`_
the advantage that it can be stably run in bfloat16 on the GPU.
For more details, see: https://kellerjordan.github.io/posts/muon/
Note: Note:
- This optimizer may not be optimal for the embedding layer, the final fully connected layer, - Muon may be sub-optimal for the embedding layer, the final fully
or any 0D/1D parameters; those should be optimized by a standard method (e.g., AdamW). connected layer, or any 0D/1D parameters. Those should be optimized
- For 4D convolutional filters, it works by flattening their last dimensions. by a different method (e.g., :class:`AdamW`).
- For 4D convolutional filters, it works by flattening their last
dimensions.
Args: Args:
learning_rate (float or callable): The learning rate used by the internal SGD. learning_rate (float or callable): The learning rate.
momentum (float, optional): The momentum strength. Default: ``0.95`` momentum (float, optional): The momentum strength. Default: ``0.95``
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0.01`` weight_decay (float, optional): The weight decay (L2 penalty).
nesterov (bool, optional): Enables Nesterov momentum. Recommended for better performance. Default: ``0.01``
Default: ``True`` nesterov (bool, optional): Enables Nesterov momentum. Recommended for
ns_steps (int, optional): Number of Newton-Schulz iteration steps for orthogonalization. better performance. Default: ``True``
Default: ``5`` ns_steps (int, optional): Number of Newton-Schulz iteration steps for
orthogonalization. Default: ``5``
""" """
def __init__( def __init__(
@@ -894,15 +894,6 @@ class Muon(Optimizer):
state["v"] = mx.zeros_like(parameter) state["v"] = mx.zeros_like(parameter)
def _zeropower_via_newtonschulz5(self, G, steps: int): def _zeropower_via_newtonschulz5(self, G, steps: int):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2 assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315) a, b, c = (3.4445, -4.7750, 2.0315)
X = G.astype(mx.bfloat16) X = G.astype(mx.bfloat16)
@@ -953,10 +944,14 @@ class Muon(Optimizer):
reshape_needed = effective_grad.ndim > 2 reshape_needed = effective_grad.ndim > 2
if reshape_needed: if reshape_needed:
effective_grad = mx.reshape(effective_grad, (effective_grad.shape[0], -1)) effective_grad = mx.reshape(
effective_grad, (effective_grad.shape[0], -1)
)
# Apply Newton-Schulz orthogonalization # Apply Newton-Schulz orthogonalization
orthogonalized_grad = self._zeropower_via_newtonschulz5(effective_grad, steps=self.ns_steps) orthogonalized_grad = self._zeropower_via_newtonschulz5(
effective_grad, steps=self.ns_steps
)
# Reshape back if needed # Reshape back if needed
if reshape_needed: if reshape_needed:
@@ -964,9 +959,16 @@ class Muon(Optimizer):
# Calculate scaling factor # Calculate scaling factor
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5 # scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
scale_factor = max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5 scale_factor = (
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
)
return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor return (
parameter
- self.learning_rate.astype(gradient.dtype)
* orthogonalized_grad
* scale_factor
)
def clip_grad_norm(grads, max_norm): def clip_grad_norm(grads, max_norm):