This commit is contained in:
Awni Hannun
2025-07-17 11:58:41 -07:00
parent c535d8c1b5
commit 0a8bb904d7

View File

@@ -893,18 +893,17 @@ class Muon(Optimizer):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def _zeropower_via_newtonschulz5(self, G, steps: int):
def _zeropower_via_newtonschulz5(self, X, steps: int):
assert (
G.ndim == 2
), f"Expected a 2D matrix for Newton-Schulz iteration, got shape {G.shape} instead."
X.ndim == 2
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.astype(G.dtype)
transpose_needed = G.shape[-2] > G.shape[-1]
transpose_needed = X.shape[-2] > X.shape[-1]
if transpose_needed:
X = X.T
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
norm = mx.sqrt(mx.sum(mx.square(X), keepdims=True) + 1e-7)
X = X / norm
for _ in range(steps):
@@ -927,39 +926,27 @@ class Muon(Optimizer):
state["v"] = v
if self.nesterov:
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
update = gradient * (1 - self.momentum) + v * self.momentum
else:
effective_grad = v
update = v
if effective_grad.ndim < 2:
orthogonalized_grad = effective_grad
scale_factor = 1.0
else:
original_shape = effective_grad.shape
reshape_needed = effective_grad.ndim > 2
lr = self.learning_rate.astype(gradient.dtype)
if update.ndim >= 2:
original_shape = update.shape
reshape_needed = update.ndim > 2
if reshape_needed:
effective_grad = mx.reshape(
effective_grad, (effective_grad.shape[0], -1)
)
update = mx.reshape(update, (update.shape[0], -1))
orthogonalized_grad = self._zeropower_via_newtonschulz5(
effective_grad, steps=self.ns_steps
)
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
if reshape_needed:
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
update = mx.reshape(update, original_shape)
scale_factor = (
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
)
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
return (
parameter
- self.learning_rate.astype(gradient.dtype)
* orthogonalized_grad
* scale_factor
)
return parameter - lr * update
def clip_grad_norm(grads, max_norm):