mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 13:07:29 +08:00
vmap matmul and admm (#836)
This commit is contained in:
@@ -275,6 +275,45 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
out = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))(a, b)
|
||||
|
||||
def test_vmap_matmul(self):
|
||||
a = mx.random.uniform(shape=(2, 3, 4))
|
||||
b = mx.random.uniform(shape=(4, 3))
|
||||
|
||||
# matmul
|
||||
out = mx.vmap(mx.matmul, in_axes=(0, None))(a, b)
|
||||
self.assertTrue(mx.allclose(out, a @ b))
|
||||
|
||||
# addmm
|
||||
c = mx.random.uniform(shape=(3,))
|
||||
out = mx.vmap(mx.addmm, in_axes=(None, 0, None))(c, a, b)
|
||||
self.assertTrue(mx.allclose(out, mx.addmm(c, a, b)))
|
||||
|
||||
b = mx.random.uniform(shape=(4, 2))
|
||||
|
||||
# matmul
|
||||
out = mx.vmap(mx.matmul, in_axes=(1, None), out_axes=(1,))(a, b)
|
||||
expected = mx.moveaxis(mx.moveaxis(a, 1, 0) @ b, 0, 1)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# addmm
|
||||
c = mx.random.uniform(shape=(2,))
|
||||
out = mx.vmap(mx.addmm, in_axes=(None, 1, None))(c, a, b)
|
||||
self.assertTrue(mx.allclose(out, mx.addmm(c, mx.moveaxis(a, 1, 0), b)))
|
||||
|
||||
a = mx.random.uniform(shape=(2, 3, 4))
|
||||
b = mx.random.uniform(shape=(4, 2, 3))
|
||||
|
||||
# matmul
|
||||
out = mx.vmap(mx.matmul, in_axes=(0, 1))(a, b)
|
||||
expected = a @ mx.moveaxis(b, 1, 0)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# addmm
|
||||
c = mx.random.uniform(shape=(3, 3, 2))
|
||||
out = mx.vmap(mx.addmm, in_axes=(2, 0, 1))(c, a, b)
|
||||
expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0))
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user