Add vmap for SVD and inverse (#849)

This commit is contained in:
nicolov
2024-03-21 21:18:27 +01:00
committed by GitHub
parent 53e6a9367c
commit 105d236889
7 changed files with 116 additions and 5 deletions

View File

@@ -314,6 +314,64 @@ class TestVmap(mlx_tests.MLXTestCase):
expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0))
self.assertTrue(mx.allclose(out, expected))
def test_vmap_svd(self):
a = mx.random.uniform(shape=(3, 4, 2))
cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu)
# Vmap over the first axis (this is already supported natively by the primitive).
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a)
self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1]))
self.assertEqual(Ss.shape, (a.shape[0], a.shape[2]))
self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2]))
for i in range(a.shape[0]):
M = a[i]
U, S, Vt = Us[i], Ss[i], Vts[i]
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
)
# Vmap over the second axis.
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a)
self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0]))
self.assertEqual(Ss.shape, (a.shape[1], a.shape[2]))
self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2]))
for i in range(a.shape[1]):
M = a[:, i, :]
U, S, Vt = Us[i], Ss[i], Vts[i]
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
)
def test_vmap_inverse(self):
a = mx.random.uniform(shape=(3, 4, 4))
cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)
# Vmap over the first axis (this is already supported natively by the primitive).
invs = mx.vmap(cpu_inv, in_axes=(0,))(a)
for i in range(a.shape[0]):
self.assertTrue(
mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=0, atol=1e-5)
)
a = mx.random.uniform(shape=(4, 3, 4))
# Without vmapping, each input matrix is not square.
with self.assertRaises(ValueError):
mx.eval(cpu_inv(a))
# Vmap over the second axis.
invs = mx.vmap(cpu_inv, in_axes=(1,))(a)
for i in range(a.shape[1]):
self.assertTrue(
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
)
if __name__ == "__main__":
unittest.main()