mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 19:48:15 +08:00
Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer * update ACKNOWLEDGMENTS.md * nits and adding it to test * nits * G.astype(mx.bfloat16) to G.astype(G.dtype) * G.ndim >= 2 to assert G.ndim == 2 * remove coments * replace with mx.addmm * remove comments * format * nits * match muon * fix addmm --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -848,6 +848,106 @@ class Adafactor(Optimizer):
|
||||
return parameter - update
|
||||
|
||||
|
||||
class Muon(Optimizer):
|
||||
r"""The Muon optimizer.
|
||||
|
||||
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
|
||||
original implementation: `Muon: An optimizer for hidden layers in neural
|
||||
networks <https://kellerjordan.github.io/posts/muon/>`_
|
||||
|
||||
Note:
|
||||
- Muon may be sub-optimal for the embedding layer, the final fully
|
||||
connected layer, or any 0D/1D parameters. Those should be optimized
|
||||
by a different method (e.g., :class:`AdamW`).
|
||||
- For 4D convolutional filters, it works by flattening their last
|
||||
dimensions.
|
||||
|
||||
Args:
|
||||
learning_rate (float or callable): The learning rate.
|
||||
momentum (float, optional): The momentum strength. Default: ``0.95``
|
||||
weight_decay (float, optional): The weight decay (L2 penalty).
|
||||
Default: ``0.01``
|
||||
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
|
||||
better performance. Default: ``True``
|
||||
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
|
||||
orthogonalization. Default: ``5``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
momentum: float = 0.95,
|
||||
weight_decay: float = 0.01,
|
||||
nesterov: bool = True,
|
||||
ns_steps: int = 5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.momentum = momentum
|
||||
self.weight_decay = weight_decay
|
||||
self.nesterov = nesterov
|
||||
self.ns_steps = ns_steps
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def _zeropower_via_newtonschulz5(self, X, steps: int):
|
||||
assert (
|
||||
X.ndim == 2
|
||||
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
transpose_needed = X.shape[-2] > X.shape[-1]
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
|
||||
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
|
||||
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
||||
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
return X
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Muon parameter update"""
|
||||
|
||||
if self.weight_decay != 0:
|
||||
gradient = gradient + self.weight_decay * parameter
|
||||
|
||||
v = self.momentum * state["v"]
|
||||
v = v + (1 - self.momentum) * gradient
|
||||
state["v"] = v
|
||||
|
||||
if self.nesterov:
|
||||
update = gradient * (1 - self.momentum) + v * self.momentum
|
||||
else:
|
||||
update = v
|
||||
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
|
||||
if update.ndim >= 2:
|
||||
original_shape = update.shape
|
||||
reshape_needed = update.ndim > 2
|
||||
|
||||
if reshape_needed:
|
||||
update = mx.reshape(update, (update.shape[0], -1))
|
||||
|
||||
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
|
||||
|
||||
if reshape_needed:
|
||||
update = mx.reshape(update, original_shape)
|
||||
|
||||
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
|
||||
|
||||
return parameter - lr * update
|
||||
|
||||
|
||||
def clip_grad_norm(grads, max_norm):
|
||||
"""Clips the global norm of the gradients.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user