mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add vmap for SVD and inverse (#849)
This commit is contained in:
		| @@ -314,6 +314,64 @@ class TestVmap(mlx_tests.MLXTestCase): | ||||
|         expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0)) | ||||
|         self.assertTrue(mx.allclose(out, expected)) | ||||
|  | ||||
|     def test_vmap_svd(self): | ||||
|         a = mx.random.uniform(shape=(3, 4, 2)) | ||||
|  | ||||
|         cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu) | ||||
|  | ||||
|         # Vmap over the first axis (this is already supported natively by the primitive). | ||||
|         Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a) | ||||
|         self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1])) | ||||
|         self.assertEqual(Ss.shape, (a.shape[0], a.shape[2])) | ||||
|         self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2])) | ||||
|  | ||||
|         for i in range(a.shape[0]): | ||||
|             M = a[i] | ||||
|             U, S, Vt = Us[i], Ss[i], Vts[i] | ||||
|             self.assertTrue( | ||||
|                 mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7) | ||||
|             ) | ||||
|  | ||||
|         # Vmap over the second axis. | ||||
|         Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a) | ||||
|         self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0])) | ||||
|         self.assertEqual(Ss.shape, (a.shape[1], a.shape[2])) | ||||
|         self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2])) | ||||
|  | ||||
|         for i in range(a.shape[1]): | ||||
|             M = a[:, i, :] | ||||
|             U, S, Vt = Us[i], Ss[i], Vts[i] | ||||
|             self.assertTrue( | ||||
|                 mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7) | ||||
|             ) | ||||
|  | ||||
|     def test_vmap_inverse(self): | ||||
|         a = mx.random.uniform(shape=(3, 4, 4)) | ||||
|  | ||||
|         cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu) | ||||
|  | ||||
|         # Vmap over the first axis (this is already supported natively by the primitive). | ||||
|         invs = mx.vmap(cpu_inv, in_axes=(0,))(a) | ||||
|  | ||||
|         for i in range(a.shape[0]): | ||||
|             self.assertTrue( | ||||
|                 mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=0, atol=1e-5) | ||||
|             ) | ||||
|  | ||||
|         a = mx.random.uniform(shape=(4, 3, 4)) | ||||
|  | ||||
|         # Without vmapping, each input matrix is not square. | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.eval(cpu_inv(a)) | ||||
|  | ||||
|         # Vmap over the second axis. | ||||
|         invs = mx.vmap(cpu_inv, in_axes=(1,))(a) | ||||
|  | ||||
|         for i in range(a.shape[1]): | ||||
|             self.assertTrue( | ||||
|                 mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5) | ||||
|             ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 nicolov
					nicolov