Double for lapack (#1904)

* double for lapack ops

* add double support for lapack ops
This commit is contained in:
Awni Hannun 2025-02-25 11:39:36 -08:00 committed by GitHub
parent 28b8079e30
commit 7d042f17fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 338 additions and 225 deletions

View File

@ -8,6 +8,7 @@
namespace mlx::core { namespace mlx::core {
template <typename T>
void cholesky_impl(const array& a, array& factor, bool upper) { void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that // Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric: // the matrix should be symmetric:
@ -28,13 +29,12 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
const int N = a.shape(-1); const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N); const size_t num_matrices = a.size() / (N * N);
float* matrix = factor.data<float>(); T* matrix = factor.data<T>();
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization. // Compute Cholesky factorization.
int info; int info;
MLX_LAPACK_FUNC(spotrf) potrf<T>(
(
/* uplo = */ &uplo, /* uplo = */ &uplo,
/* n = */ &N, /* n = */ &N,
/* a = */ matrix, /* a = */ matrix,
@ -65,10 +65,17 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
} }
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) { void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) { switch (inputs[0].dtype()) {
throw std::runtime_error("[Cholesky::eval] only supports float32."); case float32:
cholesky_impl<float>(inputs[0], output, upper_);
break;
case float64:
cholesky_impl<double>(inputs[0], output, upper_);
break;
default:
throw std::runtime_error(
"[Cholesky::eval_cpu] only supports float32 or float64.");
} }
cholesky_impl(inputs[0], output, upper_);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -11,30 +11,58 @@ namespace mlx::core {
namespace { namespace {
void ssyevd( template <typename T>
char jobz, void eigh_impl(
char uplo, array& vectors,
float* a, array& values,
int N, const std::string& uplo,
float* w, bool compute_eigenvectors) {
float* work, auto vec_ptr = vectors.data<T>();
int lwork, auto eig_ptr = values.data<T>();
int* iwork,
int liwork) { char jobz = compute_eigenvectors ? 'V' : 'N';
auto N = vectors.shape(-1);
// Work query
int lwork = -1;
int liwork = -1;
int info; int info;
MLX_LAPACK_FUNC(ssyevd) {
( T work;
/* jobz = */ &jobz, int iwork;
/* uplo = */ &uplo, syevd<T>(
/* n = */ &N, &jobz,
/* a = */ a, uplo.c_str(),
/* lda = */ &N, &N,
/* w = */ w, nullptr,
/* work = */ work, &N,
/* lwork = */ &lwork, nullptr,
/* iwork = */ iwork, &work,
/* liwork = */ &liwork, &lwork,
/* info = */ &info); &iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < vectors.size() / (N * N); ++i) {
syevd<T>(
&jobz,
uplo.c_str(),
&N,
vec_ptr,
&N,
eig_ptr,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
&liwork,
&info);
vec_ptr += N * N;
eig_ptr += N;
if (info != 0) { if (info != 0) {
std::stringstream msg; std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
@ -42,6 +70,7 @@ void ssyevd(
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
} }
}
} // namespace } // namespace
@ -80,39 +109,16 @@ void Eigh::eval_cpu(
} }
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size()); vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
} }
switch (a.dtype()) {
auto vec_ptr = vectors.data<float>(); case float32:
auto eig_ptr = values.data<float>(); eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_);
break;
char jobz = compute_eigenvectors_ ? 'V' : 'N'; case float64:
auto N = a.shape(-1); eigh_impl<double>(vectors, values, uplo_, compute_eigenvectors_);
break;
// Work query default:
int lwork; throw std::runtime_error(
int liwork; "[Eigh::eval_cpu] only supports float32 or float64.");
{
float work;
int iwork;
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < a.size() / (N * N); ++i) {
ssyevd(
jobz,
uplo_[0],
vec_ptr,
N,
eig_ptr,
static_cast<float*>(work_buf.buffer.raw_ptr()),
lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
liwork);
vec_ptr += N * N;
eig_ptr += N;
} }
} }

View File

@ -5,44 +5,33 @@
#include "mlx/backend/cpu/lapack.h" #include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
int info;
MLX_LAPACK_FUNC(strtri)
(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
return info;
}
namespace mlx::core { namespace mlx::core {
template <typename T>
void general_inv(array& inv, int N, int i) { void general_inv(array& inv, int N, int i) {
int info; int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
// Compute LU factorization. // Compute LU factorization.
sgetrf_( getrf<T>(
/* m = */ &N, /* m = */ &N,
/* n = */ &N, /* n = */ &N,
/* a = */ inv.data<float>() + N * N * i, /* a = */ inv.data<T>() + N * N * i,
/* lda = */ &N, /* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()), /* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info); /* info = */ &info);
if (info != 0) { if (info != 0) {
std::stringstream ss; std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info; ss << "[Inverse::eval_cpu] LU factorization failed with error code "
<< info;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
static const int lwork_query = -1; static const int lwork_query = -1;
float workspace_size = 0; T workspace_size = 0;
// Compute workspace size. // Compute workspace size.
sgetri_( getri<T>(
/* m = */ &N, /* m = */ &N,
/* a = */ nullptr, /* a = */ nullptr,
/* lda = */ &N, /* lda = */ &N,
@ -53,36 +42,44 @@ void general_inv(array& inv, int N, int i) {
if (info != 0) { if (info != 0) {
std::stringstream ss; std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code " ss << "[Inverse::eval_cpu] LU workspace calculation failed with error code "
<< info; << info;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
const int lwork = workspace_size; const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
// Compute inverse. // Compute inverse.
sgetri_( getri<T>(
/* m = */ &N, /* m = */ &N,
/* a = */ inv.data<float>() + N * N * i, /* a = */ inv.data<T>() + N * N * i,
/* lda = */ &N, /* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()), /* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()), /* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork, /* lwork = */ &lwork,
/* info = */ &info); /* info = */ &info);
if (info != 0) { if (info != 0) {
std::stringstream ss; std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info; ss << "[Inverse::eval_cpu] inversion failed with error code " << info;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
} }
template <typename T>
void tri_inv(array& inv, int N, int i, bool upper) { void tri_inv(array& inv, int N, int i, bool upper) {
const char uplo = upper ? 'L' : 'U'; const char uplo = upper ? 'L' : 'U';
const char diag = 'N'; const char diag = 'N';
float* data = inv.data<float>() + N * N * i; T* data = inv.data<T>() + N * N * i;
int info = strtri_wrapper(uplo, diag, data, N); int info;
trtri<T>(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ data,
/* lda = */ &N,
/* info = */ &info);
// zero out the other triangle // zero out the other triangle
if (upper) { if (upper) {
@ -99,11 +96,13 @@ void tri_inv(array& inv, int N, int i, bool upper) {
if (info != 0) { if (info != 0) {
std::stringstream ss; std::stringstream ss;
ss << "inverse_impl: triangular inversion failed with error code " << info; ss << "[Inverse::eval_cpu] triangular inversion failed with error code "
<< info;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
} }
template <typename T>
void inverse_impl(const array& a, array& inv, bool tri, bool upper) { void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
// Lapack uses the column-major convention. We take advantage of the following // Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see // identity to avoid transposing (see
@ -118,18 +117,25 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
if (tri) { if (tri) {
tri_inv(inv, N, i, upper); tri_inv<T>(inv, N, i, upper);
} else { } else {
general_inv(inv, N, i); general_inv<T>(inv, N, i);
} }
} }
} }
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) { void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) { switch (inputs[0].dtype()) {
throw std::runtime_error("[Inverse::eval] only supports float32."); case float32:
inverse_impl<float>(inputs[0], output, tri_, upper_);
break;
case float64:
inverse_impl<double>(inputs[0], output, tri_, upper_);
break;
default:
throw std::runtime_error(
"[Inverse::eval_cpu] only supports float32 or float64.");
} }
inverse_impl(inputs[0], output, tri_, upper_);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -31,3 +31,22 @@
#define MLX_LAPACK_FUNC(f) f##_ #define MLX_LAPACK_FUNC(f) f##_
#endif #endif
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, double>) { \
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_TYPES(geqrf)
INSTANTIATE_LAPACK_TYPES(orgqr)
INSTANTIATE_LAPACK_TYPES(syevd)
INSTANTIATE_LAPACK_TYPES(potrf)
INSTANTIATE_LAPACK_TYPES(gesvdx)
INSTANTIATE_LAPACK_TYPES(getrf)
INSTANTIATE_LAPACK_TYPES(getri)
INSTANTIATE_LAPACK_TYPES(trtri)

View File

@ -9,11 +9,8 @@
namespace mlx::core { namespace mlx::core {
void lu_factor_impl( template <typename T>
const array& a, void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) {
array& lu,
array& pivots,
array& row_indices) {
int M = a.shape(-2); int M = a.shape(-2);
int N = a.shape(-1); int N = a.shape(-1);
@ -31,7 +28,7 @@ void lu_factor_impl(
copy_inplace( copy_inplace(
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral); a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
auto a_ptr = lu.data<float>(); auto a_ptr = lu.data<T>();
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes())); pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes())); row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
@ -42,8 +39,8 @@ void lu_factor_impl(
size_t num_matrices = a.size() / (M * N); size_t num_matrices = a.size() / (M * N);
for (size_t i = 0; i < num_matrices; ++i) { for (size_t i = 0; i < num_matrices; ++i) {
// Compute LU factorization of A // Compute LU factorization of A
MLX_LAPACK_FUNC(sgetrf) getrf<T>(
(/* m */ &M, /* m */ &M,
/* n */ &N, /* n */ &N,
/* a */ a_ptr, /* a */ a_ptr,
/* lda */ &M, /* lda */ &M,
@ -86,7 +83,17 @@ void LUF::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]); switch (inputs[0].dtype()) {
case float32:
luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]);
break;
case float64:
luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]);
break;
default:
throw std::runtime_error(
"[LUF::eval_cpu] only supports float32 or float64.");
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -7,36 +7,6 @@
namespace mlx::core { namespace mlx::core {
template <typename T>
struct lpack;
template <>
struct lpack<float> {
static void xgeqrf(
const int* m,
const int* n,
float* a,
const int* lda,
float* tau,
float* work,
const int* lwork,
int* info) {
sgeqrf_(m, n, a, lda, tau, work, lwork, info);
}
static void xorgqr(
const int* m,
const int* n,
const int* k,
float* a,
const int* lda,
const float* tau,
float* work,
const int* lwork,
int* info) {
sorgqr_(m, n, k, a, lda, tau, work, lwork, info);
}
};
template <typename T> template <typename T>
void qrf_impl(const array& a, array& q, array& r) { void qrf_impl(const array& a, array& q, array& r) {
const int M = a.shape(-2); const int M = a.shape(-2);
@ -48,7 +18,7 @@ void qrf_impl(const array& a, array& q, array& r) {
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors); allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
// Copy A to inplace input and make it col-contiguous // Copy A to inplace input and make it col-contiguous
array in(a.shape(), float32, nullptr, {}); array in(a.shape(), a.dtype(), nullptr, {});
auto flags = in.flags(); auto flags = in.flags();
// Copy the input to be column contiguous // Copy the input to be column contiguous
@ -66,8 +36,7 @@ void qrf_impl(const array& a, array& q, array& r) {
int info; int info;
// Compute workspace size // Compute workspace size
lpack<T>::xgeqrf( geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
// Update workspace size // Update workspace size
lwork = optimal_work; lwork = optimal_work;
@ -76,10 +45,10 @@ void qrf_impl(const array& a, array& q, array& r) {
// Loop over matrices // Loop over matrices
for (int i = 0; i < num_matrices; ++i) { for (int i = 0; i < num_matrices; ++i) {
// Solve // Solve
lpack<T>::xgeqrf( geqrf<T>(
&M, &M,
&N, &N,
in.data<float>() + M * N * i, in.data<T>() + M * N * i,
&lda, &lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i, static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()), static_cast<T*>(work.raw_ptr()),
@ -105,7 +74,7 @@ void qrf_impl(const array& a, array& q, array& r) {
// Get work size // Get work size
lwork = -1; lwork = -1;
lpack<T>::xorgqr( orgqr<T>(
&M, &M,
&num_reflectors, &num_reflectors,
&num_reflectors, &num_reflectors,
@ -121,11 +90,11 @@ void qrf_impl(const array& a, array& q, array& r) {
// Loop over matrices // Loop over matrices
for (int i = 0; i < num_matrices; ++i) { for (int i = 0; i < num_matrices; ++i) {
// Compute Q // Compute Q
lpack<T>::xorgqr( orgqr<T>(
&M, &M,
&num_reflectors, &num_reflectors,
&num_reflectors, &num_reflectors,
in.data<float>() + M * N * i, in.data<T>() + M * N * i,
&lda, &lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i, static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()), static_cast<T*>(work.raw_ptr()),
@ -152,10 +121,17 @@ void qrf_impl(const array& a, array& q, array& r) {
void QRF::eval_cpu( void QRF::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) { switch (inputs[0].dtype()) {
throw std::runtime_error("[QRF::eval] only supports float32."); case float32:
}
qrf_impl<float>(inputs[0], outputs[0], outputs[1]); qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
break;
case float64:
qrf_impl<double>(inputs[0], outputs[0], outputs[1]);
break;
default:
throw std::runtime_error(
"[QRF::eval_cpu] only supports float32 or float64.");
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -7,6 +7,7 @@
namespace mlx::core { namespace mlx::core {
template <typename T>
void svd_impl(const array& a, array& u, array& s, array& vt) { void svd_impl(const array& a, array& u, array& s, array& vt) {
// Lapack uses the column-major convention. To avoid having to transpose // Lapack uses the column-major convention. To avoid having to transpose
// the input and then transpose the outputs, we swap the indices/sizes of the // the input and then transpose the outputs, we swap the indices/sizes of the
@ -31,7 +32,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
size_t num_matrices = a.size() / (M * N); size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy. // lapack clobbers the input, so we have to make a copy.
array in(a.shape(), float32, nullptr, {}); array in(a.shape(), a.dtype(), nullptr, {});
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
// Allocate outputs. // Allocate outputs.
@ -45,7 +46,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
// Will contain the number of singular values after the call has returned. // Will contain the number of singular values after the call has returned.
int ns = 0; int ns = 0;
float workspace_dimension = 0; T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not used // Will contain the indices of eigenvectors that failed to converge (not used
// here but required by lapack). // here but required by lapack).
@ -54,13 +55,12 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
static const int lwork_query = -1; static const int lwork_query = -1;
static const int ignored_int = 0; static const int ignored_int = 0;
static const float ignored_float = 0; static const T ignored_float = 0;
int info; int info;
// Compute workspace size. // Compute workspace size.
MLX_LAPACK_FUNC(sgesvdx) gesvdx<T>(
(
/* jobu = */ job_u, /* jobu = */ job_u,
/* jobvt = */ job_vt, /* jobvt = */ job_vt,
/* range = */ range, /* range = */ range,
@ -86,51 +86,50 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
if (info != 0) { if (info != 0) {
std::stringstream ss; std::stringstream ss;
ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info; ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
const int lwork = workspace_dimension; const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
// Loop over matrices. // Loop over matrices.
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
MLX_LAPACK_FUNC(sgesvdx) gesvdx<T>(
(
/* jobu = */ job_u, /* jobu = */ job_u,
/* jobvt = */ job_vt, /* jobvt = */ job_vt,
/* range = */ range, /* range = */ range,
// M and N are swapped since lapack expects column-major. // M and N are swapped since lapack expects column-major.
/* m = */ &N, /* m = */ &N,
/* n = */ &M, /* n = */ &M,
/* a = */ in.data<float>() + M * N * i, /* a = */ in.data<T>() + M * N * i,
/* lda = */ &lda, /* lda = */ &lda,
/* vl = */ &ignored_float, /* vl = */ &ignored_float,
/* vu = */ &ignored_float, /* vu = */ &ignored_float,
/* il = */ &ignored_int, /* il = */ &ignored_int,
/* iu = */ &ignored_int, /* iu = */ &ignored_int,
/* ns = */ &ns, /* ns = */ &ns,
/* s = */ s.data<float>() + K * i, /* s = */ s.data<T>() + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U. // According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt.data<float>() + N * N * i, /* u = */ vt.data<T>() + N * N * i,
/* ldu = */ &ldu, /* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ. // According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u.data<float>() + M * M * i, /* vt = */ u.data<T>() + M * M * i,
/* ldvt = */ &ldvt, /* ldvt = */ &ldvt,
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()), /* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork, /* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()), /* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info); /* info = */ &info);
if (info != 0) { if (info != 0) {
std::stringstream ss; std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info; ss << "[SVD::eval_cpu] failed with code " << info;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
if (ns != K) { if (ns != K) {
std::stringstream ss; std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns ss << "[SVD::eval_cpu] expected " << K << " singular values, but " << ns
<< " were computed."; << " were computed.";
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
@ -140,10 +139,17 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
void SVD::eval_cpu( void SVD::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) { switch (inputs[0].dtype()) {
throw std::runtime_error("[SVD::eval] only supports float32."); case float32:
svd_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]);
break;
case float64:
svd_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]);
break;
default:
throw std::runtime_error(
"[SVD::eval_cpu] only supports float32 or float64.");
} }
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -18,6 +18,14 @@ void check_cpu_stream(const StreamOrDevice& s, const std::string& prefix) {
"Explicitly pass a CPU stream to run it."); "Explicitly pass a CPU stream to run it.");
} }
} }
void check_float(Dtype dtype, const std::string& prefix) {
if (dtype != float32 && dtype != float64) {
std::ostringstream msg;
msg << prefix << " Arrays must have type float32 or float64. "
<< "Received array with type " << dtype << ".";
throw std::invalid_argument(msg.str());
}
}
Dtype at_least_float(const Dtype& d) { Dtype at_least_float(const Dtype& d) {
return issubdtype(d, inexact) ? d : promote_types(d, float32); return issubdtype(d, inexact) ? d : promote_types(d, float32);
@ -184,12 +192,8 @@ array norm(
std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) { std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::qr]"); check_cpu_stream(s, "[linalg::qr]");
if (a.dtype() != float32) { check_float(a.dtype(), "[linalg::qr]");
std::ostringstream msg;
msg << "[linalg::qr] Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) { if (a.ndim() < 2) {
std::ostringstream msg; std::ostringstream msg;
msg << "[linalg::qr] Arrays must have >= 2 dimensions. Received array " msg << "[linalg::qr] Arrays must have >= 2 dimensions. Received array "
@ -212,12 +216,8 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) { std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::svd]"); check_cpu_stream(s, "[linalg::svd]");
if (a.dtype() != float32) { check_float(a.dtype(), "[linalg::svd]");
std::ostringstream msg;
msg << "[linalg::svd] Input array must have type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) { if (a.ndim() < 2) {
std::ostringstream msg; std::ostringstream msg;
msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array " msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array "
@ -251,12 +251,8 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) { array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) {
check_cpu_stream(s, "[linalg::inv]"); check_cpu_stream(s, "[linalg::inv]");
if (a.dtype() != float32) { check_float(a.dtype(), "[linalg::inv]");
std::ostringstream msg;
msg << "[linalg::inv] Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) { if (a.ndim() < 2) {
std::ostringstream msg; std::ostringstream msg;
msg << "[linalg::inv] Arrays must have >= 2 dimensions. Received array " msg << "[linalg::inv] Arrays must have >= 2 dimensions. Received array "
@ -292,13 +288,7 @@ array cholesky(
bool upper /* = false */, bool upper /* = false */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::cholesky]"); check_cpu_stream(s, "[linalg::cholesky]");
if (a.dtype() != float32) { check_float(a.dtype(), "[linalg::cholesky]");
std::ostringstream msg;
msg << "[linalg::cholesky] Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) { if (a.ndim() < 2) {
std::ostringstream msg; std::ostringstream msg;
msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array " msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array "
@ -321,12 +311,8 @@ array cholesky(
array pinv(const array& a, StreamOrDevice s /* = {} */) { array pinv(const array& a, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::pinv]"); check_cpu_stream(s, "[linalg::pinv]");
if (a.dtype() != float32) { check_float(a.dtype(), "[linalg::pinv]");
std::ostringstream msg;
msg << "[linalg::pinv] Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) { if (a.ndim() < 2) {
std::ostringstream msg; std::ostringstream msg;
msg << "[linalg::pinv] Arrays must have >= 2 dimensions. Received array " msg << "[linalg::pinv] Arrays must have >= 2 dimensions. Received array "
@ -368,12 +354,7 @@ array cholesky_inv(
bool upper /* = false */, bool upper /* = false */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::cholesky_inv]"); check_cpu_stream(s, "[linalg::cholesky_inv]");
if (L.dtype() != float32) { check_float(L.dtype(), "[linalg::cholesky_inv]");
std::ostringstream msg;
msg << "[linalg::cholesky_inv] Arrays must type float32. Received array "
<< "with type " << L.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (L.ndim() < 2) { if (L.ndim() < 2) {
std::ostringstream msg; std::ostringstream msg;
@ -474,12 +455,7 @@ void validate_eigh(
const StreamOrDevice& stream, const StreamOrDevice& stream,
const std::string fname) { const std::string fname) {
check_cpu_stream(stream, fname); check_cpu_stream(stream, fname);
if (a.dtype() != float32) { check_float(a.dtype(), fname);
std::ostringstream msg;
msg << fname << " Arrays must have type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) { if (a.ndim() < 2) {
std::ostringstream msg; std::ostringstream msg;
@ -524,12 +500,7 @@ void validate_lu(
const StreamOrDevice& stream, const StreamOrDevice& stream,
const std::string& fname) { const std::string& fname) {
check_cpu_stream(stream, fname); check_cpu_stream(stream, fname);
if (a.dtype() != float32) { check_float(a.dtype(), fname);
std::ostringstream msg;
msg << fname << " Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) { if (a.ndim() < 2) {
std::ostringstream msg; std::ostringstream msg;
@ -627,10 +598,12 @@ void validate_solve(
} }
auto out_type = promote_types(a.dtype(), b.dtype()); auto out_type = promote_types(a.dtype(), b.dtype());
if (out_type != float32) { if (out_type != float32 && out_type != float64) {
std::ostringstream msg; std::ostringstream msg;
msg << fname << " Input arrays must promote to float32. Received arrays " msg << fname
<< "with type " << a.dtype() << " and " << b.dtype() << "."; << " Input arrays must promote to float32 or float64. "
" Received arrays with type "
<< a.dtype() << " and " << b.dtype() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }

View File

@ -44,6 +44,8 @@ std::string buffer_format(const mx::array& a) {
return "f"; return "f";
case mx::bfloat16: case mx::bfloat16:
return "B"; return "B";
case mx::float64:
return "d";
case mx::complex64: case mx::complex64:
return "Zf\0"; return "Zf\0";
default: { default: {

View File

@ -152,6 +152,8 @@ nb::ndarray<NDParams...> mlx_to_nd_array(const mx::array& a) {
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy."); throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
case mx::float32: case mx::float32:
return mlx_to_nd_array_impl<float, NDParams...>(a); return mlx_to_nd_array_impl<float, NDParams...>(a);
case mx::float64:
return mlx_to_nd_array_impl<double, NDParams...>(a);
case mx::complex64: case mx::complex64:
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a); return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
default: default:

View File

@ -183,6 +183,115 @@ class TestDouble(mlx_tests.MLXTestCase):
c = a + b c = a + b
self.assertEqual(c.dtype, mx.float64) self.assertEqual(c.dtype, mx.float64)
def test_lapack(self):
with mx.stream(mx.cpu):
# QRF
A = mx.array([[2.0, 3.0], [1.0, 2.0]], dtype=mx.float64)
Q, R = mx.linalg.qr(A)
out = Q @ R
self.assertTrue(mx.allclose(out, A))
out = Q.T @ Q
self.assertTrue(mx.allclose(out, mx.eye(2)))
self.assertTrue(mx.allclose(mx.tril(R, -1), mx.zeros_like(R)))
self.assertEqual(Q.dtype, mx.float64)
self.assertEqual(R.dtype, mx.float64)
# SVD
A = mx.array(
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64
)
U, S, Vt = mx.linalg.svd(A)
self.assertTrue(mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A))
# Inverse
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64)
A_inv = mx.linalg.inv(A)
self.assertTrue(mx.allclose(A @ A_inv, mx.eye(A.shape[0])))
# Tri inv
A = mx.array([[1, 0, 0], [6, -5, 0], [-9, 8, 7]], dtype=mx.float64)
B = mx.array([[7, 0, 0], [3, -2, 0], [1, 8, 3]], dtype=mx.float64)
AB = mx.stack([A, B])
invs = mx.linalg.tri_inv(AB, upper=False)
for M, M_inv in zip(AB, invs):
self.assertTrue(mx.allclose(M @ M_inv, mx.eye(M.shape[0])))
# Cholesky
sqrtA = mx.array(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float64
)
A = sqrtA.T @ sqrtA / 81
L = mx.linalg.cholesky(A)
U = mx.linalg.cholesky(A, upper=True)
self.assertTrue(mx.allclose(L @ L.T, A))
self.assertTrue(mx.allclose(U.T @ U, A))
# Psueod inverse
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64)
A_plus = mx.linalg.pinv(A)
self.assertTrue(mx.allclose(A @ A_plus @ A, A))
# Eigh
def check_eigs_and_vecs(A_np, kwargs={}):
A = mx.array(A_np, dtype=mx.float64)
eig_vals, eig_vecs = mx.linalg.eigh(A, **kwargs)
eig_vals_np, _ = np.linalg.eigh(A_np, **kwargs)
self.assertTrue(np.allclose(eig_vals, eig_vals_np))
self.assertTrue(
mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs)
)
eig_vals_only = mx.linalg.eigvalsh(A, **kwargs)
self.assertTrue(mx.allclose(eig_vals, eig_vals_only))
# Test a simple 2x2 symmetric matrix
A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float64)
check_eigs_and_vecs(A_np)
# Test a larger random symmetric matrix
n = 5
np.random.seed(1)
A_np = np.random.randn(n, n).astype(np.float64)
A_np = (A_np + A_np.T) / 2
check_eigs_and_vecs(A_np)
# Test with upper triangle
check_eigs_and_vecs(A_np, {"UPLO": "U"})
# LU factorization
# Test 3x3 matrix
a = mx.array(
[[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]], dtype=mx.float64
)
P, L, U = mx.linalg.lu(a)
self.assertTrue(mx.allclose(L[P, :] @ U, a))
# Solve triangular
# Test lower triangular matrix
a = mx.array(
[[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]], dtype=mx.float64
)
b = mx.array([8.0, 14.0, 3.0], dtype=mx.float64)
result = mx.linalg.solve_triangular(a, b, upper=False)
expected = np.linalg.solve(np.array(a), np.array(b))
self.assertTrue(np.allclose(result, expected))
# Test upper triangular matrix
a = mx.array(
[[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]], dtype=mx.float64
)
b = mx.array([13.0, 33.0, 18.0], dtype=mx.float64)
result = mx.linalg.solve_triangular(a, b, upper=True)
expected = np.linalg.solve(np.array(a), np.array(b))
self.assertTrue(np.allclose(result, expected))
def test_conversion(self):
a = mx.array([1.0, 2.0], mx.float64)
b = np.array(a)
self.assertTrue(np.array_equal(a, b))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()