Compare commits

...

9 Commits

Author SHA1 Message Date
Gökdeniz Gülmez
d3d575cce7 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-04-21 20:27:33 +02:00
Gökdeniz Gülmez
8f2744dcf3 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-21 08:50:43 +01:00
Gökdeniz Gülmez
b12be4b7e0 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-12 16:52:21 +01:00
Gökdeniz Gülmez
ebfcb4a14f Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-10 17:10:50 +01:00
Gökdeniz Gülmez
79175a1f35 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-07 11:41:19 +01:00
Gökdeniz Gülmez
59d4e4f61d Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-05 23:09:44 +01:00
Gökdeniz Gülmez
44f776921c Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-05 10:05:10 +01:00
Goekdeniz-Guelmez
871ee2b9b0 update ACKNOWLEDGMENTS.md 2025-02-28 23:24:39 +01:00
Goekdeniz-Guelmez
6c048ab4da initial commit with workong optmimizer 2025-02-28 23:16:51 +01:00
2 changed files with 121 additions and 0 deletions

View File

@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -846,6 +846,126 @@ class Adafactor(Optimizer):
return parameter - update
class Muon(Optimizer):
r"""The Muon optimizer - MomentUm Orthogonalized by Newton-schulz.
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, a Newton-Schulz iteration is used, which has
the advantage that it can be stably run in bfloat16 on the GPU.
For more details, see: https://kellerjordan.github.io/posts/muon/
Note:
- This optimizer may not be optimal for the embedding layer, the final fully connected layer,
or any 0D/1D parameters; those should be optimized by a standard method (e.g., AdamW).
- For 4D convolutional filters, it works by flattening their last dimensions.
Args:
learning_rate (float or callable): The learning rate used by the internal SGD.
momentum (float, optional): The momentum strength. Default: ``0.95``
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0.01``
nesterov (bool, optional): Enables Nesterov momentum. Recommended for better performance.
Default: ``True``
ns_steps (int, optional): Number of Newton-Schulz iteration steps for orthogonalization.
Default: ``5``
"""
def __init__(
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
momentum: float = 0.95,
weight_decay: float = 0.01,
nesterov: bool = True,
ns_steps: int = 5,
):
super().__init__()
self._maybe_schedule("learning_rate", learning_rate)
self.momentum = momentum
self.weight_decay = weight_decay
self.nesterov = nesterov
self.ns_steps = ns_steps
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def _zeropower_via_newtonschulz5(self, G, steps: int):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.astype(mx.bfloat16)
transpose_needed = G.shape[-2] > G.shape[-1]
if transpose_needed:
X = X.T
# Ensure spectral norm is at most 1
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
X = X / norm
# Perform the NS iterations
for _ in range(steps):
A = X @ X.T
B = b * A + c * (A @ A)
X = a * X + B @ X
if transpose_needed:
X = X.T
return X
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Muon parameter update"""
# Apply weight decay
if self.weight_decay != 0:
gradient = gradient + self.weight_decay * parameter
# Update momentum buffer
v = self.momentum * state.get("v")
v = v + (1 - self.momentum) * gradient
state["v"] = v
# Get effective gradient
if self.nesterov:
effective_grad = gradient * self.momentum + v * (1 - self.momentum)
else:
effective_grad = v
# For tensors with fewer than 2 dimensions, skip Newton-Schulz
if effective_grad.ndim < 2:
orthogonalized_grad = effective_grad
scale_factor = 1.0
else:
# Save original shape for 4D conv filters
original_shape = effective_grad.shape
reshape_needed = effective_grad.ndim > 2
if reshape_needed:
effective_grad = mx.reshape(effective_grad, (effective_grad.shape[0], -1))
# Apply Newton-Schulz orthogonalization
orthogonalized_grad = self._zeropower_via_newtonschulz5(effective_grad, steps=self.ns_steps)
# Reshape back if needed
if reshape_needed:
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
# Calculate scaling factor
scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor
def clip_grad_norm(grads, max_norm):
"""Clips the global norm of the gradients.