From cea93696109e24a5605516c5a4b48c030c15003a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 18 Aug 2025 15:07:59 -0700 Subject: [PATCH] fix lapack svd (#2515) --- mlx/backend/cpu/lapack.h | 2 +- mlx/backend/cpu/svd.cpp | 38 ++++++-------------------------------- 2 files changed, 7 insertions(+), 33 deletions(-) diff --git a/mlx/backend/cpu/lapack.h b/mlx/backend/cpu/lapack.h index b242093ff..ce735f26c 100644 --- a/mlx/backend/cpu/lapack.h +++ b/mlx/backend/cpu/lapack.h @@ -47,7 +47,7 @@ INSTANTIATE_LAPACK_REAL(orgqr) INSTANTIATE_LAPACK_REAL(syevd) INSTANTIATE_LAPACK_REAL(geev) INSTANTIATE_LAPACK_REAL(potrf) -INSTANTIATE_LAPACK_REAL(gesvdx) +INSTANTIATE_LAPACK_REAL(gesdd) INSTANTIATE_LAPACK_REAL(getrf) INSTANTIATE_LAPACK_REAL(getri) INSTANTIATE_LAPACK_REAL(trtri) diff --git a/mlx/backend/cpu/svd.cpp b/mlx/backend/cpu/svd.cpp index 08ad444e1..6e57eb401 100644 --- a/mlx/backend/cpu/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -81,9 +81,7 @@ void svd_impl( // Vᵀ of shape N x N. (M x M in lapack). const int ldvt = M; - auto job_u = (u_ptr) ? "V" : "N"; - auto job_vt = (u_ptr) ? "V" : "N"; - static constexpr auto range = "A"; + auto jobz = (u_ptr) ? "A" : "N"; // Will contain the number of singular values after the call has returned. int ns = 0; @@ -91,30 +89,20 @@ void svd_impl( // Will contain the indices of eigenvectors that failed to converge (not // used here but required by lapack). - auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)}; + auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)}; static const int lwork_query = -1; - static const int ignored_int = 0; - static const T ignored_float = 0; - int info; // Compute workspace size. - gesvdx( - /* jobu = */ job_u, - /* jobvt = */ job_vt, - /* range = */ range, + gesdd( + /* jobz = */ jobz, // M and N are swapped since lapack expects column-major. /* m = */ &N, /* n = */ &M, /* a = */ nullptr, /* lda = */ &lda, - /* vl = */ &ignored_float, - /* vu = */ &ignored_float, - /* il = */ &ignored_int, - /* iu = */ &ignored_int, - /* ns = */ &ns, /* s = */ nullptr, /* u = */ nullptr, /* ldu = */ &ldu, @@ -136,20 +124,13 @@ void svd_impl( // Loop over matrices. for (int i = 0; i < num_matrices; i++) { - gesvdx( - /* jobu = */ job_u, - /* jobvt = */ job_vt, - /* range = */ range, + gesdd( + /* jobz = */ jobz, // M and N are swapped since lapack expects column-major. /* m = */ &N, /* n = */ &M, /* a = */ in_ptr + M * N * i, /* lda = */ &lda, - /* vl = */ &ignored_float, - /* vu = */ &ignored_float, - /* il = */ &ignored_int, - /* iu = */ &ignored_int, - /* ns = */ &ns, /* s = */ s_ptr + K * i, // According to the identity above, lapack will write Vᵀᵀ as U. /* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr, @@ -167,13 +148,6 @@ void svd_impl( ss << "svd_impl: sgesvdx_ failed with code " << info; throw std::runtime_error(ss.str()); } - - if (ns != K) { - std::stringstream ss; - ss << "svd_impl: expected " << K << " singular values, but " << ns - << " were computed."; - throw std::runtime_error(ss.str()); - } } }); encoder.add_temporary(in);