Add float64 Eig and complex64 SVD/Eig support (Fixes #2708) (#2737)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Harsh Sutaria
2025-11-22 09:51:36 -05:00
committed by GitHub
parent d5f61a93fa
commit 618c87af8c
5 changed files with 471 additions and 161 deletions

View File

@@ -250,7 +250,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
std::vector<array>
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::svd]");
check_float(a.dtype(), "[linalg::svd]");
check_float_or_complex(a.dtype(), "[linalg::svd]");
if (a.ndim() < 2) {
std::ostringstream msg;
@@ -268,10 +268,12 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
s_shape.pop_back();
s_shape[rank - 2] = std::min(m, n);
auto s_dtype = a.dtype() == complex64 ? float32 : a.dtype();
if (!compute_uv) {
return {array(
std::move(s_shape),
a.dtype(),
s_dtype,
std::make_shared<SVD>(to_stream(s), compute_uv),
{a})};
}
@@ -286,7 +288,7 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
return array::make_arrays(
{u_shape, s_shape, vt_shape},
{a.dtype(), a.dtype(), a.dtype()},
{a.dtype(), s_dtype, a.dtype()},
std::make_shared<SVD>(to_stream(s), compute_uv),
{a});
}
@@ -703,4 +705,4 @@ array solve_triangular(
return matmul(a_inv, b, s);
}
} // namespace mlx::core::linalg
} // namespace mlx::core::linalg