Adding support for the Muon Optimizer (#1914)

* initial commit with workong optmimizer

* update ACKNOWLEDGMENTS.md

* nits and adding it to test

* nits

* G.astype(mx.bfloat16) to G.astype(G.dtype)

* G.ndim >= 2 to assert G.ndim == 2

* remove coments

* replace with  mx.addmm

* remove comments

* format

* nits

* match muon

* fix addmm

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Gökdeniz Gülmez
2025-07-18 21:25:28 +02:00
committed by GitHub
parent 45adec102c
commit deee214a95
6 changed files with 184 additions and 7 deletions

View File

@@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
self.assertEqual(xp["x"].shape, x.shape)
self.assertEqual(optimizer.state["step"], 2)
def test_muon(self):
params = {
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
"second": mx.zeros((3, 3)),
"conv": mx.zeros((16, 8, 3, 3)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
# Test update
updated_params = optim.apply_gradients(grads, params)
# Check that shapes are preserved
self.assertTrue(
tree_equal(
lambda p, u: p.shape == u.shape,
params,
updated_params,
)
)
# Check that parameters actually changed
self.assertFalse(
tree_equal(
lambda p, u: mx.array_equal(p, u),
params,
updated_params,
)
)
# Test with different configurations
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
optim_no_nesterov.apply_gradients(grads, params)
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
optim_no_momentum.apply_gradients(grads, params)
def test_compiled_optimizer(self):
model = nn.Linear(10, 10)
x = mx.random.uniform(shape=(2, 10))