From 6c048ab4da277b084f8497003ecdee857f7eafa4 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 28 Feb 2025 23:16:51 +0100 Subject: [PATCH] initial commit with workong optmimizer --- python/mlx/optimizers/optimizers.py | 120 ++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 3d40dd0d1..b3b701a72 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -773,6 +773,126 @@ class Adafactor(Optimizer): return parameter - update +class Muon(Optimizer): + r"""The Muon optimizer - MomentUm Orthogonalized by Newton-schulz. + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, a Newton-Schulz iteration is used, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + For more details, see: https://kellerjordan.github.io/posts/muon/ + + Note: + - This optimizer may not be optimal for the embedding layer, the final fully connected layer, + or any 0D/1D parameters; those should be optimized by a standard method (e.g., AdamW). + - For 4D convolutional filters, it works by flattening their last dimensions. + + Args: + learning_rate (float or callable): The learning rate used by the internal SGD. + momentum (float, optional): The momentum strength. Default: ``0.95`` + weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0.01`` + nesterov (bool, optional): Enables Nesterov momentum. Recommended for better performance. + Default: ``True`` + ns_steps (int, optional): Number of Newton-Schulz iteration steps for orthogonalization. + Default: ``5`` + """ + + def __init__( + self, + learning_rate: Union[float, Callable[[mx.array], mx.array]], + momentum: float = 0.95, + weight_decay: float = 0.01, + nesterov: bool = True, + ns_steps: int = 5, + ): + super().__init__() + + self._maybe_schedule("learning_rate", learning_rate) + self.momentum = momentum + self.weight_decay = weight_decay + self.nesterov = nesterov + self.ns_steps = ns_steps + + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + + def _zeropower_via_newtonschulz5(self, G, steps: int): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert G.ndim >= 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.astype(mx.bfloat16) + transpose_needed = G.shape[-2] > G.shape[-1] + + if transpose_needed: + X = X.T + + # Ensure spectral norm is at most 1 + norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7) + X = X / norm + + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = b * A + c * (A @ A) + X = a * X + B @ X + + if transpose_needed: + X = X.T + return X + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): + """Performs the Muon parameter update""" + + # Apply weight decay + if self.weight_decay != 0: + gradient = gradient + self.weight_decay * parameter + + # Update momentum buffer + v = self.momentum * state.get("v") + v = v + (1 - self.momentum) * gradient + state["v"] = v + + # Get effective gradient + if self.nesterov: + effective_grad = gradient * self.momentum + v * (1 - self.momentum) + else: + effective_grad = v + + # For tensors with fewer than 2 dimensions, skip Newton-Schulz + if effective_grad.ndim < 2: + orthogonalized_grad = effective_grad + scale_factor = 1.0 + else: + # Save original shape for 4D conv filters + original_shape = effective_grad.shape + reshape_needed = effective_grad.ndim > 2 + + if reshape_needed: + effective_grad = mx.reshape(effective_grad, (effective_grad.shape[0], -1)) + + # Apply Newton-Schulz orthogonalization + orthogonalized_grad = self._zeropower_via_newtonschulz5(effective_grad, steps=self.ns_steps) + + # Reshape back if needed + if reshape_needed: + orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape) + + # Calculate scaling factor + scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5 + + return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor + + def clip_grad_norm(grads, max_norm): """Clips the global norm of the gradients.