Compare commits

...

22 Commits

Author SHA1 Message Date
Awni Hannun
0a8bb904d7 nits 2025-07-17 11:58:41 -07:00
Gökdeniz Gülmez
c535d8c1b5 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-07-17 20:10:02 +02:00
Goekdeniz-Guelmez
4b3d7634cd format 2025-07-17 20:03:19 +02:00
Goekdeniz-Guelmez
516d172ba5 remove comments 2025-07-17 20:02:27 +02:00
Goekdeniz-Guelmez
698daee214 replace with mx.addmm 2025-07-17 19:57:18 +02:00
Goekdeniz-Guelmez
4c0f7c713b remove coments 2025-07-17 19:53:56 +02:00
Goekdeniz-Guelmez
3889c805da G.ndim >= 2 to assert G.ndim == 2 2025-07-17 19:52:00 +02:00
Goekdeniz-Guelmez
060404d862 G.astype(mx.bfloat16) to G.astype(G.dtype) 2025-07-17 19:49:26 +02:00
Awni Hannun
7f39e9c299 nits 2025-07-17 06:26:43 -07:00
Gökdeniz Gülmez
baad6e392b Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-07-17 13:07:54 +02:00
Gökdeniz Gülmez
784e0716fe Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-07-16 21:58:17 +02:00
Goekdeniz-Guelmez
df6d9e972f nits and adding it to test 2025-07-16 19:13:40 +02:00
Gökdeniz Gülmez
650c956fe6 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-07-16 16:29:10 +02:00
Gökdeniz Gülmez
d3d575cce7 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-04-21 20:27:33 +02:00
Gökdeniz Gülmez
8f2744dcf3 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-21 08:50:43 +01:00
Gökdeniz Gülmez
b12be4b7e0 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-12 16:52:21 +01:00
Gökdeniz Gülmez
ebfcb4a14f Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-10 17:10:50 +01:00
Gökdeniz Gülmez
79175a1f35 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-07 11:41:19 +01:00
Gökdeniz Gülmez
59d4e4f61d Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-05 23:09:44 +01:00
Gökdeniz Gülmez
44f776921c Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-05 10:05:10 +01:00
Goekdeniz-Guelmez
871ee2b9b0 update ACKNOWLEDGMENTS.md 2025-02-28 23:24:39 +01:00
Goekdeniz-Guelmez
6c048ab4da initial commit with workong optmimizer 2025-02-28 23:16:51 +01:00
4 changed files with 150 additions and 0 deletions

View File

@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -19,3 +19,4 @@ Common Optimizers
Adamax
Lion
MultiOptimizer
Muon

View File

@@ -848,6 +848,107 @@ class Adafactor(Optimizer):
return parameter - update
class Muon(Optimizer):
r"""The Muon optimizer.
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
original implementation: `Muon: An optimizer for hidden layers in neural
networks <https://kellerjordan.github.io/posts/muon/>`_
Note:
- Muon may be sub-optimal for the embedding layer, the final fully
connected layer, or any 0D/1D parameters. Those should be optimized
by a different method (e.g., :class:`AdamW`).
- For 4D convolutional filters, it works by flattening their last
dimensions.
Args:
learning_rate (float or callable): The learning rate.
momentum (float, optional): The momentum strength. Default: ``0.95``
weight_decay (float, optional): The weight decay (L2 penalty).
Default: ``0.01``
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
better performance. Default: ``True``
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
orthogonalization. Default: ``5``
"""
def __init__(
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
momentum: float = 0.95,
weight_decay: float = 0.01,
nesterov: bool = True,
ns_steps: int = 5,
):
super().__init__()
self._maybe_schedule("learning_rate", learning_rate)
self.momentum = momentum
self.weight_decay = weight_decay
self.nesterov = nesterov
self.ns_steps = ns_steps
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def _zeropower_via_newtonschulz5(self, X, steps: int):
assert (
X.ndim == 2
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
a, b, c = (3.4445, -4.7750, 2.0315)
transpose_needed = X.shape[-2] > X.shape[-1]
if transpose_needed:
X = X.T
norm = mx.sqrt(mx.sum(mx.square(X), keepdims=True) + 1e-7)
X = X / norm
for _ in range(steps):
A = X @ X.T
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
if transpose_needed:
X = X.T
return X
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Muon parameter update"""
if self.weight_decay != 0:
gradient = gradient + self.weight_decay * parameter
v = self.momentum * state["v"]
v = v + (1 - self.momentum) * gradient
state["v"] = v
if self.nesterov:
update = gradient * (1 - self.momentum) + v * self.momentum
else:
update = v
lr = self.learning_rate.astype(gradient.dtype)
if update.ndim >= 2:
original_shape = update.shape
reshape_needed = update.ndim > 2
if reshape_needed:
update = mx.reshape(update, (update.shape[0], -1))
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
if reshape_needed:
update = mx.reshape(update, original_shape)
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
return parameter - lr * update
def clip_grad_norm(grads, max_norm):
"""Clips the global norm of the gradients.

View File

@@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
self.assertEqual(xp["x"].shape, x.shape)
self.assertEqual(optimizer.state["step"], 2)
def test_muon(self):
params = {
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
"second": mx.zeros((3, 3)),
"conv": mx.zeros((16, 8, 3, 3)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
# Test update
updated_params = optim.apply_gradients(grads, params)
# Check that shapes are preserved
self.assertTrue(
tree_equal(
lambda p, u: p.shape == u.shape,
params,
updated_params,
)
)
# Check that parameters actually changed
self.assertFalse(
tree_equal(
lambda p, u: mx.array_equal(p, u),
params,
updated_params,
)
)
# Test with different configurations
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
optim_no_nesterov.apply_gradients(grads, params)
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
optim_no_momentum.apply_gradients(grads, params)
def test_compiled_optimizer(self):
model = nn.Linear(10, 10)
x = mx.random.uniform(shape=(2, 10))