fix lapack svd (#2515)

This commit is contained in:
Awni Hannun 2025-08-18 15:07:59 -07:00 committed by GitHub
parent e7c6e1db82
commit cea9369610
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 33 deletions

View File

@ -47,7 +47,7 @@ INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd) INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev) INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf) INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesvdx) INSTANTIATE_LAPACK_REAL(gesdd)
INSTANTIATE_LAPACK_REAL(getrf) INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri) INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri) INSTANTIATE_LAPACK_REAL(trtri)

View File

@ -81,9 +81,7 @@ void svd_impl(
// Vᵀ of shape N x N. (M x M in lapack). // Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M; const int ldvt = M;
auto job_u = (u_ptr) ? "V" : "N"; auto jobz = (u_ptr) ? "A" : "N";
auto job_vt = (u_ptr) ? "V" : "N";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned. // Will contain the number of singular values after the call has returned.
int ns = 0; int ns = 0;
@ -91,30 +89,20 @@ void svd_impl(
// Will contain the indices of eigenvectors that failed to converge (not // Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack). // used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)}; auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
static const int lwork_query = -1; static const int lwork_query = -1;
static const int ignored_int = 0;
static const T ignored_float = 0;
int info; int info;
// Compute workspace size. // Compute workspace size.
gesvdx<T>( gesdd<T>(
/* jobu = */ job_u, /* jobz = */ jobz,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major. // M and N are swapped since lapack expects column-major.
/* m = */ &N, /* m = */ &N,
/* n = */ &M, /* n = */ &M,
/* a = */ nullptr, /* a = */ nullptr,
/* lda = */ &lda, /* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr, /* s = */ nullptr,
/* u = */ nullptr, /* u = */ nullptr,
/* ldu = */ &ldu, /* ldu = */ &ldu,
@ -136,20 +124,13 @@ void svd_impl(
// Loop over matrices. // Loop over matrices.
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
gesvdx<T>( gesdd<T>(
/* jobu = */ job_u, /* jobz = */ jobz,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major. // M and N are swapped since lapack expects column-major.
/* m = */ &N, /* m = */ &N,
/* n = */ &M, /* n = */ &M,
/* a = */ in_ptr + M * N * i, /* a = */ in_ptr + M * N * i,
/* lda = */ &lda, /* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s_ptr + K * i, /* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U. // According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr, /* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
@ -167,13 +148,6 @@ void svd_impl(
ss << "svd_impl: sgesvdx_ failed with code " << info; ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
} }
}); });
encoder.add_temporary(in); encoder.add_temporary(in);