add Muon optimizer

This commit is contained in:
Dhruv Srikanth 2025-05-10 16:17:24 +01:00
parent e8ff59451e
commit aef320bf84

View File

@ -865,19 +865,19 @@ def _zeropower_via_newtonschulz5(
a, b, c = (3.4445, -4.7750, 2.0315) a, b, c = (3.4445, -4.7750, 2.0315)
X = G.astype(mx.bfloat16) X = G.astype(mx.bfloat16)
if G.shape[-2] > G.shape[-1]: if G.shape[-2] > G.shape[-1]:
X = mx.transpose(X, (-2, -1)) X = X.T
# Frobenius-norm normalisation (≈ spectral-norm ≤ 1) # Frobenius-norm normalisation (≈ spectral-norm ≤ 1)
X = X / (mx.linalg.norm(X, ord="fro", axis=(-2, -1), keepdims=True) + eps) X = X / (mx.linalg.norm(X, ord="fro", axis=(-2, -1), keepdims=True) + eps)
# Perform Newton-Schulz iteration # Perform Newton-Schulz iteration
for _ in range(steps): for _ in range(steps):
A = X @ mx.transpose(X, (-2, -1)) A = X @ X.T
B = b * A + c * A @ A B = b * A + c * A @ A
X = a * X + B @ X X = a * X + B @ X
if G.shape[-2] > G.shape[-1]: if G.shape[-2] > G.shape[-1]:
X = mx.transpose(X, (-2, -1)) X = X.T
return X return X
@ -918,6 +918,7 @@ class Muon(Optimizer):
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.nesterov = nesterov self.nesterov = nesterov
self.ns_steps = ns_steps self.ns_steps = ns_steps
self.backup_optimizer = AdamW(learning_rate=1e-3, weight_decay=weight_decay)
def _flatten_if_conv(self, x: mx.array) -> mx.array: def _flatten_if_conv(self, x: mx.array) -> mx.array:
if x.ndim == 4: # [out, in, kH, kW] → [out, -1] if x.ndim == 4: # [out, in, kH, kW] → [out, -1]
@ -925,11 +926,21 @@ class Muon(Optimizer):
return x return x
def init_single(self, parameter: mx.array, state: dict): def init_single(self, parameter: mx.array, state: dict):
# If the parameter is not a 2D array, use the backup optimizer
if parameter.ndim != 2:
return self.backup_optimizer.init_single(parameter, state)
state["B"] = mx.zeros_like(parameter) state["B"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
# If the parameter is not a 2D array, use the backup optimizer
if "B" not in state:
return self.backup_optimizer.apply_single(
gradient=gradient, parameter=parameter, state=state
)
# Record the original dtype and shape
original_dtype = gradient.dtype original_dtype = gradient.dtype
original_shape = gradient.shape original_shape = gradient.shape
# Cast lr to the original dtype
lr = self.learning_rate.astype(original_dtype) lr = self.learning_rate.astype(original_dtype)
# Compute new buffer and store # Compute new buffer and store
state["B"] = self.momentum * state["B"] + gradient state["B"] = self.momentum * state["B"] + gradient
@ -947,10 +958,11 @@ class Muon(Optimizer):
# Reshape back to original shape # Reshape back to original shape
gradient = mx.reshape(gradient, original_shape) gradient = mx.reshape(gradient, original_shape)
# scale-invariant step size # scale-invariant step size
scale = mx.sqrt( scale = mx.maximum(1.0, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
mx.maximum(1.0, parameter.shape[-2] / parameter.shape[-1]) return (
).astype(original_dtype) parameter * (1 - lr * self.weight_decay)
return parameter * (1 - lr * self.weight_decay) - lr * scale * gradient - lr * scale.astype(original_dtype) * gradient
)
def clip_grad_norm(grads, max_norm): def clip_grad_norm(grads, max_norm):