diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 5137168f6..6ab94c76d 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -865,19 +865,19 @@ def _zeropower_via_newtonschulz5( a, b, c = (3.4445, -4.7750, 2.0315) X = G.astype(mx.bfloat16) if G.shape[-2] > G.shape[-1]: - X = mx.transpose(X, (-2, -1)) + X = X.T # Frobenius-norm normalisation (≈ spectral-norm ≤ 1) X = X / (mx.linalg.norm(X, ord="fro", axis=(-2, -1), keepdims=True) + eps) # Perform Newton-Schulz iteration for _ in range(steps): - A = X @ mx.transpose(X, (-2, -1)) + A = X @ X.T B = b * A + c * A @ A X = a * X + B @ X if G.shape[-2] > G.shape[-1]: - X = mx.transpose(X, (-2, -1)) + X = X.T return X @@ -918,6 +918,7 @@ class Muon(Optimizer): self.weight_decay = weight_decay self.nesterov = nesterov self.ns_steps = ns_steps + self.backup_optimizer = AdamW(learning_rate=1e-3, weight_decay=weight_decay) def _flatten_if_conv(self, x: mx.array) -> mx.array: if x.ndim == 4: # [out, in, kH, kW] → [out, -1] @@ -925,11 +926,21 @@ class Muon(Optimizer): return x def init_single(self, parameter: mx.array, state: dict): + # If the parameter is not a 2D array, use the backup optimizer + if parameter.ndim != 2: + return self.backup_optimizer.init_single(parameter, state) state["B"] = mx.zeros_like(parameter) def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): + # If the parameter is not a 2D array, use the backup optimizer + if "B" not in state: + return self.backup_optimizer.apply_single( + gradient=gradient, parameter=parameter, state=state + ) + # Record the original dtype and shape original_dtype = gradient.dtype original_shape = gradient.shape + # Cast lr to the original dtype lr = self.learning_rate.astype(original_dtype) # Compute new buffer and store state["B"] = self.momentum * state["B"] + gradient @@ -947,10 +958,11 @@ class Muon(Optimizer): # Reshape back to original shape gradient = mx.reshape(gradient, original_shape) # scale-invariant step size - scale = mx.sqrt( - mx.maximum(1.0, parameter.shape[-2] / parameter.shape[-1]) - ).astype(original_dtype) - return parameter * (1 - lr * self.weight_decay) - lr * scale * gradient + scale = mx.maximum(1.0, parameter.shape[-2] / parameter.shape[-1]) ** 0.5 + return ( + parameter * (1 - lr * self.weight_decay) + - lr * scale.astype(original_dtype) * gradient + ) def clip_grad_norm(grads, max_norm):