Double for lapack (#1904)

* double for lapack ops

* add double support for lapack ops
This commit is contained in:
Awni Hannun
2025-02-25 11:39:36 -08:00
committed by GitHub
parent 28b8079e30
commit 7d042f17fe
11 changed files with 338 additions and 225 deletions

View File

@@ -7,6 +7,7 @@
namespace mlx::core {
template <typename T>
void svd_impl(const array& a, array& u, array& s, array& vt) {
// Lapack uses the column-major convention. To avoid having to transpose
// the input and then transpose the outputs, we swap the indices/sizes of the
@@ -31,7 +32,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
array in(a.shape(), float32, nullptr, {});
array in(a.shape(), a.dtype(), nullptr, {});
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
// Allocate outputs.
@@ -45,7 +46,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
// Will contain the number of singular values after the call has returned.
int ns = 0;
float workspace_dimension = 0;
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not used
// here but required by lapack).
@@ -54,13 +55,12 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
static const int lwork_query = -1;
static const int ignored_int = 0;
static const float ignored_float = 0;
static const T ignored_float = 0;
int info;
// Compute workspace size.
MLX_LAPACK_FUNC(sgesvdx)
(
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
@@ -86,51 +86,50 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
MLX_LAPACK_FUNC(sgesvdx)
(
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in.data<float>() + M * N * i,
/* a = */ in.data<T>() + M * N * i,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s.data<float>() + K * i,
/* s = */ s.data<T>() + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt.data<float>() + N * N * i,
/* u = */ vt.data<T>() + N * N * i,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u.data<float>() + M * M * i,
/* vt = */ u.data<T>() + M * M * i,
/* ldvt = */ &ldvt,
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
ss << "[SVD::eval_cpu] failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
ss << "[SVD::eval_cpu] expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
@@ -140,10 +139,17 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
void SVD::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[SVD::eval] only supports float32.");
switch (inputs[0].dtype()) {
case float32:
svd_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]);
break;
case float64:
svd_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]);
break;
default:
throw std::runtime_error(
"[SVD::eval_cpu] only supports float32 or float64.");
}
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
}
} // namespace mlx::core