mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
nits
This commit is contained in:
@@ -893,18 +893,17 @@ class Muon(Optimizer):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def _zeropower_via_newtonschulz5(self, G, steps: int):
|
||||
def _zeropower_via_newtonschulz5(self, X, steps: int):
|
||||
assert (
|
||||
G.ndim == 2
|
||||
), f"Expected a 2D matrix for Newton-Schulz iteration, got shape {G.shape} instead."
|
||||
X.ndim == 2
|
||||
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.astype(G.dtype)
|
||||
transpose_needed = G.shape[-2] > G.shape[-1]
|
||||
transpose_needed = X.shape[-2] > X.shape[-1]
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
|
||||
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
|
||||
norm = mx.sqrt(mx.sum(mx.square(X), keepdims=True) + 1e-7)
|
||||
X = X / norm
|
||||
|
||||
for _ in range(steps):
|
||||
@@ -927,39 +926,27 @@ class Muon(Optimizer):
|
||||
state["v"] = v
|
||||
|
||||
if self.nesterov:
|
||||
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
|
||||
update = gradient * (1 - self.momentum) + v * self.momentum
|
||||
else:
|
||||
effective_grad = v
|
||||
update = v
|
||||
|
||||
if effective_grad.ndim < 2:
|
||||
orthogonalized_grad = effective_grad
|
||||
scale_factor = 1.0
|
||||
else:
|
||||
original_shape = effective_grad.shape
|
||||
reshape_needed = effective_grad.ndim > 2
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
|
||||
if update.ndim >= 2:
|
||||
original_shape = update.shape
|
||||
reshape_needed = update.ndim > 2
|
||||
|
||||
if reshape_needed:
|
||||
effective_grad = mx.reshape(
|
||||
effective_grad, (effective_grad.shape[0], -1)
|
||||
)
|
||||
update = mx.reshape(update, (update.shape[0], -1))
|
||||
|
||||
orthogonalized_grad = self._zeropower_via_newtonschulz5(
|
||||
effective_grad, steps=self.ns_steps
|
||||
)
|
||||
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
|
||||
|
||||
if reshape_needed:
|
||||
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
|
||||
update = mx.reshape(update, original_shape)
|
||||
|
||||
scale_factor = (
|
||||
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
|
||||
)
|
||||
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
|
||||
|
||||
return (
|
||||
parameter
|
||||
- self.learning_rate.astype(gradient.dtype)
|
||||
* orthogonalized_grad
|
||||
* scale_factor
|
||||
)
|
||||
return parameter - lr * update
|
||||
|
||||
|
||||
def clip_grad_norm(grads, max_norm):
|
||||
|
Reference in New Issue
Block a user