mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Add vmap for SVD and inverse (#849)
This commit is contained in:
@@ -314,6 +314,64 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0))
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_vmap_svd(self):
|
||||
a = mx.random.uniform(shape=(3, 4, 2))
|
||||
|
||||
cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu)
|
||||
|
||||
# Vmap over the first axis (this is already supported natively by the primitive).
|
||||
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a)
|
||||
self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1]))
|
||||
self.assertEqual(Ss.shape, (a.shape[0], a.shape[2]))
|
||||
self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2]))
|
||||
|
||||
for i in range(a.shape[0]):
|
||||
M = a[i]
|
||||
U, S, Vt = Us[i], Ss[i], Vts[i]
|
||||
self.assertTrue(
|
||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||
)
|
||||
|
||||
# Vmap over the second axis.
|
||||
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a)
|
||||
self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0]))
|
||||
self.assertEqual(Ss.shape, (a.shape[1], a.shape[2]))
|
||||
self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2]))
|
||||
|
||||
for i in range(a.shape[1]):
|
||||
M = a[:, i, :]
|
||||
U, S, Vt = Us[i], Ss[i], Vts[i]
|
||||
self.assertTrue(
|
||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||
)
|
||||
|
||||
def test_vmap_inverse(self):
|
||||
a = mx.random.uniform(shape=(3, 4, 4))
|
||||
|
||||
cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)
|
||||
|
||||
# Vmap over the first axis (this is already supported natively by the primitive).
|
||||
invs = mx.vmap(cpu_inv, in_axes=(0,))(a)
|
||||
|
||||
for i in range(a.shape[0]):
|
||||
self.assertTrue(
|
||||
mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=0, atol=1e-5)
|
||||
)
|
||||
|
||||
a = mx.random.uniform(shape=(4, 3, 4))
|
||||
|
||||
# Without vmapping, each input matrix is not square.
|
||||
with self.assertRaises(ValueError):
|
||||
mx.eval(cpu_inv(a))
|
||||
|
||||
# Vmap over the second axis.
|
||||
invs = mx.vmap(cpu_inv, in_axes=(1,))(a)
|
||||
|
||||
for i in range(a.shape[1]):
|
||||
self.assertTrue(
|
||||
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user