Add muon optimizer

This commit is contained in:
Dhruv Srikanth 2025-05-07 03:29:44 +01:00
parent 5a1a5d5ed1
commit e8ff59451e

View File

@ -846,6 +846,113 @@ class Adafactor(Optimizer):
return parameter - update return parameter - update
@mx.compile
def _zeropower_via_newtonschulz5(
G: mx.array, steps: int, eps: float = 1e-7
) -> mx.array:
"""Approximate the orthogonal factor U Vᵀ of the SVD G = U S Vᵀ using
a quintic Newton-Schulz iteration (see https://kellerjordan.github.io/posts/muon/).
Args:
G (mx.array): 2-D (or batched 2-D) array of any floating dtype.
steps (int): Number of Newton-Schulz iterations.
eps (float): A small value to avoid division by zero.
Returns:
An array with the same shape/dtype as `G`.
"""
assert G.ndim >= 2, "G must be at least 2-D"
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.astype(mx.bfloat16)
if G.shape[-2] > G.shape[-1]:
X = mx.transpose(X, (-2, -1))
# Frobenius-norm normalisation (≈ spectral-norm ≤ 1)
X = X / (mx.linalg.norm(X, ord="fro", axis=(-2, -1), keepdims=True) + eps)
# Perform Newton-Schulz iteration
for _ in range(steps):
A = X @ mx.transpose(X, (-2, -1))
B = b * A + c * A @ A
X = a * X + B @ X
if G.shape[-2] > G.shape[-1]:
X = mx.transpose(X, (-2, -1))
return X
class Muon(Optimizer):
r"""Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- This optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
Args:
lr (float): The learning rate used by the internal SGD.
momentum (float): The momentum used by the internal SGD.
nesterov (bool): Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps (int): The number of Newton-Schulz iteration steps to use.
"""
def __init__(
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]] = 0.02,
weight_decay: float = 0.01,
momentum: float = 0.95,
nesterov: bool = True,
ns_steps: int = 5,
):
super().__init__()
self._maybe_schedule("learning_rate", learning_rate)
self.momentum = momentum
self.weight_decay = weight_decay
self.nesterov = nesterov
self.ns_steps = ns_steps
def _flatten_if_conv(self, x: mx.array) -> mx.array:
if x.ndim == 4: # [out, in, kH, kW] → [out, -1]
return mx.reshape(x, (x.shape[0], -1))
return x
def init_single(self, parameter: mx.array, state: dict):
state["B"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
original_dtype = gradient.dtype
original_shape = gradient.shape
lr = self.learning_rate.astype(original_dtype)
# Compute new buffer and store
state["B"] = self.momentum * state["B"] + gradient
# Compute gradient based on nesterov momentum
gradient = (
gradient + self.momentum * state["B"] if self.nesterov else state["B"]
)
# Perform standard SGD with momentum for layers with {0,1}-D parameters
if gradient.ndim <= 1:
return parameter * (1 - lr * self.weight_decay) - lr * gradient
# Flatten conv kernels
gradient = self._flatten_if_conv(gradient)
# Newton-Schulz orthogonalisation
gradient = _zeropower_via_newtonschulz5(gradient, steps=self.ns_steps)
# Reshape back to original shape
gradient = mx.reshape(gradient, original_shape)
# scale-invariant step size
scale = mx.sqrt(
mx.maximum(1.0, parameter.shape[-2] / parameter.shape[-1])
).astype(original_dtype)
return parameter * (1 - lr * self.weight_decay) - lr * scale * gradient
def clip_grad_norm(grads, max_norm): def clip_grad_norm(grads, max_norm):
"""Clips the global norm of the gradients. """Clips the global norm of the gradients.