mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
fix lapack svd (#2515)
This commit is contained in:
parent
e7c6e1db82
commit
cea9369610
@ -47,7 +47,7 @@ INSTANTIATE_LAPACK_REAL(orgqr)
|
|||||||
INSTANTIATE_LAPACK_REAL(syevd)
|
INSTANTIATE_LAPACK_REAL(syevd)
|
||||||
INSTANTIATE_LAPACK_REAL(geev)
|
INSTANTIATE_LAPACK_REAL(geev)
|
||||||
INSTANTIATE_LAPACK_REAL(potrf)
|
INSTANTIATE_LAPACK_REAL(potrf)
|
||||||
INSTANTIATE_LAPACK_REAL(gesvdx)
|
INSTANTIATE_LAPACK_REAL(gesdd)
|
||||||
INSTANTIATE_LAPACK_REAL(getrf)
|
INSTANTIATE_LAPACK_REAL(getrf)
|
||||||
INSTANTIATE_LAPACK_REAL(getri)
|
INSTANTIATE_LAPACK_REAL(getri)
|
||||||
INSTANTIATE_LAPACK_REAL(trtri)
|
INSTANTIATE_LAPACK_REAL(trtri)
|
||||||
|
@ -81,9 +81,7 @@ void svd_impl(
|
|||||||
// Vᵀ of shape N x N. (M x M in lapack).
|
// Vᵀ of shape N x N. (M x M in lapack).
|
||||||
const int ldvt = M;
|
const int ldvt = M;
|
||||||
|
|
||||||
auto job_u = (u_ptr) ? "V" : "N";
|
auto jobz = (u_ptr) ? "A" : "N";
|
||||||
auto job_vt = (u_ptr) ? "V" : "N";
|
|
||||||
static constexpr auto range = "A";
|
|
||||||
|
|
||||||
// Will contain the number of singular values after the call has returned.
|
// Will contain the number of singular values after the call has returned.
|
||||||
int ns = 0;
|
int ns = 0;
|
||||||
@ -91,30 +89,20 @@ void svd_impl(
|
|||||||
|
|
||||||
// Will contain the indices of eigenvectors that failed to converge (not
|
// Will contain the indices of eigenvectors that failed to converge (not
|
||||||
// used here but required by lapack).
|
// used here but required by lapack).
|
||||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
|
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
|
||||||
|
|
||||||
static const int lwork_query = -1;
|
static const int lwork_query = -1;
|
||||||
|
|
||||||
static const int ignored_int = 0;
|
|
||||||
static const T ignored_float = 0;
|
|
||||||
|
|
||||||
int info;
|
int info;
|
||||||
|
|
||||||
// Compute workspace size.
|
// Compute workspace size.
|
||||||
gesvdx<T>(
|
gesdd<T>(
|
||||||
/* jobu = */ job_u,
|
/* jobz = */ jobz,
|
||||||
/* jobvt = */ job_vt,
|
|
||||||
/* range = */ range,
|
|
||||||
// M and N are swapped since lapack expects column-major.
|
// M and N are swapped since lapack expects column-major.
|
||||||
/* m = */ &N,
|
/* m = */ &N,
|
||||||
/* n = */ &M,
|
/* n = */ &M,
|
||||||
/* a = */ nullptr,
|
/* a = */ nullptr,
|
||||||
/* lda = */ &lda,
|
/* lda = */ &lda,
|
||||||
/* vl = */ &ignored_float,
|
|
||||||
/* vu = */ &ignored_float,
|
|
||||||
/* il = */ &ignored_int,
|
|
||||||
/* iu = */ &ignored_int,
|
|
||||||
/* ns = */ &ns,
|
|
||||||
/* s = */ nullptr,
|
/* s = */ nullptr,
|
||||||
/* u = */ nullptr,
|
/* u = */ nullptr,
|
||||||
/* ldu = */ &ldu,
|
/* ldu = */ &ldu,
|
||||||
@ -136,20 +124,13 @@ void svd_impl(
|
|||||||
|
|
||||||
// Loop over matrices.
|
// Loop over matrices.
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
gesvdx<T>(
|
gesdd<T>(
|
||||||
/* jobu = */ job_u,
|
/* jobz = */ jobz,
|
||||||
/* jobvt = */ job_vt,
|
|
||||||
/* range = */ range,
|
|
||||||
// M and N are swapped since lapack expects column-major.
|
// M and N are swapped since lapack expects column-major.
|
||||||
/* m = */ &N,
|
/* m = */ &N,
|
||||||
/* n = */ &M,
|
/* n = */ &M,
|
||||||
/* a = */ in_ptr + M * N * i,
|
/* a = */ in_ptr + M * N * i,
|
||||||
/* lda = */ &lda,
|
/* lda = */ &lda,
|
||||||
/* vl = */ &ignored_float,
|
|
||||||
/* vu = */ &ignored_float,
|
|
||||||
/* il = */ &ignored_int,
|
|
||||||
/* iu = */ &ignored_int,
|
|
||||||
/* ns = */ &ns,
|
|
||||||
/* s = */ s_ptr + K * i,
|
/* s = */ s_ptr + K * i,
|
||||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||||
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||||
@ -167,13 +148,6 @@ void svd_impl(
|
|||||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||||
throw std::runtime_error(ss.str());
|
throw std::runtime_error(ss.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ns != K) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "svd_impl: expected " << K << " singular values, but " << ns
|
|
||||||
<< " were computed.";
|
|
||||||
throw std::runtime_error(ss.str());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
encoder.add_temporary(in);
|
encoder.add_temporary(in);
|
||||||
|
Loading…
Reference in New Issue
Block a user