From df6d9e972f2ea998adaa58e4fecca81f207edddf Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 16 Jul 2025 19:13:40 +0200 Subject: [PATCH] nits and adding it to test --- python/mlx/optimizers/optimizers.py | 7 +++-- python/tests/test_optimizers.py | 47 +++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 9f78aa912..4f73d72d5 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -933,13 +933,13 @@ class Muon(Optimizer): gradient = gradient + self.weight_decay * parameter # Update momentum buffer - v = self.momentum * state.get("v") + v = self.momentum * state["v"] v = v + (1 - self.momentum) * gradient state["v"] = v # Get effective gradient if self.nesterov: - effective_grad = gradient * self.momentum + v * (1 - self.momentum) + effective_grad = gradient * (1 - self.momentum) + v * self.momentum else: effective_grad = v @@ -963,7 +963,8 @@ class Muon(Optimizer): orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape) # Calculate scaling factor - scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5 + # scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5 + scale_factor = max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5 return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 8f9e33679..962db4161 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase): self.assertEqual(xp["x"].shape, x.shape) self.assertEqual(optimizer.state["step"], 2) + def test_muon(self): + params = { + "first": [mx.zeros((10, 5)), mx.zeros((1,))], + "second": mx.zeros((3, 3)), + "conv": mx.zeros((16, 8, 3, 3)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + # Test update + updated_params = optim.apply_gradients(grads, params) + + # Check that shapes are preserved + self.assertTrue( + tree_equal( + lambda p, u: p.shape == u.shape, + params, + updated_params, + ) + ) + + # Check that parameters actually changed + self.assertFalse( + tree_equal( + lambda p, u: mx.array_equal(p, u), + params, + updated_params, + ) + ) + + # Test with different configurations + optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False) + optim_no_nesterov.apply_gradients(grads, params) + + optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0) + optim_no_momentum.apply_gradients(grads, params) + def test_compiled_optimizer(self): model = nn.Linear(10, 10) x = mx.random.uniform(shape=(2, 10))