Add float64 Eig and complex64 SVD/Eig support (Fixes #2708) (#2737)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Harsh Sutaria
2025-11-22 09:51:36 -05:00
committed by GitHub
parent d5f61a93fa
commit 618c87af8c
5 changed files with 471 additions and 161 deletions

View File

@@ -168,6 +168,42 @@ class TestLinalg(mlx_tests.MLXTestCase):
)
)
# Test float64 - use CPU stream since float64 is not supported on GPU
with mx.stream(mx.cpu):
A_f64 = mx.array(
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64
)
U_f64, S_f64, Vt_f64 = mx.linalg.svd(A_f64, compute_uv=True)
mx.eval(U_f64, S_f64, Vt_f64)
self.assertTrue(
mx.allclose(
U_f64[:, : len(S_f64)] @ mx.diag(S_f64) @ Vt_f64,
A_f64,
rtol=1e-5,
atol=1e-7,
)
)
self.assertEqual(S_f64.dtype, mx.float64)
# Test complex64 - use CPU stream since complex64 is not supported on GPU
with mx.stream(mx.cpu):
A_c64 = mx.array(
[[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=mx.complex64
)
U_c64, S_c64, Vt_c64 = mx.linalg.svd(A_c64, compute_uv=True)
mx.eval(U_c64, S_c64, Vt_c64)
self.assertTrue(
mx.allclose(
U_c64[:, : len(S_c64)] @ mx.diag(S_c64) @ Vt_c64,
A_c64,
rtol=1e-5,
atol=1e-7,
)
)
self.assertEqual(S_c64.dtype, mx.float32)
self.assertEqual(U_c64.dtype, mx.complex64)
self.assertEqual(Vt_c64.dtype, mx.complex64)
def test_inverse(self):
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
A_inv = mx.linalg.inv(A, stream=mx.cpu)
@@ -342,6 +378,43 @@ class TestLinalg(mlx_tests.MLXTestCase):
A_np = np.random.randn(3, n, n).astype(np.float32)
check_eigs_and_vecs(A_np)
# Test float64 - use CPU stream since float64 is not supported on GPU
with mx.stream(mx.cpu):
A_np_f64 = np.array([[1.0, 1.0], [3.0, 4.0]], dtype=np.float64)
A_f64 = mx.array(A_np_f64, dtype=mx.float64)
eig_vals_f64, eig_vecs_f64 = mx.linalg.eig(A_f64)
mx.eval(eig_vals_f64, eig_vecs_f64)
self.assertTrue(
mx.allclose(
A_f64 @ eig_vecs_f64,
eig_vals_f64[..., None, :] * eig_vecs_f64,
rtol=1e-5,
atol=1e-5,
)
)
# Eigenvalues should be complex64 (output dtype)
self.assertEqual(eig_vals_f64.dtype, mx.complex64)
self.assertEqual(eig_vecs_f64.dtype, mx.complex64)
# Test complex64 input - use CPU stream since complex64 is not supported on GPU
with mx.stream(mx.cpu):
A_np_c64 = np.array(
[[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=np.complex64
)
A_c64 = mx.array(A_np_c64, dtype=mx.complex64)
eig_vals_c64, eig_vecs_c64 = mx.linalg.eig(A_c64)
mx.eval(eig_vals_c64, eig_vecs_c64)
self.assertTrue(
mx.allclose(
A_c64 @ eig_vecs_c64,
eig_vals_c64[..., None, :] * eig_vecs_c64,
rtol=1e-5,
atol=1e-5,
)
)
self.assertEqual(eig_vals_c64.dtype, mx.complex64)
self.assertEqual(eig_vecs_c64.dtype, mx.complex64)
# Test error cases
with self.assertRaises(ValueError):
mx.linalg.eig(mx.array([1.0, 2.0])) # 1D array