mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer * update ACKNOWLEDGMENTS.md * nits and adding it to test * nits * G.astype(mx.bfloat16) to G.astype(G.dtype) * G.ndim >= 2 to assert G.ndim == 2 * remove coments * replace with mx.addmm * remove comments * format * nits * match muon * fix addmm --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		@@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
 | 
			
		||||
            self.assertEqual(xp["x"].shape, x.shape)
 | 
			
		||||
        self.assertEqual(optimizer.state["step"], 2)
 | 
			
		||||
 | 
			
		||||
    def test_muon(self):
 | 
			
		||||
        params = {
 | 
			
		||||
            "first": [mx.zeros((10, 5)), mx.zeros((1,))],
 | 
			
		||||
            "second": mx.zeros((3, 3)),
 | 
			
		||||
            "conv": mx.zeros((16, 8, 3, 3)),
 | 
			
		||||
        }
 | 
			
		||||
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
			
		||||
 | 
			
		||||
        # Explicit init
 | 
			
		||||
        optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
 | 
			
		||||
        optim.init(params)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            tree_equal(
 | 
			
		||||
                lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
			
		||||
                params,
 | 
			
		||||
                optim.state,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Test update
 | 
			
		||||
        updated_params = optim.apply_gradients(grads, params)
 | 
			
		||||
 | 
			
		||||
        # Check that shapes are preserved
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            tree_equal(
 | 
			
		||||
                lambda p, u: p.shape == u.shape,
 | 
			
		||||
                params,
 | 
			
		||||
                updated_params,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Check that parameters actually changed
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            tree_equal(
 | 
			
		||||
                lambda p, u: mx.array_equal(p, u),
 | 
			
		||||
                params,
 | 
			
		||||
                updated_params,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Test with different configurations
 | 
			
		||||
        optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
 | 
			
		||||
        optim_no_nesterov.apply_gradients(grads, params)
 | 
			
		||||
 | 
			
		||||
        optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
 | 
			
		||||
        optim_no_momentum.apply_gradients(grads, params)
 | 
			
		||||
 | 
			
		||||
    def test_compiled_optimizer(self):
 | 
			
		||||
        model = nn.Linear(10, 10)
 | 
			
		||||
        x = mx.random.uniform(shape=(2, 10))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user