Add SVD primitive (#809)

Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra

Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
This commit is contained in:
nicolov
2024-03-12 20:30:11 +01:00
committed by GitHub
parent ffb19df3c0
commit d0c544a868
13 changed files with 318 additions and 1 deletions

View File

@@ -120,6 +120,22 @@ class TestLinalg(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out, mx.eye(2), rtol=1e-5, atol=1e-7))
self.assertTrue(mx.allclose(mx.tril(r, -1), mx.zeros_like(r)))
def test_svd_decomposition(self):
A = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float32)
U, S, Vt = mx.linalg.svd(A, stream=mx.cpu)
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A, rtol=1e-5, atol=1e-7)
)
# Multiple matrices
B = A + 10.0
AB = mx.stack([A, B])
Us, Ss, Vts = mx.linalg.svd(AB, stream=mx.cpu)
for M, U, S, Vt in zip([A, B], Us, Ss, Vts):
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
)
if __name__ == "__main__":
unittest.main()