mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 18:18:15 +08:00
Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer * update ACKNOWLEDGMENTS.md * nits and adding it to test * nits * G.astype(mx.bfloat16) to G.astype(G.dtype) * G.ndim >= 2 to assert G.ndim == 2 * remove coments * replace with mx.addmm * remove comments * format * nits * match muon * fix addmm --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# Transposed c
|
||||
a = mx.ones((10, 5)).T
|
||||
b = mx.ones((5, 5))
|
||||
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
|
||||
expected = 1.5 * a + 0.5 * (b @ a)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# Broadcast c
|
||||
a = mx.ones((5, 5))
|
||||
b = mx.ones((5, 5))
|
||||
c = mx.ones((1, 5))
|
||||
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
|
||||
expected = 1.5 * c + 0.5 * (a @ b)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_addmm_grad(self):
|
||||
def make_ref_addmm(alpha, beta):
|
||||
return lambda c, a, b: alpha * (a @ b) + beta * c
|
||||
|
||||
Reference in New Issue
Block a user