diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 09857dd0a..5137168f6 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -846,6 +846,113 @@ class Adafactor(Optimizer): return parameter - update +@mx.compile +def _zeropower_via_newtonschulz5( + G: mx.array, steps: int, eps: float = 1e-7 +) -> mx.array: + """Approximate the orthogonal factor U Vᵀ of the SVD G = U S Vᵀ using + a quintic Newton-Schulz iteration (see https://kellerjordan.github.io/posts/muon/). + + Args: + G (mx.array): 2-D (or batched ≥ 2-D) array of any floating dtype. + steps (int): Number of Newton-Schulz iterations. + eps (float): A small value to avoid division by zero. + + Returns: + An array with the same shape/dtype as `G`. + """ + assert G.ndim >= 2, "G must be at least 2-D" + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.astype(mx.bfloat16) + if G.shape[-2] > G.shape[-1]: + X = mx.transpose(X, (-2, -1)) + + # Frobenius-norm normalisation (≈ spectral-norm ≤ 1) + X = X / (mx.linalg.norm(X, ord="fro", axis=(-2, -1), keepdims=True) + eps) + + # Perform Newton-Schulz iteration + for _ in range(steps): + A = X @ mx.transpose(X, (-2, -1)) + B = b * A + c * A @ A + X = a * X + B @ X + + if G.shape[-2] > G.shape[-1]: + X = mx.transpose(X, (-2, -1)) + + return X + + +class Muon(Optimizer): + r"""Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. + + Args: + lr (float): The learning rate used by the internal SGD. + momentum (float): The momentum used by the internal SGD. + nesterov (bool): Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps (int): The number of Newton-Schulz iteration steps to use. + """ + + def __init__( + self, + learning_rate: Union[float, Callable[[mx.array], mx.array]] = 0.02, + weight_decay: float = 0.01, + momentum: float = 0.95, + nesterov: bool = True, + ns_steps: int = 5, + ): + super().__init__() + self._maybe_schedule("learning_rate", learning_rate) + self.momentum = momentum + self.weight_decay = weight_decay + self.nesterov = nesterov + self.ns_steps = ns_steps + + def _flatten_if_conv(self, x: mx.array) -> mx.array: + if x.ndim == 4: # [out, in, kH, kW] → [out, -1] + return mx.reshape(x, (x.shape[0], -1)) + return x + + def init_single(self, parameter: mx.array, state: dict): + state["B"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): + original_dtype = gradient.dtype + original_shape = gradient.shape + lr = self.learning_rate.astype(original_dtype) + # Compute new buffer and store + state["B"] = self.momentum * state["B"] + gradient + # Compute gradient based on nesterov momentum + gradient = ( + gradient + self.momentum * state["B"] if self.nesterov else state["B"] + ) + # Perform standard SGD with momentum for layers with {0,1}-D parameters + if gradient.ndim <= 1: + return parameter * (1 - lr * self.weight_decay) - lr * gradient + # Flatten conv kernels + gradient = self._flatten_if_conv(gradient) + # Newton-Schulz orthogonalisation + gradient = _zeropower_via_newtonschulz5(gradient, steps=self.ns_steps) + # Reshape back to original shape + gradient = mx.reshape(gradient, original_shape) + # scale-invariant step size + scale = mx.sqrt( + mx.maximum(1.0, parameter.shape[-2] / parameter.shape[-1]) + ).astype(original_dtype) + return parameter * (1 - lr * self.weight_decay) - lr * scale * gradient + + def clip_grad_norm(grads, max_norm): """Clips the global norm of the gradients.