mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Double for lapack (#1904)
* double for lapack ops * add double support for lapack ops
This commit is contained in:
		| @@ -8,6 +8,7 @@ | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| template <typename T> | ||||
| void cholesky_impl(const array& a, array& factor, bool upper) { | ||||
|   // Lapack uses the column-major convention. We take advantage of the fact that | ||||
|   // the matrix should be symmetric: | ||||
| @@ -28,13 +29,12 @@ void cholesky_impl(const array& a, array& factor, bool upper) { | ||||
|   const int N = a.shape(-1); | ||||
|   const size_t num_matrices = a.size() / (N * N); | ||||
|  | ||||
|   float* matrix = factor.data<float>(); | ||||
|   T* matrix = factor.data<T>(); | ||||
|  | ||||
|   for (int i = 0; i < num_matrices; i++) { | ||||
|     // Compute Cholesky factorization. | ||||
|     int info; | ||||
|     MLX_LAPACK_FUNC(spotrf) | ||||
|     ( | ||||
|     potrf<T>( | ||||
|         /* uplo = */ &uplo, | ||||
|         /* n = */ &N, | ||||
|         /* a = */ matrix, | ||||
| @@ -65,10 +65,17 @@ void cholesky_impl(const array& a, array& factor, bool upper) { | ||||
| } | ||||
|  | ||||
| void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) { | ||||
|   if (inputs[0].dtype() != float32) { | ||||
|     throw std::runtime_error("[Cholesky::eval] only supports float32."); | ||||
|   switch (inputs[0].dtype()) { | ||||
|     case float32: | ||||
|       cholesky_impl<float>(inputs[0], output, upper_); | ||||
|       break; | ||||
|     case float64: | ||||
|       cholesky_impl<double>(inputs[0], output, upper_); | ||||
|       break; | ||||
|     default: | ||||
|       throw std::runtime_error( | ||||
|           "[Cholesky::eval_cpu] only supports float32 or float64."); | ||||
|   } | ||||
|   cholesky_impl(inputs[0], output, upper_); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -11,35 +11,64 @@ namespace mlx::core { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| void ssyevd( | ||||
|     char jobz, | ||||
|     char uplo, | ||||
|     float* a, | ||||
|     int N, | ||||
|     float* w, | ||||
|     float* work, | ||||
|     int lwork, | ||||
|     int* iwork, | ||||
|     int liwork) { | ||||
| template <typename T> | ||||
| void eigh_impl( | ||||
|     array& vectors, | ||||
|     array& values, | ||||
|     const std::string& uplo, | ||||
|     bool compute_eigenvectors) { | ||||
|   auto vec_ptr = vectors.data<T>(); | ||||
|   auto eig_ptr = values.data<T>(); | ||||
|  | ||||
|   char jobz = compute_eigenvectors ? 'V' : 'N'; | ||||
|   auto N = vectors.shape(-1); | ||||
|  | ||||
|   // Work query | ||||
|   int lwork = -1; | ||||
|   int liwork = -1; | ||||
|   int info; | ||||
|   MLX_LAPACK_FUNC(ssyevd) | ||||
|   ( | ||||
|       /* jobz = */ &jobz, | ||||
|       /* uplo = */ &uplo, | ||||
|       /* n = */ &N, | ||||
|       /* a = */ a, | ||||
|       /* lda = */ &N, | ||||
|       /* w = */ w, | ||||
|       /* work = */ work, | ||||
|       /* lwork = */ &lwork, | ||||
|       /* iwork = */ iwork, | ||||
|       /* liwork = */ &liwork, | ||||
|       /* info = */ &info); | ||||
|   if (info != 0) { | ||||
|     std::stringstream msg; | ||||
|     msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " | ||||
|         << info; | ||||
|     throw std::runtime_error(msg.str()); | ||||
|   { | ||||
|     T work; | ||||
|     int iwork; | ||||
|     syevd<T>( | ||||
|         &jobz, | ||||
|         uplo.c_str(), | ||||
|         &N, | ||||
|         nullptr, | ||||
|         &N, | ||||
|         nullptr, | ||||
|         &work, | ||||
|         &lwork, | ||||
|         &iwork, | ||||
|         &liwork, | ||||
|         &info); | ||||
|     lwork = static_cast<int>(work); | ||||
|     liwork = iwork; | ||||
|   } | ||||
|  | ||||
|   auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; | ||||
|   auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)}; | ||||
|   for (size_t i = 0; i < vectors.size() / (N * N); ++i) { | ||||
|     syevd<T>( | ||||
|         &jobz, | ||||
|         uplo.c_str(), | ||||
|         &N, | ||||
|         vec_ptr, | ||||
|         &N, | ||||
|         eig_ptr, | ||||
|         static_cast<T*>(work_buf.buffer.raw_ptr()), | ||||
|         &lwork, | ||||
|         static_cast<int*>(iwork_buf.buffer.raw_ptr()), | ||||
|         &liwork, | ||||
|         &info); | ||||
|     vec_ptr += N * N; | ||||
|     eig_ptr += N; | ||||
|     if (info != 0) { | ||||
|       std::stringstream msg; | ||||
|       msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " | ||||
|           << info; | ||||
|       throw std::runtime_error(msg.str()); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| @@ -80,39 +109,16 @@ void Eigh::eval_cpu( | ||||
|     } | ||||
|     vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size()); | ||||
|   } | ||||
|  | ||||
|   auto vec_ptr = vectors.data<float>(); | ||||
|   auto eig_ptr = values.data<float>(); | ||||
|  | ||||
|   char jobz = compute_eigenvectors_ ? 'V' : 'N'; | ||||
|   auto N = a.shape(-1); | ||||
|  | ||||
|   // Work query | ||||
|   int lwork; | ||||
|   int liwork; | ||||
|   { | ||||
|     float work; | ||||
|     int iwork; | ||||
|     ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1); | ||||
|     lwork = static_cast<int>(work); | ||||
|     liwork = iwork; | ||||
|   } | ||||
|  | ||||
|   auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; | ||||
|   auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)}; | ||||
|   for (size_t i = 0; i < a.size() / (N * N); ++i) { | ||||
|     ssyevd( | ||||
|         jobz, | ||||
|         uplo_[0], | ||||
|         vec_ptr, | ||||
|         N, | ||||
|         eig_ptr, | ||||
|         static_cast<float*>(work_buf.buffer.raw_ptr()), | ||||
|         lwork, | ||||
|         static_cast<int*>(iwork_buf.buffer.raw_ptr()), | ||||
|         liwork); | ||||
|     vec_ptr += N * N; | ||||
|     eig_ptr += N; | ||||
|   switch (a.dtype()) { | ||||
|     case float32: | ||||
|       eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_); | ||||
|       break; | ||||
|     case float64: | ||||
|       eigh_impl<double>(vectors, values, uplo_, compute_eigenvectors_); | ||||
|       break; | ||||
|     default: | ||||
|       throw std::runtime_error( | ||||
|           "[Eigh::eval_cpu] only supports float32 or float64."); | ||||
|   } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -5,44 +5,33 @@ | ||||
| #include "mlx/backend/cpu/lapack.h" | ||||
| #include "mlx/primitives.h" | ||||
|  | ||||
| int strtri_wrapper(char uplo, char diag, float* matrix, int N) { | ||||
|   int info; | ||||
|   MLX_LAPACK_FUNC(strtri) | ||||
|   ( | ||||
|       /* uplo = */ &uplo, | ||||
|       /* diag = */ &diag, | ||||
|       /* N = */ &N, | ||||
|       /* a = */ matrix, | ||||
|       /* lda = */ &N, | ||||
|       /* info = */ &info); | ||||
|   return info; | ||||
| } | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| template <typename T> | ||||
| void general_inv(array& inv, int N, int i) { | ||||
|   int info; | ||||
|   auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; | ||||
|   // Compute LU factorization. | ||||
|   sgetrf_( | ||||
|   getrf<T>( | ||||
|       /* m = */ &N, | ||||
|       /* n = */ &N, | ||||
|       /* a = */ inv.data<float>() + N * N * i, | ||||
|       /* a = */ inv.data<T>() + N * N * i, | ||||
|       /* lda = */ &N, | ||||
|       /* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()), | ||||
|       /* info = */ &info); | ||||
|  | ||||
|   if (info != 0) { | ||||
|     std::stringstream ss; | ||||
|     ss << "inverse_impl: LU factorization failed with error code " << info; | ||||
|     ss << "[Inverse::eval_cpu] LU factorization failed with error code " | ||||
|        << info; | ||||
|     throw std::runtime_error(ss.str()); | ||||
|   } | ||||
|  | ||||
|   static const int lwork_query = -1; | ||||
|   float workspace_size = 0; | ||||
|   T workspace_size = 0; | ||||
|  | ||||
|   // Compute workspace size. | ||||
|   sgetri_( | ||||
|   getri<T>( | ||||
|       /* m = */ &N, | ||||
|       /* a = */ nullptr, | ||||
|       /* lda = */ &N, | ||||
| @@ -53,36 +42,44 @@ void general_inv(array& inv, int N, int i) { | ||||
|  | ||||
|   if (info != 0) { | ||||
|     std::stringstream ss; | ||||
|     ss << "inverse_impl: LU workspace calculation failed with error code " | ||||
|     ss << "[Inverse::eval_cpu] LU workspace calculation failed with error code " | ||||
|        << info; | ||||
|     throw std::runtime_error(ss.str()); | ||||
|   } | ||||
|  | ||||
|   const int lwork = workspace_size; | ||||
|   auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; | ||||
|   auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; | ||||
|  | ||||
|   // Compute inverse. | ||||
|   sgetri_( | ||||
|   getri<T>( | ||||
|       /* m = */ &N, | ||||
|       /* a = */ inv.data<float>() + N * N * i, | ||||
|       /* a = */ inv.data<T>() + N * N * i, | ||||
|       /* lda = */ &N, | ||||
|       /* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()), | ||||
|       /* work = */ static_cast<float*>(scratch.buffer.raw_ptr()), | ||||
|       /* work = */ static_cast<T*>(scratch.buffer.raw_ptr()), | ||||
|       /* lwork = */ &lwork, | ||||
|       /* info = */ &info); | ||||
|  | ||||
|   if (info != 0) { | ||||
|     std::stringstream ss; | ||||
|     ss << "inverse_impl: inversion failed with error code " << info; | ||||
|     ss << "[Inverse::eval_cpu] inversion failed with error code " << info; | ||||
|     throw std::runtime_error(ss.str()); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| void tri_inv(array& inv, int N, int i, bool upper) { | ||||
|   const char uplo = upper ? 'L' : 'U'; | ||||
|   const char diag = 'N'; | ||||
|   float* data = inv.data<float>() + N * N * i; | ||||
|   int info = strtri_wrapper(uplo, diag, data, N); | ||||
|   T* data = inv.data<T>() + N * N * i; | ||||
|   int info; | ||||
|   trtri<T>( | ||||
|       /* uplo = */ &uplo, | ||||
|       /* diag = */ &diag, | ||||
|       /* N = */ &N, | ||||
|       /* a = */ data, | ||||
|       /* lda = */ &N, | ||||
|       /* info = */ &info); | ||||
|  | ||||
|   // zero out the other triangle | ||||
|   if (upper) { | ||||
| @@ -99,11 +96,13 @@ void tri_inv(array& inv, int N, int i, bool upper) { | ||||
|  | ||||
|   if (info != 0) { | ||||
|     std::stringstream ss; | ||||
|     ss << "inverse_impl: triangular inversion failed with error code " << info; | ||||
|     ss << "[Inverse::eval_cpu] triangular inversion failed with error code " | ||||
|        << info; | ||||
|     throw std::runtime_error(ss.str()); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| void inverse_impl(const array& a, array& inv, bool tri, bool upper) { | ||||
|   // Lapack uses the column-major convention. We take advantage of the following | ||||
|   // identity to avoid transposing (see | ||||
| @@ -118,18 +117,25 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) { | ||||
|  | ||||
|   for (int i = 0; i < num_matrices; i++) { | ||||
|     if (tri) { | ||||
|       tri_inv(inv, N, i, upper); | ||||
|       tri_inv<T>(inv, N, i, upper); | ||||
|     } else { | ||||
|       general_inv(inv, N, i); | ||||
|       general_inv<T>(inv, N, i); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) { | ||||
|   if (inputs[0].dtype() != float32) { | ||||
|     throw std::runtime_error("[Inverse::eval] only supports float32."); | ||||
|   switch (inputs[0].dtype()) { | ||||
|     case float32: | ||||
|       inverse_impl<float>(inputs[0], output, tri_, upper_); | ||||
|       break; | ||||
|     case float64: | ||||
|       inverse_impl<double>(inputs[0], output, tri_, upper_); | ||||
|       break; | ||||
|     default: | ||||
|       throw std::runtime_error( | ||||
|           "[Inverse::eval_cpu] only supports float32 or float64."); | ||||
|   } | ||||
|   inverse_impl(inputs[0], output, tri_, upper_); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -31,3 +31,22 @@ | ||||
| #define MLX_LAPACK_FUNC(f) f##_ | ||||
|  | ||||
| #endif | ||||
|  | ||||
| #define INSTANTIATE_LAPACK_TYPES(FUNC)                       \ | ||||
|   template <typename T, typename... Args>                    \ | ||||
|   void FUNC(Args... args) {                                  \ | ||||
|     if constexpr (std::is_same_v<T, float>) {                \ | ||||
|       MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \ | ||||
|     } else if constexpr (std::is_same_v<T, double>) {        \ | ||||
|       MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \ | ||||
|     }                                                        \ | ||||
|   } | ||||
|  | ||||
| INSTANTIATE_LAPACK_TYPES(geqrf) | ||||
| INSTANTIATE_LAPACK_TYPES(orgqr) | ||||
| INSTANTIATE_LAPACK_TYPES(syevd) | ||||
| INSTANTIATE_LAPACK_TYPES(potrf) | ||||
| INSTANTIATE_LAPACK_TYPES(gesvdx) | ||||
| INSTANTIATE_LAPACK_TYPES(getrf) | ||||
| INSTANTIATE_LAPACK_TYPES(getri) | ||||
| INSTANTIATE_LAPACK_TYPES(trtri) | ||||
|   | ||||
| @@ -9,11 +9,8 @@ | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| void lu_factor_impl( | ||||
|     const array& a, | ||||
|     array& lu, | ||||
|     array& pivots, | ||||
|     array& row_indices) { | ||||
| template <typename T> | ||||
| void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) { | ||||
|   int M = a.shape(-2); | ||||
|   int N = a.shape(-1); | ||||
|  | ||||
| @@ -31,7 +28,7 @@ void lu_factor_impl( | ||||
|   copy_inplace( | ||||
|       a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral); | ||||
|  | ||||
|   auto a_ptr = lu.data<float>(); | ||||
|   auto a_ptr = lu.data<T>(); | ||||
|  | ||||
|   pivots.set_data(allocator::malloc_or_wait(pivots.nbytes())); | ||||
|   row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes())); | ||||
| @@ -42,13 +39,13 @@ void lu_factor_impl( | ||||
|   size_t num_matrices = a.size() / (M * N); | ||||
|   for (size_t i = 0; i < num_matrices; ++i) { | ||||
|     // Compute LU factorization of A | ||||
|     MLX_LAPACK_FUNC(sgetrf) | ||||
|     (/* m */ &M, | ||||
|      /* n */ &N, | ||||
|      /* a */ a_ptr, | ||||
|      /* lda */ &M, | ||||
|      /* ipiv */ reinterpret_cast<int*>(pivots_ptr), | ||||
|      /* info */ &info); | ||||
|     getrf<T>( | ||||
|         /* m */ &M, | ||||
|         /* n */ &N, | ||||
|         /* a */ a_ptr, | ||||
|         /* lda */ &M, | ||||
|         /* ipiv */ reinterpret_cast<int*>(pivots_ptr), | ||||
|         /* info */ &info); | ||||
|  | ||||
|     if (info != 0) { | ||||
|       std::stringstream ss; | ||||
| @@ -86,7 +83,17 @@ void LUF::eval_cpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs) { | ||||
|   assert(inputs.size() == 1); | ||||
|   lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]); | ||||
|   switch (inputs[0].dtype()) { | ||||
|     case float32: | ||||
|       luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]); | ||||
|       break; | ||||
|     case float64: | ||||
|       luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]); | ||||
|       break; | ||||
|     default: | ||||
|       throw std::runtime_error( | ||||
|           "[LUF::eval_cpu] only supports float32 or float64."); | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -7,36 +7,6 @@ | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| template <typename T> | ||||
| struct lpack; | ||||
|  | ||||
| template <> | ||||
| struct lpack<float> { | ||||
|   static void xgeqrf( | ||||
|       const int* m, | ||||
|       const int* n, | ||||
|       float* a, | ||||
|       const int* lda, | ||||
|       float* tau, | ||||
|       float* work, | ||||
|       const int* lwork, | ||||
|       int* info) { | ||||
|     sgeqrf_(m, n, a, lda, tau, work, lwork, info); | ||||
|   } | ||||
|   static void xorgqr( | ||||
|       const int* m, | ||||
|       const int* n, | ||||
|       const int* k, | ||||
|       float* a, | ||||
|       const int* lda, | ||||
|       const float* tau, | ||||
|       float* work, | ||||
|       const int* lwork, | ||||
|       int* info) { | ||||
|     sorgqr_(m, n, k, a, lda, tau, work, lwork, info); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <typename T> | ||||
| void qrf_impl(const array& a, array& q, array& r) { | ||||
|   const int M = a.shape(-2); | ||||
| @@ -48,7 +18,7 @@ void qrf_impl(const array& a, array& q, array& r) { | ||||
|       allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors); | ||||
|  | ||||
|   // Copy A to inplace input and make it col-contiguous | ||||
|   array in(a.shape(), float32, nullptr, {}); | ||||
|   array in(a.shape(), a.dtype(), nullptr, {}); | ||||
|   auto flags = in.flags(); | ||||
|  | ||||
|   // Copy the input to be column contiguous | ||||
| @@ -66,8 +36,7 @@ void qrf_impl(const array& a, array& q, array& r) { | ||||
|   int info; | ||||
|  | ||||
|   // Compute workspace size | ||||
|   lpack<T>::xgeqrf( | ||||
|       &M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info); | ||||
|   geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info); | ||||
|  | ||||
|   // Update workspace size | ||||
|   lwork = optimal_work; | ||||
| @@ -76,10 +45,10 @@ void qrf_impl(const array& a, array& q, array& r) { | ||||
|   // Loop over matrices | ||||
|   for (int i = 0; i < num_matrices; ++i) { | ||||
|     // Solve | ||||
|     lpack<T>::xgeqrf( | ||||
|     geqrf<T>( | ||||
|         &M, | ||||
|         &N, | ||||
|         in.data<float>() + M * N * i, | ||||
|         in.data<T>() + M * N * i, | ||||
|         &lda, | ||||
|         static_cast<T*>(tau.raw_ptr()) + num_reflectors * i, | ||||
|         static_cast<T*>(work.raw_ptr()), | ||||
| @@ -105,7 +74,7 @@ void qrf_impl(const array& a, array& q, array& r) { | ||||
|  | ||||
|   // Get work size | ||||
|   lwork = -1; | ||||
|   lpack<T>::xorgqr( | ||||
|   orgqr<T>( | ||||
|       &M, | ||||
|       &num_reflectors, | ||||
|       &num_reflectors, | ||||
| @@ -121,11 +90,11 @@ void qrf_impl(const array& a, array& q, array& r) { | ||||
|   // Loop over matrices | ||||
|   for (int i = 0; i < num_matrices; ++i) { | ||||
|     // Compute Q | ||||
|     lpack<T>::xorgqr( | ||||
|     orgqr<T>( | ||||
|         &M, | ||||
|         &num_reflectors, | ||||
|         &num_reflectors, | ||||
|         in.data<float>() + M * N * i, | ||||
|         in.data<T>() + M * N * i, | ||||
|         &lda, | ||||
|         static_cast<T*>(tau.raw_ptr()) + num_reflectors * i, | ||||
|         static_cast<T*>(work.raw_ptr()), | ||||
| @@ -152,10 +121,17 @@ void qrf_impl(const array& a, array& q, array& r) { | ||||
| void QRF::eval_cpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs) { | ||||
|   if (!(inputs[0].dtype() == float32)) { | ||||
|     throw std::runtime_error("[QRF::eval] only supports float32."); | ||||
|   switch (inputs[0].dtype()) { | ||||
|     case float32: | ||||
|       qrf_impl<float>(inputs[0], outputs[0], outputs[1]); | ||||
|       break; | ||||
|     case float64: | ||||
|       qrf_impl<double>(inputs[0], outputs[0], outputs[1]); | ||||
|       break; | ||||
|     default: | ||||
|       throw std::runtime_error( | ||||
|           "[QRF::eval_cpu] only supports float32 or float64."); | ||||
|   } | ||||
|   qrf_impl<float>(inputs[0], outputs[0], outputs[1]); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -7,6 +7,7 @@ | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| template <typename T> | ||||
| void svd_impl(const array& a, array& u, array& s, array& vt) { | ||||
|   // Lapack uses the column-major convention. To avoid having to transpose | ||||
|   // the input and then transpose the outputs, we swap the indices/sizes of the | ||||
| @@ -31,7 +32,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { | ||||
|   size_t num_matrices = a.size() / (M * N); | ||||
|  | ||||
|   // lapack clobbers the input, so we have to make a copy. | ||||
|   array in(a.shape(), float32, nullptr, {}); | ||||
|   array in(a.shape(), a.dtype(), nullptr, {}); | ||||
|   copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); | ||||
|  | ||||
|   // Allocate outputs. | ||||
| @@ -45,7 +46,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { | ||||
|  | ||||
|   // Will contain the number of singular values after the call has returned. | ||||
|   int ns = 0; | ||||
|   float workspace_dimension = 0; | ||||
|   T workspace_dimension = 0; | ||||
|  | ||||
|   // Will contain the indices of eigenvectors that failed to converge (not used | ||||
|   // here but required by lapack). | ||||
| @@ -54,13 +55,12 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { | ||||
|   static const int lwork_query = -1; | ||||
|  | ||||
|   static const int ignored_int = 0; | ||||
|   static const float ignored_float = 0; | ||||
|   static const T ignored_float = 0; | ||||
|  | ||||
|   int info; | ||||
|  | ||||
|   // Compute workspace size. | ||||
|   MLX_LAPACK_FUNC(sgesvdx) | ||||
|   ( | ||||
|   gesvdx<T>( | ||||
|       /* jobu = */ job_u, | ||||
|       /* jobvt = */ job_vt, | ||||
|       /* range = */ range, | ||||
| @@ -86,51 +86,50 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { | ||||
|  | ||||
|   if (info != 0) { | ||||
|     std::stringstream ss; | ||||
|     ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info; | ||||
|     ss << "[SVD::eval_cpu] workspace calculation failed with code " << info; | ||||
|     throw std::runtime_error(ss.str()); | ||||
|   } | ||||
|  | ||||
|   const int lwork = workspace_dimension; | ||||
|   auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; | ||||
|   auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; | ||||
|  | ||||
|   // Loop over matrices. | ||||
|   for (int i = 0; i < num_matrices; i++) { | ||||
|     MLX_LAPACK_FUNC(sgesvdx) | ||||
|     ( | ||||
|     gesvdx<T>( | ||||
|         /* jobu = */ job_u, | ||||
|         /* jobvt = */ job_vt, | ||||
|         /* range = */ range, | ||||
|         // M and N are swapped since lapack expects column-major. | ||||
|         /* m = */ &N, | ||||
|         /* n = */ &M, | ||||
|         /* a = */ in.data<float>() + M * N * i, | ||||
|         /* a = */ in.data<T>() + M * N * i, | ||||
|         /* lda = */ &lda, | ||||
|         /* vl = */ &ignored_float, | ||||
|         /* vu = */ &ignored_float, | ||||
|         /* il = */ &ignored_int, | ||||
|         /* iu = */ &ignored_int, | ||||
|         /* ns = */ &ns, | ||||
|         /* s = */ s.data<float>() + K * i, | ||||
|         /* s = */ s.data<T>() + K * i, | ||||
|         // According to the identity above, lapack will write Vᵀᵀ as U. | ||||
|         /* u = */ vt.data<float>() + N * N * i, | ||||
|         /* u = */ vt.data<T>() + N * N * i, | ||||
|         /* ldu = */ &ldu, | ||||
|         // According to the identity above, lapack will write Uᵀ as Vᵀ. | ||||
|         /* vt = */ u.data<float>() + M * M * i, | ||||
|         /* vt = */ u.data<T>() + M * M * i, | ||||
|         /* ldvt = */ &ldvt, | ||||
|         /* work = */ static_cast<float*>(scratch.buffer.raw_ptr()), | ||||
|         /* work = */ static_cast<T*>(scratch.buffer.raw_ptr()), | ||||
|         /* lwork = */ &lwork, | ||||
|         /* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()), | ||||
|         /* info = */ &info); | ||||
|  | ||||
|     if (info != 0) { | ||||
|       std::stringstream ss; | ||||
|       ss << "svd_impl: sgesvdx_ failed with code " << info; | ||||
|       ss << "[SVD::eval_cpu] failed with code " << info; | ||||
|       throw std::runtime_error(ss.str()); | ||||
|     } | ||||
|  | ||||
|     if (ns != K) { | ||||
|       std::stringstream ss; | ||||
|       ss << "svd_impl: expected " << K << " singular values, but " << ns | ||||
|       ss << "[SVD::eval_cpu] expected " << K << " singular values, but " << ns | ||||
|          << " were computed."; | ||||
|       throw std::runtime_error(ss.str()); | ||||
|     } | ||||
| @@ -140,10 +139,17 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { | ||||
| void SVD::eval_cpu( | ||||
|     const std::vector<array>& inputs, | ||||
|     std::vector<array>& outputs) { | ||||
|   if (!(inputs[0].dtype() == float32)) { | ||||
|     throw std::runtime_error("[SVD::eval] only supports float32."); | ||||
|   switch (inputs[0].dtype()) { | ||||
|     case float32: | ||||
|       svd_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]); | ||||
|       break; | ||||
|     case float64: | ||||
|       svd_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]); | ||||
|       break; | ||||
|     default: | ||||
|       throw std::runtime_error( | ||||
|           "[SVD::eval_cpu] only supports float32 or float64."); | ||||
|   } | ||||
|   svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -18,6 +18,14 @@ void check_cpu_stream(const StreamOrDevice& s, const std::string& prefix) { | ||||
|         "Explicitly pass a CPU stream to run it."); | ||||
|   } | ||||
| } | ||||
| void check_float(Dtype dtype, const std::string& prefix) { | ||||
|   if (dtype != float32 && dtype != float64) { | ||||
|     std::ostringstream msg; | ||||
|     msg << prefix << " Arrays must have type float32 or float64. " | ||||
|         << "Received array with type " << dtype << "."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
| } | ||||
|  | ||||
| Dtype at_least_float(const Dtype& d) { | ||||
|   return issubdtype(d, inexact) ? d : promote_types(d, float32); | ||||
| @@ -184,12 +192,8 @@ array norm( | ||||
|  | ||||
| std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) { | ||||
|   check_cpu_stream(s, "[linalg::qr]"); | ||||
|   if (a.dtype() != float32) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[linalg::qr] Arrays must type float32. Received array " | ||||
|         << "with type " << a.dtype() << "."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   check_float(a.dtype(), "[linalg::qr]"); | ||||
|  | ||||
|   if (a.ndim() < 2) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[linalg::qr] Arrays must have >= 2 dimensions. Received array " | ||||
| @@ -212,12 +216,8 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) { | ||||
|  | ||||
| std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) { | ||||
|   check_cpu_stream(s, "[linalg::svd]"); | ||||
|   if (a.dtype() != float32) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[linalg::svd] Input array must have type float32. Received array " | ||||
|         << "with type " << a.dtype() << "."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   check_float(a.dtype(), "[linalg::svd]"); | ||||
|  | ||||
|   if (a.ndim() < 2) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array " | ||||
| @@ -251,12 +251,8 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) { | ||||
|  | ||||
| array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) { | ||||
|   check_cpu_stream(s, "[linalg::inv]"); | ||||
|   if (a.dtype() != float32) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[linalg::inv] Arrays must type float32. Received array " | ||||
|         << "with type " << a.dtype() << "."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   check_float(a.dtype(), "[linalg::inv]"); | ||||
|  | ||||
|   if (a.ndim() < 2) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[linalg::inv] Arrays must have >= 2 dimensions. Received array " | ||||
| @@ -292,13 +288,7 @@ array cholesky( | ||||
|     bool upper /* = false */, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   check_cpu_stream(s, "[linalg::cholesky]"); | ||||
|   if (a.dtype() != float32) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[linalg::cholesky] Arrays must type float32. Received array " | ||||
|         << "with type " << a.dtype() << "."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|  | ||||
|   check_float(a.dtype(), "[linalg::cholesky]"); | ||||
|   if (a.ndim() < 2) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array " | ||||
| @@ -321,12 +311,8 @@ array cholesky( | ||||
|  | ||||
| array pinv(const array& a, StreamOrDevice s /* = {} */) { | ||||
|   check_cpu_stream(s, "[linalg::pinv]"); | ||||
|   if (a.dtype() != float32) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[linalg::pinv] Arrays must type float32. Received array " | ||||
|         << "with type " << a.dtype() << "."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   check_float(a.dtype(), "[linalg::pinv]"); | ||||
|  | ||||
|   if (a.ndim() < 2) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[linalg::pinv] Arrays must have >= 2 dimensions. Received array " | ||||
| @@ -368,12 +354,7 @@ array cholesky_inv( | ||||
|     bool upper /* = false */, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   check_cpu_stream(s, "[linalg::cholesky_inv]"); | ||||
|   if (L.dtype() != float32) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[linalg::cholesky_inv] Arrays must type float32. Received array " | ||||
|         << "with type " << L.dtype() << "."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   check_float(L.dtype(), "[linalg::cholesky_inv]"); | ||||
|  | ||||
|   if (L.ndim() < 2) { | ||||
|     std::ostringstream msg; | ||||
| @@ -474,12 +455,7 @@ void validate_eigh( | ||||
|     const StreamOrDevice& stream, | ||||
|     const std::string fname) { | ||||
|   check_cpu_stream(stream, fname); | ||||
|   if (a.dtype() != float32) { | ||||
|     std::ostringstream msg; | ||||
|     msg << fname << " Arrays must have type float32. Received array " | ||||
|         << "with type " << a.dtype() << "."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   check_float(a.dtype(), fname); | ||||
|  | ||||
|   if (a.ndim() < 2) { | ||||
|     std::ostringstream msg; | ||||
| @@ -524,12 +500,7 @@ void validate_lu( | ||||
|     const StreamOrDevice& stream, | ||||
|     const std::string& fname) { | ||||
|   check_cpu_stream(stream, fname); | ||||
|   if (a.dtype() != float32) { | ||||
|     std::ostringstream msg; | ||||
|     msg << fname << " Arrays must type float32. Received array " | ||||
|         << "with type " << a.dtype() << "."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   check_float(a.dtype(), fname); | ||||
|  | ||||
|   if (a.ndim() < 2) { | ||||
|     std::ostringstream msg; | ||||
| @@ -627,10 +598,12 @@ void validate_solve( | ||||
|   } | ||||
|  | ||||
|   auto out_type = promote_types(a.dtype(), b.dtype()); | ||||
|   if (out_type != float32) { | ||||
|   if (out_type != float32 && out_type != float64) { | ||||
|     std::ostringstream msg; | ||||
|     msg << fname << " Input arrays must promote to float32. Received arrays " | ||||
|         << "with type " << a.dtype() << " and " << b.dtype() << "."; | ||||
|     msg << fname | ||||
|         << " Input arrays must promote to float32 or float64. " | ||||
|            " Received arrays with type " | ||||
|         << a.dtype() << " and " << b.dtype() << "."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
| } | ||||
|   | ||||
| @@ -44,6 +44,8 @@ std::string buffer_format(const mx::array& a) { | ||||
|       return "f"; | ||||
|     case mx::bfloat16: | ||||
|       return "B"; | ||||
|     case mx::float64: | ||||
|       return "d"; | ||||
|     case mx::complex64: | ||||
|       return "Zf\0"; | ||||
|     default: { | ||||
|   | ||||
| @@ -152,6 +152,8 @@ nb::ndarray<NDParams...> mlx_to_nd_array(const mx::array& a) { | ||||
|       throw nb::type_error("bfloat16 arrays cannot be converted to NumPy."); | ||||
|     case mx::float32: | ||||
|       return mlx_to_nd_array_impl<float, NDParams...>(a); | ||||
|     case mx::float64: | ||||
|       return mlx_to_nd_array_impl<double, NDParams...>(a); | ||||
|     case mx::complex64: | ||||
|       return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a); | ||||
|     default: | ||||
|   | ||||
| @@ -183,6 +183,115 @@ class TestDouble(mlx_tests.MLXTestCase): | ||||
|             c = a + b | ||||
|             self.assertEqual(c.dtype, mx.float64) | ||||
|  | ||||
|     def test_lapack(self): | ||||
|         with mx.stream(mx.cpu): | ||||
|             # QRF | ||||
|             A = mx.array([[2.0, 3.0], [1.0, 2.0]], dtype=mx.float64) | ||||
|             Q, R = mx.linalg.qr(A) | ||||
|             out = Q @ R | ||||
|             self.assertTrue(mx.allclose(out, A)) | ||||
|             out = Q.T @ Q | ||||
|             self.assertTrue(mx.allclose(out, mx.eye(2))) | ||||
|             self.assertTrue(mx.allclose(mx.tril(R, -1), mx.zeros_like(R))) | ||||
|             self.assertEqual(Q.dtype, mx.float64) | ||||
|             self.assertEqual(R.dtype, mx.float64) | ||||
|  | ||||
|             # SVD | ||||
|             A = mx.array( | ||||
|                 [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64 | ||||
|             ) | ||||
|             U, S, Vt = mx.linalg.svd(A) | ||||
|             self.assertTrue(mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A)) | ||||
|  | ||||
|             # Inverse | ||||
|             A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64) | ||||
|             A_inv = mx.linalg.inv(A) | ||||
|             self.assertTrue(mx.allclose(A @ A_inv, mx.eye(A.shape[0]))) | ||||
|  | ||||
|             # Tri inv | ||||
|             A = mx.array([[1, 0, 0], [6, -5, 0], [-9, 8, 7]], dtype=mx.float64) | ||||
|             B = mx.array([[7, 0, 0], [3, -2, 0], [1, 8, 3]], dtype=mx.float64) | ||||
|             AB = mx.stack([A, B]) | ||||
|             invs = mx.linalg.tri_inv(AB, upper=False) | ||||
|             for M, M_inv in zip(AB, invs): | ||||
|                 self.assertTrue(mx.allclose(M @ M_inv, mx.eye(M.shape[0]))) | ||||
|  | ||||
|             # Cholesky | ||||
|             sqrtA = mx.array( | ||||
|                 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float64 | ||||
|             ) | ||||
|             A = sqrtA.T @ sqrtA / 81 | ||||
|             L = mx.linalg.cholesky(A) | ||||
|             U = mx.linalg.cholesky(A, upper=True) | ||||
|             self.assertTrue(mx.allclose(L @ L.T, A)) | ||||
|             self.assertTrue(mx.allclose(U.T @ U, A)) | ||||
|  | ||||
|             # Psueod inverse | ||||
|             A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64) | ||||
|             A_plus = mx.linalg.pinv(A) | ||||
|             self.assertTrue(mx.allclose(A @ A_plus @ A, A)) | ||||
|  | ||||
|             # Eigh | ||||
|             def check_eigs_and_vecs(A_np, kwargs={}): | ||||
|                 A = mx.array(A_np, dtype=mx.float64) | ||||
|                 eig_vals, eig_vecs = mx.linalg.eigh(A, **kwargs) | ||||
|                 eig_vals_np, _ = np.linalg.eigh(A_np, **kwargs) | ||||
|                 self.assertTrue(np.allclose(eig_vals, eig_vals_np)) | ||||
|                 self.assertTrue( | ||||
|                     mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs) | ||||
|                 ) | ||||
|  | ||||
|                 eig_vals_only = mx.linalg.eigvalsh(A, **kwargs) | ||||
|                 self.assertTrue(mx.allclose(eig_vals, eig_vals_only)) | ||||
|  | ||||
|             # Test a simple 2x2 symmetric matrix | ||||
|             A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float64) | ||||
|             check_eigs_and_vecs(A_np) | ||||
|  | ||||
|             # Test a larger random symmetric matrix | ||||
|             n = 5 | ||||
|             np.random.seed(1) | ||||
|             A_np = np.random.randn(n, n).astype(np.float64) | ||||
|             A_np = (A_np + A_np.T) / 2 | ||||
|             check_eigs_and_vecs(A_np) | ||||
|  | ||||
|             # Test with upper triangle | ||||
|             check_eigs_and_vecs(A_np, {"UPLO": "U"}) | ||||
|  | ||||
|             # LU factorization | ||||
|             # Test 3x3 matrix | ||||
|             a = mx.array( | ||||
|                 [[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]], dtype=mx.float64 | ||||
|             ) | ||||
|             P, L, U = mx.linalg.lu(a) | ||||
|             self.assertTrue(mx.allclose(L[P, :] @ U, a)) | ||||
|  | ||||
|             # Solve triangular | ||||
|             # Test lower triangular matrix | ||||
|             a = mx.array( | ||||
|                 [[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]], dtype=mx.float64 | ||||
|             ) | ||||
|             b = mx.array([8.0, 14.0, 3.0], dtype=mx.float64) | ||||
|  | ||||
|             result = mx.linalg.solve_triangular(a, b, upper=False) | ||||
|             expected = np.linalg.solve(np.array(a), np.array(b)) | ||||
|             self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|             # Test upper triangular matrix | ||||
|             a = mx.array( | ||||
|                 [[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]], dtype=mx.float64 | ||||
|             ) | ||||
|             b = mx.array([13.0, 33.0, 18.0], dtype=mx.float64) | ||||
|  | ||||
|             result = mx.linalg.solve_triangular(a, b, upper=True) | ||||
|             expected = np.linalg.solve(np.array(a), np.array(b)) | ||||
|             self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|     def test_conversion(self): | ||||
|         a = mx.array([1.0, 2.0], mx.float64) | ||||
|         b = np.array(a) | ||||
|         self.assertTrue(np.array_equal(a, b)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun