mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
nits
This commit is contained in:
@@ -19,3 +19,4 @@ Common Optimizers
|
|||||||
Adamax
|
Adamax
|
||||||
Lion
|
Lion
|
||||||
MultiOptimizer
|
MultiOptimizer
|
||||||
|
Muon
|
||||||
|
|||||||
@@ -849,28 +849,28 @@ class Adafactor(Optimizer):
|
|||||||
|
|
||||||
|
|
||||||
class Muon(Optimizer):
|
class Muon(Optimizer):
|
||||||
r"""The Muon optimizer - MomentUm Orthogonalized by Newton-schulz.
|
r"""The Muon optimizer.
|
||||||
|
|
||||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
|
||||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
original implementation: `Muon: An optimizer for hidden layers in neural
|
||||||
matrix. To efficiently orthogonalize each update, a Newton-Schulz iteration is used, which has
|
networks <https://kellerjordan.github.io/posts/muon/>`_
|
||||||
the advantage that it can be stably run in bfloat16 on the GPU.
|
|
||||||
|
|
||||||
For more details, see: https://kellerjordan.github.io/posts/muon/
|
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
- This optimizer may not be optimal for the embedding layer, the final fully connected layer,
|
- Muon may be sub-optimal for the embedding layer, the final fully
|
||||||
or any 0D/1D parameters; those should be optimized by a standard method (e.g., AdamW).
|
connected layer, or any 0D/1D parameters. Those should be optimized
|
||||||
- For 4D convolutional filters, it works by flattening their last dimensions.
|
by a different method (e.g., :class:`AdamW`).
|
||||||
|
- For 4D convolutional filters, it works by flattening their last
|
||||||
|
dimensions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float or callable): The learning rate used by the internal SGD.
|
learning_rate (float or callable): The learning rate.
|
||||||
momentum (float, optional): The momentum strength. Default: ``0.95``
|
momentum (float, optional): The momentum strength. Default: ``0.95``
|
||||||
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0.01``
|
weight_decay (float, optional): The weight decay (L2 penalty).
|
||||||
nesterov (bool, optional): Enables Nesterov momentum. Recommended for better performance.
|
Default: ``0.01``
|
||||||
Default: ``True``
|
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
|
||||||
ns_steps (int, optional): Number of Newton-Schulz iteration steps for orthogonalization.
|
better performance. Default: ``True``
|
||||||
Default: ``5``
|
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
|
||||||
|
orthogonalization. Default: ``5``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -894,15 +894,6 @@ class Muon(Optimizer):
|
|||||||
state["v"] = mx.zeros_like(parameter)
|
state["v"] = mx.zeros_like(parameter)
|
||||||
|
|
||||||
def _zeropower_via_newtonschulz5(self, G, steps: int):
|
def _zeropower_via_newtonschulz5(self, G, steps: int):
|
||||||
"""
|
|
||||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
||||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
|
||||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
|
||||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
|
||||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
|
||||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
|
||||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
|
||||||
"""
|
|
||||||
assert G.ndim >= 2
|
assert G.ndim >= 2
|
||||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
X = G.astype(mx.bfloat16)
|
X = G.astype(mx.bfloat16)
|
||||||
@@ -953,10 +944,14 @@ class Muon(Optimizer):
|
|||||||
reshape_needed = effective_grad.ndim > 2
|
reshape_needed = effective_grad.ndim > 2
|
||||||
|
|
||||||
if reshape_needed:
|
if reshape_needed:
|
||||||
effective_grad = mx.reshape(effective_grad, (effective_grad.shape[0], -1))
|
effective_grad = mx.reshape(
|
||||||
|
effective_grad, (effective_grad.shape[0], -1)
|
||||||
|
)
|
||||||
|
|
||||||
# Apply Newton-Schulz orthogonalization
|
# Apply Newton-Schulz orthogonalization
|
||||||
orthogonalized_grad = self._zeropower_via_newtonschulz5(effective_grad, steps=self.ns_steps)
|
orthogonalized_grad = self._zeropower_via_newtonschulz5(
|
||||||
|
effective_grad, steps=self.ns_steps
|
||||||
|
)
|
||||||
|
|
||||||
# Reshape back if needed
|
# Reshape back if needed
|
||||||
if reshape_needed:
|
if reshape_needed:
|
||||||
@@ -964,9 +959,16 @@ class Muon(Optimizer):
|
|||||||
|
|
||||||
# Calculate scaling factor
|
# Calculate scaling factor
|
||||||
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
|
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
|
||||||
scale_factor = max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
|
scale_factor = (
|
||||||
|
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
|
||||||
|
)
|
||||||
|
|
||||||
return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor
|
return (
|
||||||
|
parameter
|
||||||
|
- self.learning_rate.astype(gradient.dtype)
|
||||||
|
* orthogonalized_grad
|
||||||
|
* scale_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def clip_grad_norm(grads, max_norm):
|
def clip_grad_norm(grads, max_norm):
|
||||||
|
|||||||
Reference in New Issue
Block a user