From 7d042f17fe61b01bf7c08c28de272c2b8c27b3f0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 25 Feb 2025 11:39:36 -0800 Subject: [PATCH] Double for lapack (#1904) * double for lapack ops * add double support for lapack ops --- mlx/backend/cpu/cholesky.cpp | 19 ++++-- mlx/backend/cpu/eigh.cpp | 128 ++++++++++++++++++----------------- mlx/backend/cpu/inverse.cpp | 70 ++++++++++--------- mlx/backend/cpu/lapack.h | 19 ++++++ mlx/backend/cpu/luf.cpp | 35 ++++++---- mlx/backend/cpu/qrf.cpp | 58 +++++----------- mlx/backend/cpu/svd.cpp | 44 ++++++------ mlx/linalg.cpp | 77 +++++++-------------- python/src/buffer.h | 2 + python/src/convert.cpp | 2 + python/tests/test_double.py | 109 +++++++++++++++++++++++++++++ 11 files changed, 338 insertions(+), 225 deletions(-) diff --git a/mlx/backend/cpu/cholesky.cpp b/mlx/backend/cpu/cholesky.cpp index 33668159a..52a39c7c1 100644 --- a/mlx/backend/cpu/cholesky.cpp +++ b/mlx/backend/cpu/cholesky.cpp @@ -8,6 +8,7 @@ namespace mlx::core { +template void cholesky_impl(const array& a, array& factor, bool upper) { // Lapack uses the column-major convention. We take advantage of the fact that // the matrix should be symmetric: @@ -28,13 +29,12 @@ void cholesky_impl(const array& a, array& factor, bool upper) { const int N = a.shape(-1); const size_t num_matrices = a.size() / (N * N); - float* matrix = factor.data(); + T* matrix = factor.data(); for (int i = 0; i < num_matrices; i++) { // Compute Cholesky factorization. int info; - MLX_LAPACK_FUNC(spotrf) - ( + potrf( /* uplo = */ &uplo, /* n = */ &N, /* a = */ matrix, @@ -65,10 +65,17 @@ void cholesky_impl(const array& a, array& factor, bool upper) { } void Cholesky::eval_cpu(const std::vector& inputs, array& output) { - if (inputs[0].dtype() != float32) { - throw std::runtime_error("[Cholesky::eval] only supports float32."); + switch (inputs[0].dtype()) { + case float32: + cholesky_impl(inputs[0], output, upper_); + break; + case float64: + cholesky_impl(inputs[0], output, upper_); + break; + default: + throw std::runtime_error( + "[Cholesky::eval_cpu] only supports float32 or float64."); } - cholesky_impl(inputs[0], output, upper_); } } // namespace mlx::core diff --git a/mlx/backend/cpu/eigh.cpp b/mlx/backend/cpu/eigh.cpp index be5e379f0..c9ec2875f 100644 --- a/mlx/backend/cpu/eigh.cpp +++ b/mlx/backend/cpu/eigh.cpp @@ -11,35 +11,64 @@ namespace mlx::core { namespace { -void ssyevd( - char jobz, - char uplo, - float* a, - int N, - float* w, - float* work, - int lwork, - int* iwork, - int liwork) { +template +void eigh_impl( + array& vectors, + array& values, + const std::string& uplo, + bool compute_eigenvectors) { + auto vec_ptr = vectors.data(); + auto eig_ptr = values.data(); + + char jobz = compute_eigenvectors ? 'V' : 'N'; + auto N = vectors.shape(-1); + + // Work query + int lwork = -1; + int liwork = -1; int info; - MLX_LAPACK_FUNC(ssyevd) - ( - /* jobz = */ &jobz, - /* uplo = */ &uplo, - /* n = */ &N, - /* a = */ a, - /* lda = */ &N, - /* w = */ w, - /* work = */ work, - /* lwork = */ &lwork, - /* iwork = */ iwork, - /* liwork = */ &liwork, - /* info = */ &info); - if (info != 0) { - std::stringstream msg; - msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " - << info; - throw std::runtime_error(msg.str()); + { + T work; + int iwork; + syevd( + &jobz, + uplo.c_str(), + &N, + nullptr, + &N, + nullptr, + &work, + &lwork, + &iwork, + &liwork, + &info); + lwork = static_cast(work); + liwork = iwork; + } + + auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; + auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)}; + for (size_t i = 0; i < vectors.size() / (N * N); ++i) { + syevd( + &jobz, + uplo.c_str(), + &N, + vec_ptr, + &N, + eig_ptr, + static_cast(work_buf.buffer.raw_ptr()), + &lwork, + static_cast(iwork_buf.buffer.raw_ptr()), + &liwork, + &info); + vec_ptr += N * N; + eig_ptr += N; + if (info != 0) { + std::stringstream msg; + msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " + << info; + throw std::runtime_error(msg.str()); + } } } @@ -80,39 +109,16 @@ void Eigh::eval_cpu( } vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size()); } - - auto vec_ptr = vectors.data(); - auto eig_ptr = values.data(); - - char jobz = compute_eigenvectors_ ? 'V' : 'N'; - auto N = a.shape(-1); - - // Work query - int lwork; - int liwork; - { - float work; - int iwork; - ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1); - lwork = static_cast(work); - liwork = iwork; - } - - auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; - auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)}; - for (size_t i = 0; i < a.size() / (N * N); ++i) { - ssyevd( - jobz, - uplo_[0], - vec_ptr, - N, - eig_ptr, - static_cast(work_buf.buffer.raw_ptr()), - lwork, - static_cast(iwork_buf.buffer.raw_ptr()), - liwork); - vec_ptr += N * N; - eig_ptr += N; + switch (a.dtype()) { + case float32: + eigh_impl(vectors, values, uplo_, compute_eigenvectors_); + break; + case float64: + eigh_impl(vectors, values, uplo_, compute_eigenvectors_); + break; + default: + throw std::runtime_error( + "[Eigh::eval_cpu] only supports float32 or float64."); } } diff --git a/mlx/backend/cpu/inverse.cpp b/mlx/backend/cpu/inverse.cpp index 81cabb79d..aba038218 100644 --- a/mlx/backend/cpu/inverse.cpp +++ b/mlx/backend/cpu/inverse.cpp @@ -5,44 +5,33 @@ #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" -int strtri_wrapper(char uplo, char diag, float* matrix, int N) { - int info; - MLX_LAPACK_FUNC(strtri) - ( - /* uplo = */ &uplo, - /* diag = */ &diag, - /* N = */ &N, - /* a = */ matrix, - /* lda = */ &N, - /* info = */ &info); - return info; -} - namespace mlx::core { +template void general_inv(array& inv, int N, int i) { int info; auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; // Compute LU factorization. - sgetrf_( + getrf( /* m = */ &N, /* n = */ &N, - /* a = */ inv.data() + N * N * i, + /* a = */ inv.data() + N * N * i, /* lda = */ &N, /* ipiv = */ static_cast(ipiv.buffer.raw_ptr()), /* info = */ &info); if (info != 0) { std::stringstream ss; - ss << "inverse_impl: LU factorization failed with error code " << info; + ss << "[Inverse::eval_cpu] LU factorization failed with error code " + << info; throw std::runtime_error(ss.str()); } static const int lwork_query = -1; - float workspace_size = 0; + T workspace_size = 0; // Compute workspace size. - sgetri_( + getri( /* m = */ &N, /* a = */ nullptr, /* lda = */ &N, @@ -53,36 +42,44 @@ void general_inv(array& inv, int N, int i) { if (info != 0) { std::stringstream ss; - ss << "inverse_impl: LU workspace calculation failed with error code " + ss << "[Inverse::eval_cpu] LU workspace calculation failed with error code " << info; throw std::runtime_error(ss.str()); } const int lwork = workspace_size; - auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; + auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; // Compute inverse. - sgetri_( + getri( /* m = */ &N, - /* a = */ inv.data() + N * N * i, + /* a = */ inv.data() + N * N * i, /* lda = */ &N, /* ipiv = */ static_cast(ipiv.buffer.raw_ptr()), - /* work = */ static_cast(scratch.buffer.raw_ptr()), + /* work = */ static_cast(scratch.buffer.raw_ptr()), /* lwork = */ &lwork, /* info = */ &info); if (info != 0) { std::stringstream ss; - ss << "inverse_impl: inversion failed with error code " << info; + ss << "[Inverse::eval_cpu] inversion failed with error code " << info; throw std::runtime_error(ss.str()); } } +template void tri_inv(array& inv, int N, int i, bool upper) { const char uplo = upper ? 'L' : 'U'; const char diag = 'N'; - float* data = inv.data() + N * N * i; - int info = strtri_wrapper(uplo, diag, data, N); + T* data = inv.data() + N * N * i; + int info; + trtri( + /* uplo = */ &uplo, + /* diag = */ &diag, + /* N = */ &N, + /* a = */ data, + /* lda = */ &N, + /* info = */ &info); // zero out the other triangle if (upper) { @@ -99,11 +96,13 @@ void tri_inv(array& inv, int N, int i, bool upper) { if (info != 0) { std::stringstream ss; - ss << "inverse_impl: triangular inversion failed with error code " << info; + ss << "[Inverse::eval_cpu] triangular inversion failed with error code " + << info; throw std::runtime_error(ss.str()); } } +template void inverse_impl(const array& a, array& inv, bool tri, bool upper) { // Lapack uses the column-major convention. We take advantage of the following // identity to avoid transposing (see @@ -118,18 +117,25 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) { for (int i = 0; i < num_matrices; i++) { if (tri) { - tri_inv(inv, N, i, upper); + tri_inv(inv, N, i, upper); } else { - general_inv(inv, N, i); + general_inv(inv, N, i); } } } void Inverse::eval_cpu(const std::vector& inputs, array& output) { - if (inputs[0].dtype() != float32) { - throw std::runtime_error("[Inverse::eval] only supports float32."); + switch (inputs[0].dtype()) { + case float32: + inverse_impl(inputs[0], output, tri_, upper_); + break; + case float64: + inverse_impl(inputs[0], output, tri_, upper_); + break; + default: + throw std::runtime_error( + "[Inverse::eval_cpu] only supports float32 or float64."); } - inverse_impl(inputs[0], output, tri_, upper_); } } // namespace mlx::core diff --git a/mlx/backend/cpu/lapack.h b/mlx/backend/cpu/lapack.h index dc262a0ff..2911c63f8 100644 --- a/mlx/backend/cpu/lapack.h +++ b/mlx/backend/cpu/lapack.h @@ -31,3 +31,22 @@ #define MLX_LAPACK_FUNC(f) f##_ #endif + +#define INSTANTIATE_LAPACK_TYPES(FUNC) \ + template \ + void FUNC(Args... args) { \ + if constexpr (std::is_same_v) { \ + MLX_LAPACK_FUNC(s##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v) { \ + MLX_LAPACK_FUNC(d##FUNC)(std::forward(args)...); \ + } \ + } + +INSTANTIATE_LAPACK_TYPES(geqrf) +INSTANTIATE_LAPACK_TYPES(orgqr) +INSTANTIATE_LAPACK_TYPES(syevd) +INSTANTIATE_LAPACK_TYPES(potrf) +INSTANTIATE_LAPACK_TYPES(gesvdx) +INSTANTIATE_LAPACK_TYPES(getrf) +INSTANTIATE_LAPACK_TYPES(getri) +INSTANTIATE_LAPACK_TYPES(trtri) diff --git a/mlx/backend/cpu/luf.cpp b/mlx/backend/cpu/luf.cpp index e055f4cac..87de97c72 100644 --- a/mlx/backend/cpu/luf.cpp +++ b/mlx/backend/cpu/luf.cpp @@ -9,11 +9,8 @@ namespace mlx::core { -void lu_factor_impl( - const array& a, - array& lu, - array& pivots, - array& row_indices) { +template +void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) { int M = a.shape(-2); int N = a.shape(-1); @@ -31,7 +28,7 @@ void lu_factor_impl( copy_inplace( a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral); - auto a_ptr = lu.data(); + auto a_ptr = lu.data(); pivots.set_data(allocator::malloc_or_wait(pivots.nbytes())); row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes())); @@ -42,13 +39,13 @@ void lu_factor_impl( size_t num_matrices = a.size() / (M * N); for (size_t i = 0; i < num_matrices; ++i) { // Compute LU factorization of A - MLX_LAPACK_FUNC(sgetrf) - (/* m */ &M, - /* n */ &N, - /* a */ a_ptr, - /* lda */ &M, - /* ipiv */ reinterpret_cast(pivots_ptr), - /* info */ &info); + getrf( + /* m */ &M, + /* n */ &N, + /* a */ a_ptr, + /* lda */ &M, + /* ipiv */ reinterpret_cast(pivots_ptr), + /* info */ &info); if (info != 0) { std::stringstream ss; @@ -86,7 +83,17 @@ void LUF::eval_cpu( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 1); - lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]); + switch (inputs[0].dtype()) { + case float32: + luf_impl(inputs[0], outputs[0], outputs[1], outputs[2]); + break; + case float64: + luf_impl(inputs[0], outputs[0], outputs[1], outputs[2]); + break; + default: + throw std::runtime_error( + "[LUF::eval_cpu] only supports float32 or float64."); + } } } // namespace mlx::core diff --git a/mlx/backend/cpu/qrf.cpp b/mlx/backend/cpu/qrf.cpp index d7caa8b68..537b63358 100644 --- a/mlx/backend/cpu/qrf.cpp +++ b/mlx/backend/cpu/qrf.cpp @@ -7,36 +7,6 @@ namespace mlx::core { -template -struct lpack; - -template <> -struct lpack { - static void xgeqrf( - const int* m, - const int* n, - float* a, - const int* lda, - float* tau, - float* work, - const int* lwork, - int* info) { - sgeqrf_(m, n, a, lda, tau, work, lwork, info); - } - static void xorgqr( - const int* m, - const int* n, - const int* k, - float* a, - const int* lda, - const float* tau, - float* work, - const int* lwork, - int* info) { - sorgqr_(m, n, k, a, lda, tau, work, lwork, info); - } -}; - template void qrf_impl(const array& a, array& q, array& r) { const int M = a.shape(-2); @@ -48,7 +18,7 @@ void qrf_impl(const array& a, array& q, array& r) { allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors); // Copy A to inplace input and make it col-contiguous - array in(a.shape(), float32, nullptr, {}); + array in(a.shape(), a.dtype(), nullptr, {}); auto flags = in.flags(); // Copy the input to be column contiguous @@ -66,8 +36,7 @@ void qrf_impl(const array& a, array& q, array& r) { int info; // Compute workspace size - lpack::xgeqrf( - &M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info); + geqrf(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info); // Update workspace size lwork = optimal_work; @@ -76,10 +45,10 @@ void qrf_impl(const array& a, array& q, array& r) { // Loop over matrices for (int i = 0; i < num_matrices; ++i) { // Solve - lpack::xgeqrf( + geqrf( &M, &N, - in.data() + M * N * i, + in.data() + M * N * i, &lda, static_cast(tau.raw_ptr()) + num_reflectors * i, static_cast(work.raw_ptr()), @@ -105,7 +74,7 @@ void qrf_impl(const array& a, array& q, array& r) { // Get work size lwork = -1; - lpack::xorgqr( + orgqr( &M, &num_reflectors, &num_reflectors, @@ -121,11 +90,11 @@ void qrf_impl(const array& a, array& q, array& r) { // Loop over matrices for (int i = 0; i < num_matrices; ++i) { // Compute Q - lpack::xorgqr( + orgqr( &M, &num_reflectors, &num_reflectors, - in.data() + M * N * i, + in.data() + M * N * i, &lda, static_cast(tau.raw_ptr()) + num_reflectors * i, static_cast(work.raw_ptr()), @@ -152,10 +121,17 @@ void qrf_impl(const array& a, array& q, array& r) { void QRF::eval_cpu( const std::vector& inputs, std::vector& outputs) { - if (!(inputs[0].dtype() == float32)) { - throw std::runtime_error("[QRF::eval] only supports float32."); + switch (inputs[0].dtype()) { + case float32: + qrf_impl(inputs[0], outputs[0], outputs[1]); + break; + case float64: + qrf_impl(inputs[0], outputs[0], outputs[1]); + break; + default: + throw std::runtime_error( + "[QRF::eval_cpu] only supports float32 or float64."); } - qrf_impl(inputs[0], outputs[0], outputs[1]); } } // namespace mlx::core diff --git a/mlx/backend/cpu/svd.cpp b/mlx/backend/cpu/svd.cpp index f18ab4f91..33a30d843 100644 --- a/mlx/backend/cpu/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -7,6 +7,7 @@ namespace mlx::core { +template void svd_impl(const array& a, array& u, array& s, array& vt) { // Lapack uses the column-major convention. To avoid having to transpose // the input and then transpose the outputs, we swap the indices/sizes of the @@ -31,7 +32,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { size_t num_matrices = a.size() / (M * N); // lapack clobbers the input, so we have to make a copy. - array in(a.shape(), float32, nullptr, {}); + array in(a.shape(), a.dtype(), nullptr, {}); copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); // Allocate outputs. @@ -45,7 +46,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { // Will contain the number of singular values after the call has returned. int ns = 0; - float workspace_dimension = 0; + T workspace_dimension = 0; // Will contain the indices of eigenvectors that failed to converge (not used // here but required by lapack). @@ -54,13 +55,12 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { static const int lwork_query = -1; static const int ignored_int = 0; - static const float ignored_float = 0; + static const T ignored_float = 0; int info; // Compute workspace size. - MLX_LAPACK_FUNC(sgesvdx) - ( + gesvdx( /* jobu = */ job_u, /* jobvt = */ job_vt, /* range = */ range, @@ -86,51 +86,50 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { if (info != 0) { std::stringstream ss; - ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info; + ss << "[SVD::eval_cpu] workspace calculation failed with code " << info; throw std::runtime_error(ss.str()); } const int lwork = workspace_dimension; - auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; + auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; // Loop over matrices. for (int i = 0; i < num_matrices; i++) { - MLX_LAPACK_FUNC(sgesvdx) - ( + gesvdx( /* jobu = */ job_u, /* jobvt = */ job_vt, /* range = */ range, // M and N are swapped since lapack expects column-major. /* m = */ &N, /* n = */ &M, - /* a = */ in.data() + M * N * i, + /* a = */ in.data() + M * N * i, /* lda = */ &lda, /* vl = */ &ignored_float, /* vu = */ &ignored_float, /* il = */ &ignored_int, /* iu = */ &ignored_int, /* ns = */ &ns, - /* s = */ s.data() + K * i, + /* s = */ s.data() + K * i, // According to the identity above, lapack will write Vᵀᵀ as U. - /* u = */ vt.data() + N * N * i, + /* u = */ vt.data() + N * N * i, /* ldu = */ &ldu, // According to the identity above, lapack will write Uᵀ as Vᵀ. - /* vt = */ u.data() + M * M * i, + /* vt = */ u.data() + M * M * i, /* ldvt = */ &ldvt, - /* work = */ static_cast(scratch.buffer.raw_ptr()), + /* work = */ static_cast(scratch.buffer.raw_ptr()), /* lwork = */ &lwork, /* iwork = */ static_cast(iwork.buffer.raw_ptr()), /* info = */ &info); if (info != 0) { std::stringstream ss; - ss << "svd_impl: sgesvdx_ failed with code " << info; + ss << "[SVD::eval_cpu] failed with code " << info; throw std::runtime_error(ss.str()); } if (ns != K) { std::stringstream ss; - ss << "svd_impl: expected " << K << " singular values, but " << ns + ss << "[SVD::eval_cpu] expected " << K << " singular values, but " << ns << " were computed."; throw std::runtime_error(ss.str()); } @@ -140,10 +139,17 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { void SVD::eval_cpu( const std::vector& inputs, std::vector& outputs) { - if (!(inputs[0].dtype() == float32)) { - throw std::runtime_error("[SVD::eval] only supports float32."); + switch (inputs[0].dtype()) { + case float32: + svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]); + break; + case float64: + svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]); + break; + default: + throw std::runtime_error( + "[SVD::eval_cpu] only supports float32 or float64."); } - svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]); } } // namespace mlx::core diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 01aa9b7ff..356d39626 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -18,6 +18,14 @@ void check_cpu_stream(const StreamOrDevice& s, const std::string& prefix) { "Explicitly pass a CPU stream to run it."); } } +void check_float(Dtype dtype, const std::string& prefix) { + if (dtype != float32 && dtype != float64) { + std::ostringstream msg; + msg << prefix << " Arrays must have type float32 or float64. " + << "Received array with type " << dtype << "."; + throw std::invalid_argument(msg.str()); + } +} Dtype at_least_float(const Dtype& d) { return issubdtype(d, inexact) ? d : promote_types(d, float32); @@ -184,12 +192,8 @@ array norm( std::pair qr(const array& a, StreamOrDevice s /* = {} */) { check_cpu_stream(s, "[linalg::qr]"); - if (a.dtype() != float32) { - std::ostringstream msg; - msg << "[linalg::qr] Arrays must type float32. Received array " - << "with type " << a.dtype() << "."; - throw std::invalid_argument(msg.str()); - } + check_float(a.dtype(), "[linalg::qr]"); + if (a.ndim() < 2) { std::ostringstream msg; msg << "[linalg::qr] Arrays must have >= 2 dimensions. Received array " @@ -212,12 +216,8 @@ std::pair qr(const array& a, StreamOrDevice s /* = {} */) { std::vector svd(const array& a, StreamOrDevice s /* = {} */) { check_cpu_stream(s, "[linalg::svd]"); - if (a.dtype() != float32) { - std::ostringstream msg; - msg << "[linalg::svd] Input array must have type float32. Received array " - << "with type " << a.dtype() << "."; - throw std::invalid_argument(msg.str()); - } + check_float(a.dtype(), "[linalg::svd]"); + if (a.ndim() < 2) { std::ostringstream msg; msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array " @@ -251,12 +251,8 @@ std::vector svd(const array& a, StreamOrDevice s /* = {} */) { array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) { check_cpu_stream(s, "[linalg::inv]"); - if (a.dtype() != float32) { - std::ostringstream msg; - msg << "[linalg::inv] Arrays must type float32. Received array " - << "with type " << a.dtype() << "."; - throw std::invalid_argument(msg.str()); - } + check_float(a.dtype(), "[linalg::inv]"); + if (a.ndim() < 2) { std::ostringstream msg; msg << "[linalg::inv] Arrays must have >= 2 dimensions. Received array " @@ -292,13 +288,7 @@ array cholesky( bool upper /* = false */, StreamOrDevice s /* = {} */) { check_cpu_stream(s, "[linalg::cholesky]"); - if (a.dtype() != float32) { - std::ostringstream msg; - msg << "[linalg::cholesky] Arrays must type float32. Received array " - << "with type " << a.dtype() << "."; - throw std::invalid_argument(msg.str()); - } - + check_float(a.dtype(), "[linalg::cholesky]"); if (a.ndim() < 2) { std::ostringstream msg; msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array " @@ -321,12 +311,8 @@ array cholesky( array pinv(const array& a, StreamOrDevice s /* = {} */) { check_cpu_stream(s, "[linalg::pinv]"); - if (a.dtype() != float32) { - std::ostringstream msg; - msg << "[linalg::pinv] Arrays must type float32. Received array " - << "with type " << a.dtype() << "."; - throw std::invalid_argument(msg.str()); - } + check_float(a.dtype(), "[linalg::pinv]"); + if (a.ndim() < 2) { std::ostringstream msg; msg << "[linalg::pinv] Arrays must have >= 2 dimensions. Received array " @@ -368,12 +354,7 @@ array cholesky_inv( bool upper /* = false */, StreamOrDevice s /* = {} */) { check_cpu_stream(s, "[linalg::cholesky_inv]"); - if (L.dtype() != float32) { - std::ostringstream msg; - msg << "[linalg::cholesky_inv] Arrays must type float32. Received array " - << "with type " << L.dtype() << "."; - throw std::invalid_argument(msg.str()); - } + check_float(L.dtype(), "[linalg::cholesky_inv]"); if (L.ndim() < 2) { std::ostringstream msg; @@ -474,12 +455,7 @@ void validate_eigh( const StreamOrDevice& stream, const std::string fname) { check_cpu_stream(stream, fname); - if (a.dtype() != float32) { - std::ostringstream msg; - msg << fname << " Arrays must have type float32. Received array " - << "with type " << a.dtype() << "."; - throw std::invalid_argument(msg.str()); - } + check_float(a.dtype(), fname); if (a.ndim() < 2) { std::ostringstream msg; @@ -524,12 +500,7 @@ void validate_lu( const StreamOrDevice& stream, const std::string& fname) { check_cpu_stream(stream, fname); - if (a.dtype() != float32) { - std::ostringstream msg; - msg << fname << " Arrays must type float32. Received array " - << "with type " << a.dtype() << "."; - throw std::invalid_argument(msg.str()); - } + check_float(a.dtype(), fname); if (a.ndim() < 2) { std::ostringstream msg; @@ -627,10 +598,12 @@ void validate_solve( } auto out_type = promote_types(a.dtype(), b.dtype()); - if (out_type != float32) { + if (out_type != float32 && out_type != float64) { std::ostringstream msg; - msg << fname << " Input arrays must promote to float32. Received arrays " - << "with type " << a.dtype() << " and " << b.dtype() << "."; + msg << fname + << " Input arrays must promote to float32 or float64. " + " Received arrays with type " + << a.dtype() << " and " << b.dtype() << "."; throw std::invalid_argument(msg.str()); } } diff --git a/python/src/buffer.h b/python/src/buffer.h index cca832686..272a91888 100644 --- a/python/src/buffer.h +++ b/python/src/buffer.h @@ -44,6 +44,8 @@ std::string buffer_format(const mx::array& a) { return "f"; case mx::bfloat16: return "B"; + case mx::float64: + return "d"; case mx::complex64: return "Zf\0"; default: { diff --git a/python/src/convert.cpp b/python/src/convert.cpp index b88a5832a..00f8395fc 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -152,6 +152,8 @@ nb::ndarray mlx_to_nd_array(const mx::array& a) { throw nb::type_error("bfloat16 arrays cannot be converted to NumPy."); case mx::float32: return mlx_to_nd_array_impl(a); + case mx::float64: + return mlx_to_nd_array_impl(a); case mx::complex64: return mlx_to_nd_array_impl, NDParams...>(a); default: diff --git a/python/tests/test_double.py b/python/tests/test_double.py index 8de3f3cea..10fce0db1 100644 --- a/python/tests/test_double.py +++ b/python/tests/test_double.py @@ -183,6 +183,115 @@ class TestDouble(mlx_tests.MLXTestCase): c = a + b self.assertEqual(c.dtype, mx.float64) + def test_lapack(self): + with mx.stream(mx.cpu): + # QRF + A = mx.array([[2.0, 3.0], [1.0, 2.0]], dtype=mx.float64) + Q, R = mx.linalg.qr(A) + out = Q @ R + self.assertTrue(mx.allclose(out, A)) + out = Q.T @ Q + self.assertTrue(mx.allclose(out, mx.eye(2))) + self.assertTrue(mx.allclose(mx.tril(R, -1), mx.zeros_like(R))) + self.assertEqual(Q.dtype, mx.float64) + self.assertEqual(R.dtype, mx.float64) + + # SVD + A = mx.array( + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64 + ) + U, S, Vt = mx.linalg.svd(A) + self.assertTrue(mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A)) + + # Inverse + A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64) + A_inv = mx.linalg.inv(A) + self.assertTrue(mx.allclose(A @ A_inv, mx.eye(A.shape[0]))) + + # Tri inv + A = mx.array([[1, 0, 0], [6, -5, 0], [-9, 8, 7]], dtype=mx.float64) + B = mx.array([[7, 0, 0], [3, -2, 0], [1, 8, 3]], dtype=mx.float64) + AB = mx.stack([A, B]) + invs = mx.linalg.tri_inv(AB, upper=False) + for M, M_inv in zip(AB, invs): + self.assertTrue(mx.allclose(M @ M_inv, mx.eye(M.shape[0]))) + + # Cholesky + sqrtA = mx.array( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float64 + ) + A = sqrtA.T @ sqrtA / 81 + L = mx.linalg.cholesky(A) + U = mx.linalg.cholesky(A, upper=True) + self.assertTrue(mx.allclose(L @ L.T, A)) + self.assertTrue(mx.allclose(U.T @ U, A)) + + # Psueod inverse + A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64) + A_plus = mx.linalg.pinv(A) + self.assertTrue(mx.allclose(A @ A_plus @ A, A)) + + # Eigh + def check_eigs_and_vecs(A_np, kwargs={}): + A = mx.array(A_np, dtype=mx.float64) + eig_vals, eig_vecs = mx.linalg.eigh(A, **kwargs) + eig_vals_np, _ = np.linalg.eigh(A_np, **kwargs) + self.assertTrue(np.allclose(eig_vals, eig_vals_np)) + self.assertTrue( + mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs) + ) + + eig_vals_only = mx.linalg.eigvalsh(A, **kwargs) + self.assertTrue(mx.allclose(eig_vals, eig_vals_only)) + + # Test a simple 2x2 symmetric matrix + A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float64) + check_eigs_and_vecs(A_np) + + # Test a larger random symmetric matrix + n = 5 + np.random.seed(1) + A_np = np.random.randn(n, n).astype(np.float64) + A_np = (A_np + A_np.T) / 2 + check_eigs_and_vecs(A_np) + + # Test with upper triangle + check_eigs_and_vecs(A_np, {"UPLO": "U"}) + + # LU factorization + # Test 3x3 matrix + a = mx.array( + [[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]], dtype=mx.float64 + ) + P, L, U = mx.linalg.lu(a) + self.assertTrue(mx.allclose(L[P, :] @ U, a)) + + # Solve triangular + # Test lower triangular matrix + a = mx.array( + [[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]], dtype=mx.float64 + ) + b = mx.array([8.0, 14.0, 3.0], dtype=mx.float64) + + result = mx.linalg.solve_triangular(a, b, upper=False) + expected = np.linalg.solve(np.array(a), np.array(b)) + self.assertTrue(np.allclose(result, expected)) + + # Test upper triangular matrix + a = mx.array( + [[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]], dtype=mx.float64 + ) + b = mx.array([13.0, 33.0, 18.0], dtype=mx.float64) + + result = mx.linalg.solve_triangular(a, b, upper=True) + expected = np.linalg.solve(np.array(a), np.array(b)) + self.assertTrue(np.allclose(result, expected)) + + def test_conversion(self): + a = mx.array([1.0, 2.0], mx.float64) + b = np.array(a) + self.assertTrue(np.array_equal(a, b)) + if __name__ == "__main__": unittest.main()