diff --git a/mlx/backend/cpu/eig.cpp b/mlx/backend/cpu/eig.cpp index 0d1f95a57..fb63eefdb 100644 --- a/mlx/backend/cpu/eig.cpp +++ b/mlx/backend/cpu/eig.cpp @@ -12,6 +12,167 @@ namespace mlx::core { namespace { +template +complex64_t to_complex(T r, T i) { + return {static_cast(r), static_cast(i)}; +} + +template +struct EigWork {}; + +template +struct EigWork< + T, + typename std::enable_if::value>::type> { + using O = complex64_t; + + char jobl; + char jobr; + int N; + int lwork; + int info; + std::vector buffers; + + EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors) + : jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) { + T work; + int n_vecs_l = compute_eigenvectors ? N_ : 1; + int n_vecs_r = 1; + geev( + &jobl, + &jobr, + &N, + nullptr, + &N, + nullptr, + nullptr, + nullptr, + &n_vecs_l, + nullptr, + &n_vecs_r, + &work, + &lwork, + &info); + lwork = static_cast(work); + + buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2)); + if (compute_eigenvectors) { + buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2)); + } + buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); + } + + void run(T* a, O* values, O* vectors) { + auto eig_tmp = static_cast(buffers[0].buffer.raw_ptr()); + T* vec_tmp = nullptr; + if (vectors) { + vec_tmp = static_cast(buffers[1].buffer.raw_ptr()); + } + auto work = static_cast(buffers.back().buffer.raw_ptr()); + + int n_vecs_l = vectors ? N : 1; + int n_vecs_r = 1; + geev( + &jobl, + &jobr, + &N, + a, + &N, + eig_tmp, + eig_tmp + N, + vectors ? vec_tmp : nullptr, + &n_vecs_l, + nullptr, + &n_vecs_r, + work, + &lwork, + &info); + + for (int i = 0; i < N; ++i) { + values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]); + } + + if (vectors) { + for (int i = 0; i < N; ++i) { + if (values[i].imag() != 0) { + for (int j = 0; j < N; ++j) { + vectors[i * N + j] = + to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]); + vectors[(i + 1) * N + j] = + to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]); + } + i += 1; + } else { + for (int j = 0; j < N; ++j) { + vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0)); + } + } + } + } + } +}; + +template <> +struct EigWork> { + using T = std::complex; + using R = float; + using O = T; + + char jobl; + char jobr; + int N; + int lwork; + int lrwork; + int info; + std::vector buffers; + + EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors) + : jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) { + T work; + R rwork; + int n_vecs_l = compute_eigenvectors ? N_ : 1; + int n_vecs_r = 1; + geev( + &jobl, + &jobr, + &N, + nullptr, + &N, + nullptr, + nullptr, + &n_vecs_l, + nullptr, + &n_vecs_r, + &work, + &lwork, + &rwork, + &info); + lwork = static_cast(work.real()); + buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); + buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork)); + } + + void run(T* a, T* values, T* vectors) { + int n_vecs_l = vectors ? N : 1; + int n_vecs_r = 1; + geev( + &jobl, + &jobr, + &N, + a, + &N, + values, + vectors, + &n_vecs_l, + nullptr, + &n_vecs_r, + static_cast(buffers[0].buffer.raw_ptr()), + &lwork, + static_cast(buffers[1].buffer.raw_ptr()), + &info); + } +}; + template void eig_impl( array& a, @@ -19,101 +180,39 @@ void eig_impl( array& values, bool compute_eigenvectors, Stream stream) { - using OT = std::complex; auto a_ptr = a.data(); - auto eig_ptr = values.data(); + auto val_ptr = values.data(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_output_array(values); - OT* vec_ptr = nullptr; + complex64_t* vec_ptr = nullptr; if (compute_eigenvectors) { encoder.set_output_array(vectors); - vec_ptr = vectors.data(); + vec_ptr = vectors.data(); } encoder.dispatch([a_ptr, + val_ptr, vec_ptr, - eig_ptr, compute_eigenvectors, N = vectors.shape(-1), size = vectors.size()]() mutable { - // Work query char jobr = 'N'; char jobl = compute_eigenvectors ? 'V' : 'N'; - int n_vecs_r = 1; - int n_vecs_l = compute_eigenvectors ? N : 1; - int lwork = -1; - int info; - { - T work; - geev( - &jobl, - &jobr, - &N, - nullptr, - &N, - nullptr, - nullptr, - nullptr, - &n_vecs_l, - nullptr, - &n_vecs_r, - &work, - &lwork, - &info); - lwork = static_cast(work); - } - auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)}; - auto vec_tmp_data = - array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)}; - auto eig_tmp = static_cast(eig_tmp_data.buffer.raw_ptr()); - auto vec_tmp = static_cast(vec_tmp_data.buffer.raw_ptr()); - auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)}; + EigWork work(jobl, jobr, N, compute_eigenvectors); + for (size_t i = 0; i < size / (N * N); ++i) { - geev( - &jobl, - &jobr, - &N, - a_ptr, - &N, - eig_tmp, - eig_tmp + N, - vec_tmp, - &n_vecs_l, - nullptr, - &n_vecs_r, - static_cast(work_buf.buffer.raw_ptr()), - &lwork, - &info); - for (int i = 0; i < N; ++i) { - eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]}; - } + work.run(a_ptr, val_ptr, vec_ptr); + a_ptr += N * N; + val_ptr += N; if (vec_ptr) { - for (int i = 0; i < N; ++i) { - if (eig_ptr[i].imag() != 0) { - // This vector and the next are a pair - for (int j = 0; j < N; ++j) { - vec_ptr[i * N + j] = { - vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]}; - vec_ptr[(i + 1) * N + j] = { - vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]}; - } - i += 1; - } else { - for (int j = 0; j < N; ++j) { - vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0}; - } - } - } vec_ptr += N * N; } - a_ptr += N * N; - eig_ptr += N; - if (info != 0) { + if (work.info != 0) { std::stringstream msg; msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code " - << info; + << work.info; throw std::runtime_error(msg.str()); } } @@ -165,8 +264,17 @@ void Eig::eval_cpu( case float32: eig_impl(a_copy, vectors, values, compute_eigenvectors_, stream()); break; + case float64: + eig_impl( + a_copy, vectors, values, compute_eigenvectors_, stream()); + break; + case complex64: + eig_impl>( + a_copy, vectors, values, compute_eigenvectors_, stream()); + break; default: - throw std::runtime_error("[Eig::eval_cpu] only supports float32."); + throw std::runtime_error( + "[Eig::eval_cpu] only supports float32, float64, or complex64."); } } diff --git a/mlx/backend/cpu/lapack.h b/mlx/backend/cpu/lapack.h index ce735f26c..1c3ba1a80 100644 --- a/mlx/backend/cpu/lapack.h +++ b/mlx/backend/cpu/lapack.h @@ -45,9 +45,7 @@ INSTANTIATE_LAPACK_REAL(geqrf) INSTANTIATE_LAPACK_REAL(orgqr) INSTANTIATE_LAPACK_REAL(syevd) -INSTANTIATE_LAPACK_REAL(geev) INSTANTIATE_LAPACK_REAL(potrf) -INSTANTIATE_LAPACK_REAL(gesdd) INSTANTIATE_LAPACK_REAL(getrf) INSTANTIATE_LAPACK_REAL(getri) INSTANTIATE_LAPACK_REAL(trtri) @@ -63,3 +61,20 @@ INSTANTIATE_LAPACK_REAL(trtri) } INSTANTIATE_LAPACK_COMPLEX(heevd) + +#define INSTANTIATE_LAPACK_ALL(FUNC) \ + template \ + void FUNC(Args... args) { \ + if constexpr (std::is_same_v) { \ + MLX_LAPACK_FUNC(s##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v) { \ + MLX_LAPACK_FUNC(d##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(c##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(z##FUNC)(std::forward(args)...); \ + } \ + } + +INSTANTIATE_LAPACK_ALL(geev) +INSTANTIATE_LAPACK_ALL(gesdd) diff --git a/mlx/backend/cpu/svd.cpp b/mlx/backend/cpu/svd.cpp index 1fc94c382..ca01a0a65 100644 --- a/mlx/backend/cpu/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -8,6 +8,183 @@ namespace mlx::core { +template +struct SVDWork {}; + +template +struct SVDWork< + T, + typename std::enable_if::value>::type> { + using R = T; + + int N; + int M; + int K; + int lda; + int ldu; + int ldvt; + char jobz; + std::vector buffers; + int lwork; + + SVDWork(int N, int M, int K, char jobz) + : N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) { + T workspace_dimension = 0; + + // Will contain the indices of eigenvectors that failed to converge (not + // used here but required by lapack). + buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K)); + + int lwork_query = -1; + int info; + + // Compute workspace size. + gesdd( + /* jobz = */ &jobz, + // M and N are swapped since lapack expects column-major. + /* m = */ &N, + /* n = */ &M, + /* a = */ nullptr, + /* lda = */ &lda, + /* s = */ nullptr, + /* u = */ nullptr, + /* ldu = */ &ldu, + /* vt = */ nullptr, + /* ldvt = */ &ldvt, + /* work = */ &workspace_dimension, + /* lwork = */ &lwork_query, + /* iwork = */ static_cast(buffers[0].buffer.raw_ptr()), + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "[SVD::eval_cpu] workspace calculation failed with code " << info; + throw std::runtime_error(ss.str()); + } + + lwork = workspace_dimension; + buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); + } + + void run(T* a, R* s, T* u, T* vt) { + int info; + gesdd( + /* jobz = */ &jobz, + // M and N are swapped since lapack expects column-major. + /* m = */ &N, + /* n = */ &M, + /* a = */ a, + /* lda = */ &lda, + /* s = */ s, + // According to the identity above, lapack will write Vᵀᵀ as U. + /* u = */ u, + /* ldu = */ &ldu, + // According to the identity above, lapack will write Uᵀ as Vᵀ. + /* vt = */ vt, + /* ldvt = */ &ldvt, + /* work = */ static_cast(buffers[1].buffer.raw_ptr()), + /* lwork = */ &lwork, + /* iwork = */ static_cast(buffers[0].buffer.raw_ptr()), + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "svd_impl: sgesvdx_ failed with code " << info; + throw std::runtime_error(ss.str()); + } + } +}; + +template <> +struct SVDWork> { + using T = std::complex; + using R = float; + + int N; + int M; + int K; + int lda; + int ldu; + int ldvt; + char jobz; + std::vector buffers; + int lwork; + + SVDWork(int N, int M, int K, char jobz) + : N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) { + T workspace_dimension = 0; + + // Will contain the indices of eigenvectors that failed to converge (not + // used here but required by lapack). + buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K)); + + const int lrwork = + jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K); + buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork)); + + int lwork_query = -1; + int work_query = -1; + int info; + + // Compute workspace size. + gesdd( + /* jobz = */ &jobz, + // M and N are swapped since lapack expects column-major. + /* m = */ &N, + /* n = */ &M, + /* a = */ nullptr, + /* lda = */ &lda, + /* s = */ nullptr, + /* u = */ nullptr, + /* ldu = */ &ldu, + /* vt = */ nullptr, + /* ldvt = */ &ldvt, + /* work = */ &workspace_dimension, + /* lwork = */ &lwork_query, + /* rwork = */ static_cast(buffers[1].buffer.raw_ptr()), + /* iwork = */ static_cast(buffers[0].buffer.raw_ptr()), + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "[SVD::eval_cpu] workspace calculation failed with code " << info; + throw std::runtime_error(ss.str()); + } + + lwork = workspace_dimension.real(); + buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); + } + + void run(T* a, R* s, T* u, T* vt) { + int info; + gesdd( + /* jobz = */ &jobz, + // M and N are swapped since lapack expects column-major. + /* m = */ &N, + /* n = */ &M, + /* a = */ a, + /* lda = */ &lda, + /* s = */ s, + // According to the identity above, lapack will write Vᵀᵀ as U. + /* u = */ u, + /* ldu = */ &ldu, + // According to the identity above, lapack will write Uᵀ as Vᵀ. + /* vt = */ vt, + /* ldvt = */ &ldvt, + /* work = */ static_cast(buffers[2].buffer.raw_ptr()), + /* lwork = */ &lwork, + /* rwork = */ static_cast(buffers[1].buffer.raw_ptr()), + /* iwork = */ static_cast(buffers[0].buffer.raw_ptr()), + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "svd_impl: sgesvdx_ failed with code " << info; + throw std::runtime_error(ss.str()); + } + } +}; + template void svd_impl( const array& a, @@ -27,6 +204,8 @@ void svd_impl( const int N = a.shape(-1); const int K = std::min(M, N); + using R = typename SVDWork::R; + size_t num_matrices = a.size() / (M * N); // lapack clobbers the input, so we have to make a copy. @@ -42,7 +221,7 @@ void svd_impl( encoder.set_input_array(a); auto in_ptr = in.data(); T* u_ptr; - T* s_ptr; + R* s_ptr; T* vt_ptr; if (compute_uv) { @@ -58,7 +237,7 @@ void svd_impl( encoder.set_output_array(s); encoder.set_output_array(vt); - s_ptr = s.data(); + s_ptr = s.data(); u_ptr = u.data(); vt_ptr = vt.data(); } else { @@ -68,96 +247,26 @@ void svd_impl( encoder.set_output_array(s); - s_ptr = s.data(); + s_ptr = s.data(); u_ptr = nullptr; vt_ptr = nullptr; } encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() { - // A of shape M x N. The leading dimension is N since lapack receives Aᵀ. - const int lda = N; - // U of shape M x M. (N x N in lapack). - const int ldu = N; - // Vᵀ of shape N x N. (M x M in lapack). - const int ldvt = M; - - auto jobz = (u_ptr) ? "A" : "N"; - - T workspace_dimension = 0; - - // Will contain the indices of eigenvectors that failed to converge (not - // used here but required by lapack). - auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)}; - - static const int lwork_query = -1; - - int info; - - // Compute workspace size. - gesdd( - /* jobz = */ jobz, - // M and N are swapped since lapack expects column-major. - /* m = */ &N, - /* n = */ &M, - /* a = */ nullptr, - /* lda = */ &lda, - /* s = */ nullptr, - /* u = */ nullptr, - /* ldu = */ &ldu, - /* vt = */ nullptr, - /* ldvt = */ &ldvt, - /* work = */ &workspace_dimension, - /* lwork = */ &lwork_query, - /* iwork = */ static_cast(iwork.buffer.raw_ptr()), - /* info = */ &info); - - if (info != 0) { - std::stringstream ss; - ss << "[SVD::eval_cpu] workspace calculation failed with code " << info; - throw std::runtime_error(ss.str()); - } - - const int lwork = workspace_dimension; - auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)}; - + auto jobz = (u_ptr) ? 'A' : 'N'; + SVDWork svd_work(N, M, K, jobz); // Loop over matrices. for (int i = 0; i < num_matrices; i++) { - gesdd( - /* jobz = */ jobz, - // M and N are swapped since lapack expects column-major. - /* m = */ &N, - /* n = */ &M, - /* a = */ in_ptr + M * N * i, - /* lda = */ &lda, - /* s = */ s_ptr + K * i, - // According to the identity above, lapack will write Vᵀᵀ as U. - /* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr, - /* ldu = */ &ldu, - // According to the identity above, lapack will write Uᵀ as Vᵀ. - /* vt = */ u_ptr ? u_ptr + M * M * i : nullptr, - /* ldvt = */ &ldvt, - /* work = */ static_cast(scratch.buffer.raw_ptr()), - /* lwork = */ &lwork, - /* iwork = */ static_cast(iwork.buffer.raw_ptr()), - /* info = */ &info); - - if (info != 0) { - std::stringstream ss; - ss << "svd_impl: sgesvdx_ failed with code " << info; - throw std::runtime_error(ss.str()); - } + svd_work.run( + in_ptr + M * N * i, + s_ptr + K * i, + vt_ptr ? vt_ptr + N * N * i : nullptr, + u_ptr ? u_ptr + M * M * i : nullptr); } }); encoder.add_temporary(in); } -template -void compute_svd( - const array& a, - bool compute_uv, - std::vector& outputs, - Stream stream) {} - void SVD::eval_cpu( const std::vector& inputs, std::vector& outputs) { @@ -168,9 +277,12 @@ void SVD::eval_cpu( case float64: svd_impl(inputs[0], outputs, compute_uv_, stream()); break; + case complex64: + svd_impl>(inputs[0], outputs, compute_uv_, stream()); + break; default: throw std::runtime_error( - "[SVD::eval_cpu] only supports float32 or float64."); + "[SVD::eval_cpu] only supports float32, float64, or complex64."); } } diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index afa8e447a..7ac080dab 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -250,7 +250,7 @@ std::pair qr(const array& a, StreamOrDevice s /* = {} */) { std::vector svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) { check_cpu_stream(s, "[linalg::svd]"); - check_float(a.dtype(), "[linalg::svd]"); + check_float_or_complex(a.dtype(), "[linalg::svd]"); if (a.ndim() < 2) { std::ostringstream msg; @@ -268,10 +268,12 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) { s_shape.pop_back(); s_shape[rank - 2] = std::min(m, n); + auto s_dtype = a.dtype() == complex64 ? float32 : a.dtype(); + if (!compute_uv) { return {array( std::move(s_shape), - a.dtype(), + s_dtype, std::make_shared(to_stream(s), compute_uv), {a})}; } @@ -286,7 +288,7 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) { return array::make_arrays( {u_shape, s_shape, vt_shape}, - {a.dtype(), a.dtype(), a.dtype()}, + {a.dtype(), s_dtype, a.dtype()}, std::make_shared(to_stream(s), compute_uv), {a}); } @@ -703,4 +705,4 @@ array solve_triangular( return matmul(a_inv, b, s); } -} // namespace mlx::core::linalg +} // namespace mlx::core::linalg \ No newline at end of file diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 81a43ed7f..8e2444f20 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -168,6 +168,42 @@ class TestLinalg(mlx_tests.MLXTestCase): ) ) + # Test float64 - use CPU stream since float64 is not supported on GPU + with mx.stream(mx.cpu): + A_f64 = mx.array( + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64 + ) + U_f64, S_f64, Vt_f64 = mx.linalg.svd(A_f64, compute_uv=True) + mx.eval(U_f64, S_f64, Vt_f64) + self.assertTrue( + mx.allclose( + U_f64[:, : len(S_f64)] @ mx.diag(S_f64) @ Vt_f64, + A_f64, + rtol=1e-5, + atol=1e-7, + ) + ) + self.assertEqual(S_f64.dtype, mx.float64) + + # Test complex64 - use CPU stream since complex64 is not supported on GPU + with mx.stream(mx.cpu): + A_c64 = mx.array( + [[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=mx.complex64 + ) + U_c64, S_c64, Vt_c64 = mx.linalg.svd(A_c64, compute_uv=True) + mx.eval(U_c64, S_c64, Vt_c64) + self.assertTrue( + mx.allclose( + U_c64[:, : len(S_c64)] @ mx.diag(S_c64) @ Vt_c64, + A_c64, + rtol=1e-5, + atol=1e-7, + ) + ) + self.assertEqual(S_c64.dtype, mx.float32) + self.assertEqual(U_c64.dtype, mx.complex64) + self.assertEqual(Vt_c64.dtype, mx.complex64) + def test_inverse(self): A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32) A_inv = mx.linalg.inv(A, stream=mx.cpu) @@ -342,6 +378,43 @@ class TestLinalg(mlx_tests.MLXTestCase): A_np = np.random.randn(3, n, n).astype(np.float32) check_eigs_and_vecs(A_np) + # Test float64 - use CPU stream since float64 is not supported on GPU + with mx.stream(mx.cpu): + A_np_f64 = np.array([[1.0, 1.0], [3.0, 4.0]], dtype=np.float64) + A_f64 = mx.array(A_np_f64, dtype=mx.float64) + eig_vals_f64, eig_vecs_f64 = mx.linalg.eig(A_f64) + mx.eval(eig_vals_f64, eig_vecs_f64) + self.assertTrue( + mx.allclose( + A_f64 @ eig_vecs_f64, + eig_vals_f64[..., None, :] * eig_vecs_f64, + rtol=1e-5, + atol=1e-5, + ) + ) + # Eigenvalues should be complex64 (output dtype) + self.assertEqual(eig_vals_f64.dtype, mx.complex64) + self.assertEqual(eig_vecs_f64.dtype, mx.complex64) + + # Test complex64 input - use CPU stream since complex64 is not supported on GPU + with mx.stream(mx.cpu): + A_np_c64 = np.array( + [[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=np.complex64 + ) + A_c64 = mx.array(A_np_c64, dtype=mx.complex64) + eig_vals_c64, eig_vecs_c64 = mx.linalg.eig(A_c64) + mx.eval(eig_vals_c64, eig_vecs_c64) + self.assertTrue( + mx.allclose( + A_c64 @ eig_vecs_c64, + eig_vals_c64[..., None, :] * eig_vecs_c64, + rtol=1e-5, + atol=1e-5, + ) + ) + self.assertEqual(eig_vals_c64.dtype, mx.complex64) + self.assertEqual(eig_vecs_c64.dtype, mx.complex64) + # Test error cases with self.assertRaises(ValueError): mx.linalg.eig(mx.array([1.0, 2.0])) # 1D array