mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 23:08:11 +08:00
Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer * update ACKNOWLEDGMENTS.md * nits and adding it to test * nits * G.astype(mx.bfloat16) to G.astype(G.dtype) * G.ndim >= 2 to assert G.ndim == 2 * remove coments * replace with mx.addmm * remove comments * format * nits * match muon * fix addmm --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# Transposed c
|
||||
a = mx.ones((10, 5)).T
|
||||
b = mx.ones((5, 5))
|
||||
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
|
||||
expected = 1.5 * a + 0.5 * (b @ a)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# Broadcast c
|
||||
a = mx.ones((5, 5))
|
||||
b = mx.ones((5, 5))
|
||||
c = mx.ones((1, 5))
|
||||
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
|
||||
expected = 1.5 * c + 0.5 * (a @ b)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_addmm_grad(self):
|
||||
def make_ref_addmm(alpha, beta):
|
||||
return lambda c, a, b: alpha * (a @ b) + beta * c
|
||||
|
@@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(xp["x"].shape, x.shape)
|
||||
self.assertEqual(optimizer.state["step"], 2)
|
||||
|
||||
def test_muon(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
|
||||
"second": mx.zeros((3, 3)),
|
||||
"conv": mx.zeros((16, 8, 3, 3)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
# Test update
|
||||
updated_params = optim.apply_gradients(grads, params)
|
||||
|
||||
# Check that shapes are preserved
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, u: p.shape == u.shape,
|
||||
params,
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
# Check that parameters actually changed
|
||||
self.assertFalse(
|
||||
tree_equal(
|
||||
lambda p, u: mx.array_equal(p, u),
|
||||
params,
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
# Test with different configurations
|
||||
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
|
||||
optim_no_nesterov.apply_gradients(grads, params)
|
||||
|
||||
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
|
||||
optim_no_momentum.apply_gradients(grads, params)
|
||||
|
||||
def test_compiled_optimizer(self):
|
||||
model = nn.Linear(10, 10)
|
||||
x = mx.random.uniform(shape=(2, 10))
|
||||
|
Reference in New Issue
Block a user