vmap matmul and admm (#836)

This commit is contained in:
Awni Hannun
2024-03-14 14:38:22 -07:00
committed by GitHub
parent 63ab0ab580
commit 19ec023256
3 changed files with 79 additions and 3 deletions

View File

@@ -275,6 +275,45 @@ class TestVmap(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
out = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))(a, b)
def test_vmap_matmul(self):
a = mx.random.uniform(shape=(2, 3, 4))
b = mx.random.uniform(shape=(4, 3))
# matmul
out = mx.vmap(mx.matmul, in_axes=(0, None))(a, b)
self.assertTrue(mx.allclose(out, a @ b))
# addmm
c = mx.random.uniform(shape=(3,))
out = mx.vmap(mx.addmm, in_axes=(None, 0, None))(c, a, b)
self.assertTrue(mx.allclose(out, mx.addmm(c, a, b)))
b = mx.random.uniform(shape=(4, 2))
# matmul
out = mx.vmap(mx.matmul, in_axes=(1, None), out_axes=(1,))(a, b)
expected = mx.moveaxis(mx.moveaxis(a, 1, 0) @ b, 0, 1)
self.assertTrue(mx.allclose(out, expected))
# addmm
c = mx.random.uniform(shape=(2,))
out = mx.vmap(mx.addmm, in_axes=(None, 1, None))(c, a, b)
self.assertTrue(mx.allclose(out, mx.addmm(c, mx.moveaxis(a, 1, 0), b)))
a = mx.random.uniform(shape=(2, 3, 4))
b = mx.random.uniform(shape=(4, 2, 3))
# matmul
out = mx.vmap(mx.matmul, in_axes=(0, 1))(a, b)
expected = a @ mx.moveaxis(b, 1, 0)
self.assertTrue(mx.allclose(out, expected))
# addmm
c = mx.random.uniform(shape=(3, 3, 2))
out = mx.vmap(mx.addmm, in_axes=(2, 0, 1))(c, a, b)
expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0))
self.assertTrue(mx.allclose(out, expected))
if __name__ == "__main__":
unittest.main()