mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add float64 Eig and complex64 SVD/Eig support (Fixes #2708) (#2737)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -250,7 +250,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
||||
std::vector<array>
|
||||
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
||||
check_cpu_stream(s, "[linalg::svd]");
|
||||
check_float(a.dtype(), "[linalg::svd]");
|
||||
check_float_or_complex(a.dtype(), "[linalg::svd]");
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
@@ -268,10 +268,12 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
||||
s_shape.pop_back();
|
||||
s_shape[rank - 2] = std::min(m, n);
|
||||
|
||||
auto s_dtype = a.dtype() == complex64 ? float32 : a.dtype();
|
||||
|
||||
if (!compute_uv) {
|
||||
return {array(
|
||||
std::move(s_shape),
|
||||
a.dtype(),
|
||||
s_dtype,
|
||||
std::make_shared<SVD>(to_stream(s), compute_uv),
|
||||
{a})};
|
||||
}
|
||||
@@ -286,7 +288,7 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
||||
|
||||
return array::make_arrays(
|
||||
{u_shape, s_shape, vt_shape},
|
||||
{a.dtype(), a.dtype(), a.dtype()},
|
||||
{a.dtype(), s_dtype, a.dtype()},
|
||||
std::make_shared<SVD>(to_stream(s), compute_uv),
|
||||
{a});
|
||||
}
|
||||
@@ -703,4 +705,4 @@ array solve_triangular(
|
||||
return matmul(a_inv, b, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
} // namespace mlx::core::linalg
|
||||
Reference in New Issue
Block a user