fix addmm

This commit is contained in:
Awni Hannun
2025-07-18 08:41:22 -07:00
parent 508bd25e29
commit 8435c047e1
2 changed files with 35 additions and 7 deletions

View File

@@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Transposed c
a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
expected = 1.5 * a + 0.5 * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
expected = 1.5 * c + 0.5 * (a @ b)
self.assertTrue(mx.allclose(expected, out))
def test_addmm_grad(self):
def make_ref_addmm(alpha, beta):
return lambda c, a, b: alpha * (a @ b) + beta * c