mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Add muon optimizer
This commit is contained in:
parent
5a1a5d5ed1
commit
e8ff59451e
@ -846,6 +846,113 @@ class Adafactor(Optimizer):
|
|||||||
return parameter - update
|
return parameter - update
|
||||||
|
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def _zeropower_via_newtonschulz5(
|
||||||
|
G: mx.array, steps: int, eps: float = 1e-7
|
||||||
|
) -> mx.array:
|
||||||
|
"""Approximate the orthogonal factor U Vᵀ of the SVD G = U S Vᵀ using
|
||||||
|
a quintic Newton-Schulz iteration (see https://kellerjordan.github.io/posts/muon/).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
G (mx.array): 2-D (or batched ≥ 2-D) array of any floating dtype.
|
||||||
|
steps (int): Number of Newton-Schulz iterations.
|
||||||
|
eps (float): A small value to avoid division by zero.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array with the same shape/dtype as `G`.
|
||||||
|
"""
|
||||||
|
assert G.ndim >= 2, "G must be at least 2-D"
|
||||||
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
|
X = G.astype(mx.bfloat16)
|
||||||
|
if G.shape[-2] > G.shape[-1]:
|
||||||
|
X = mx.transpose(X, (-2, -1))
|
||||||
|
|
||||||
|
# Frobenius-norm normalisation (≈ spectral-norm ≤ 1)
|
||||||
|
X = X / (mx.linalg.norm(X, ord="fro", axis=(-2, -1), keepdims=True) + eps)
|
||||||
|
|
||||||
|
# Perform Newton-Schulz iteration
|
||||||
|
for _ in range(steps):
|
||||||
|
A = X @ mx.transpose(X, (-2, -1))
|
||||||
|
B = b * A + c * A @ A
|
||||||
|
X = a * X + B @ X
|
||||||
|
|
||||||
|
if G.shape[-2] > G.shape[-1]:
|
||||||
|
X = mx.transpose(X, (-2, -1))
|
||||||
|
|
||||||
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
class Muon(Optimizer):
|
||||||
|
r"""Muon - MomentUm Orthogonalized by Newton-schulz
|
||||||
|
|
||||||
|
https://kellerjordan.github.io/posts/muon/
|
||||||
|
|
||||||
|
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||||
|
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||||
|
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
||||||
|
the advantage that it can be stably run in bfloat16 on the GPU.
|
||||||
|
|
||||||
|
Some warnings:
|
||||||
|
- This optimizer should not be used for the embedding layer, the final fully connected layer,
|
||||||
|
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
||||||
|
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lr (float): The learning rate used by the internal SGD.
|
||||||
|
momentum (float): The momentum used by the internal SGD.
|
||||||
|
nesterov (bool): Whether to use Nesterov-style momentum in the internal SGD. (recommended)
|
||||||
|
ns_steps (int): The number of Newton-Schulz iteration steps to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
learning_rate: Union[float, Callable[[mx.array], mx.array]] = 0.02,
|
||||||
|
weight_decay: float = 0.01,
|
||||||
|
momentum: float = 0.95,
|
||||||
|
nesterov: bool = True,
|
||||||
|
ns_steps: int = 5,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
|
self.momentum = momentum
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.nesterov = nesterov
|
||||||
|
self.ns_steps = ns_steps
|
||||||
|
|
||||||
|
def _flatten_if_conv(self, x: mx.array) -> mx.array:
|
||||||
|
if x.ndim == 4: # [out, in, kH, kW] → [out, -1]
|
||||||
|
return mx.reshape(x, (x.shape[0], -1))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def init_single(self, parameter: mx.array, state: dict):
|
||||||
|
state["B"] = mx.zeros_like(parameter)
|
||||||
|
|
||||||
|
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||||
|
original_dtype = gradient.dtype
|
||||||
|
original_shape = gradient.shape
|
||||||
|
lr = self.learning_rate.astype(original_dtype)
|
||||||
|
# Compute new buffer and store
|
||||||
|
state["B"] = self.momentum * state["B"] + gradient
|
||||||
|
# Compute gradient based on nesterov momentum
|
||||||
|
gradient = (
|
||||||
|
gradient + self.momentum * state["B"] if self.nesterov else state["B"]
|
||||||
|
)
|
||||||
|
# Perform standard SGD with momentum for layers with {0,1}-D parameters
|
||||||
|
if gradient.ndim <= 1:
|
||||||
|
return parameter * (1 - lr * self.weight_decay) - lr * gradient
|
||||||
|
# Flatten conv kernels
|
||||||
|
gradient = self._flatten_if_conv(gradient)
|
||||||
|
# Newton-Schulz orthogonalisation
|
||||||
|
gradient = _zeropower_via_newtonschulz5(gradient, steps=self.ns_steps)
|
||||||
|
# Reshape back to original shape
|
||||||
|
gradient = mx.reshape(gradient, original_shape)
|
||||||
|
# scale-invariant step size
|
||||||
|
scale = mx.sqrt(
|
||||||
|
mx.maximum(1.0, parameter.shape[-2] / parameter.shape[-1])
|
||||||
|
).astype(original_dtype)
|
||||||
|
return parameter * (1 - lr * self.weight_decay) - lr * scale * gradient
|
||||||
|
|
||||||
|
|
||||||
def clip_grad_norm(grads, max_norm):
|
def clip_grad_norm(grads, max_norm):
|
||||||
"""Clips the global norm of the gradients.
|
"""Clips the global norm of the gradients.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user