mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Double for lapack (#1904)
* double for lapack ops * add double support for lapack ops
This commit is contained in:
parent
28b8079e30
commit
7d042f17fe
@ -8,6 +8,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||
// the matrix should be symmetric:
|
||||
@ -28,13 +29,12 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
float* matrix = factor.data<float>();
|
||||
T* matrix = factor.data<T>();
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(spotrf)
|
||||
(
|
||||
potrf<T>(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
@ -65,10 +65,17 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
}
|
||||
|
||||
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Cholesky::eval] only supports float32.");
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
cholesky_impl<float>(inputs[0], output, upper_);
|
||||
break;
|
||||
case float64:
|
||||
cholesky_impl<double>(inputs[0], output, upper_);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Cholesky::eval_cpu] only supports float32 or float64.");
|
||||
}
|
||||
cholesky_impl(inputs[0], output, upper_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -11,35 +11,64 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void ssyevd(
|
||||
char jobz,
|
||||
char uplo,
|
||||
float* a,
|
||||
int N,
|
||||
float* w,
|
||||
float* work,
|
||||
int lwork,
|
||||
int* iwork,
|
||||
int liwork) {
|
||||
template <typename T>
|
||||
void eigh_impl(
|
||||
array& vectors,
|
||||
array& values,
|
||||
const std::string& uplo,
|
||||
bool compute_eigenvectors) {
|
||||
auto vec_ptr = vectors.data<T>();
|
||||
auto eig_ptr = values.data<T>();
|
||||
|
||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||
auto N = vectors.shape(-1);
|
||||
|
||||
// Work query
|
||||
int lwork = -1;
|
||||
int liwork = -1;
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(ssyevd)
|
||||
(
|
||||
/* jobz = */ &jobz,
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ a,
|
||||
/* lda = */ &N,
|
||||
/* w = */ w,
|
||||
/* work = */ work,
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ iwork,
|
||||
/* liwork = */ &liwork,
|
||||
/* info = */ &info);
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
uplo.c_str(),
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||
for (size_t i = 0; i < vectors.size() / (N * N); ++i) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
uplo.c_str(),
|
||||
&N,
|
||||
vec_ptr,
|
||||
&N,
|
||||
eig_ptr,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,39 +109,16 @@ void Eigh::eval_cpu(
|
||||
}
|
||||
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
|
||||
}
|
||||
|
||||
auto vec_ptr = vectors.data<float>();
|
||||
auto eig_ptr = values.data<float>();
|
||||
|
||||
char jobz = compute_eigenvectors_ ? 'V' : 'N';
|
||||
auto N = a.shape(-1);
|
||||
|
||||
// Work query
|
||||
int lwork;
|
||||
int liwork;
|
||||
{
|
||||
float work;
|
||||
int iwork;
|
||||
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||
for (size_t i = 0; i < a.size() / (N * N); ++i) {
|
||||
ssyevd(
|
||||
jobz,
|
||||
uplo_[0],
|
||||
vec_ptr,
|
||||
N,
|
||||
eig_ptr,
|
||||
static_cast<float*>(work_buf.buffer.raw_ptr()),
|
||||
lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
liwork);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
switch (a.dtype()) {
|
||||
case float32:
|
||||
eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_);
|
||||
break;
|
||||
case float64:
|
||||
eigh_impl<double>(vectors, values, uplo_, compute_eigenvectors_);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Eigh::eval_cpu] only supports float32 or float64.");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,44 +5,33 @@
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(strtri)
|
||||
(
|
||||
/* uplo = */ &uplo,
|
||||
/* diag = */ &diag,
|
||||
/* N = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
return info;
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void general_inv(array& inv, int N, int i) {
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
// Compute LU factorization.
|
||||
sgetrf_(
|
||||
getrf<T>(
|
||||
/* m = */ &N,
|
||||
/* n = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* a = */ inv.data<T>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU factorization failed with error code " << info;
|
||||
ss << "[Inverse::eval_cpu] LU factorization failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
static const int lwork_query = -1;
|
||||
float workspace_size = 0;
|
||||
T workspace_size = 0;
|
||||
|
||||
// Compute workspace size.
|
||||
sgetri_(
|
||||
getri<T>(
|
||||
/* m = */ &N,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &N,
|
||||
@ -53,36 +42,44 @@ void general_inv(array& inv, int N, int i) {
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU workspace calculation failed with error code "
|
||||
ss << "[Inverse::eval_cpu] LU workspace calculation failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
sgetri_(
|
||||
getri<T>(
|
||||
/* m = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* a = */ inv.data<T>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: inversion failed with error code " << info;
|
||||
ss << "[Inverse::eval_cpu] inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void tri_inv(array& inv, int N, int i, bool upper) {
|
||||
const char uplo = upper ? 'L' : 'U';
|
||||
const char diag = 'N';
|
||||
float* data = inv.data<float>() + N * N * i;
|
||||
int info = strtri_wrapper(uplo, diag, data, N);
|
||||
T* data = inv.data<T>() + N * N * i;
|
||||
int info;
|
||||
trtri<T>(
|
||||
/* uplo = */ &uplo,
|
||||
/* diag = */ &diag,
|
||||
/* N = */ &N,
|
||||
/* a = */ data,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
|
||||
// zero out the other triangle
|
||||
if (upper) {
|
||||
@ -99,11 +96,13 @@ void tri_inv(array& inv, int N, int i, bool upper) {
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: triangular inversion failed with error code " << info;
|
||||
ss << "[Inverse::eval_cpu] triangular inversion failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the following
|
||||
// identity to avoid transposing (see
|
||||
@ -118,18 +117,25 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
if (tri) {
|
||||
tri_inv(inv, N, i, upper);
|
||||
tri_inv<T>(inv, N, i, upper);
|
||||
} else {
|
||||
general_inv(inv, N, i);
|
||||
general_inv<T>(inv, N, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Inverse::eval] only supports float32.");
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
inverse_impl<float>(inputs[0], output, tri_, upper_);
|
||||
break;
|
||||
case float64:
|
||||
inverse_impl<double>(inputs[0], output, tri_, upper_);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Inverse::eval_cpu] only supports float32 or float64.");
|
||||
}
|
||||
inverse_impl(inputs[0], output, tri_, upper_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -31,3 +31,22 @@
|
||||
#define MLX_LAPACK_FUNC(f) f##_
|
||||
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
|
||||
template <typename T, typename... Args> \
|
||||
void FUNC(Args... args) { \
|
||||
if constexpr (std::is_same_v<T, float>) { \
|
||||
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
|
||||
} else if constexpr (std::is_same_v<T, double>) { \
|
||||
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_TYPES(geqrf)
|
||||
INSTANTIATE_LAPACK_TYPES(orgqr)
|
||||
INSTANTIATE_LAPACK_TYPES(syevd)
|
||||
INSTANTIATE_LAPACK_TYPES(potrf)
|
||||
INSTANTIATE_LAPACK_TYPES(gesvdx)
|
||||
INSTANTIATE_LAPACK_TYPES(getrf)
|
||||
INSTANTIATE_LAPACK_TYPES(getri)
|
||||
INSTANTIATE_LAPACK_TYPES(trtri)
|
||||
|
@ -9,11 +9,8 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void lu_factor_impl(
|
||||
const array& a,
|
||||
array& lu,
|
||||
array& pivots,
|
||||
array& row_indices) {
|
||||
template <typename T>
|
||||
void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) {
|
||||
int M = a.shape(-2);
|
||||
int N = a.shape(-1);
|
||||
|
||||
@ -31,7 +28,7 @@ void lu_factor_impl(
|
||||
copy_inplace(
|
||||
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
|
||||
|
||||
auto a_ptr = lu.data<float>();
|
||||
auto a_ptr = lu.data<T>();
|
||||
|
||||
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
|
||||
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
|
||||
@ -42,13 +39,13 @@ void lu_factor_impl(
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
for (size_t i = 0; i < num_matrices; ++i) {
|
||||
// Compute LU factorization of A
|
||||
MLX_LAPACK_FUNC(sgetrf)
|
||||
(/* m */ &M,
|
||||
/* n */ &N,
|
||||
/* a */ a_ptr,
|
||||
/* lda */ &M,
|
||||
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
|
||||
/* info */ &info);
|
||||
getrf<T>(
|
||||
/* m */ &M,
|
||||
/* n */ &N,
|
||||
/* a */ a_ptr,
|
||||
/* lda */ &M,
|
||||
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
|
||||
/* info */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
@ -86,7 +83,17 @@ void LUF::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
break;
|
||||
case float64:
|
||||
luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[LUF::eval_cpu] only supports float32 or float64.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -7,36 +7,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
struct lpack;
|
||||
|
||||
template <>
|
||||
struct lpack<float> {
|
||||
static void xgeqrf(
|
||||
const int* m,
|
||||
const int* n,
|
||||
float* a,
|
||||
const int* lda,
|
||||
float* tau,
|
||||
float* work,
|
||||
const int* lwork,
|
||||
int* info) {
|
||||
sgeqrf_(m, n, a, lda, tau, work, lwork, info);
|
||||
}
|
||||
static void xorgqr(
|
||||
const int* m,
|
||||
const int* n,
|
||||
const int* k,
|
||||
float* a,
|
||||
const int* lda,
|
||||
const float* tau,
|
||||
float* work,
|
||||
const int* lwork,
|
||||
int* info) {
|
||||
sorgqr_(m, n, k, a, lda, tau, work, lwork, info);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void qrf_impl(const array& a, array& q, array& r) {
|
||||
const int M = a.shape(-2);
|
||||
@ -48,7 +18,7 @@ void qrf_impl(const array& a, array& q, array& r) {
|
||||
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
|
||||
|
||||
// Copy A to inplace input and make it col-contiguous
|
||||
array in(a.shape(), float32, nullptr, {});
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
auto flags = in.flags();
|
||||
|
||||
// Copy the input to be column contiguous
|
||||
@ -66,8 +36,7 @@ void qrf_impl(const array& a, array& q, array& r) {
|
||||
int info;
|
||||
|
||||
// Compute workspace size
|
||||
lpack<T>::xgeqrf(
|
||||
&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
|
||||
geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
|
||||
|
||||
// Update workspace size
|
||||
lwork = optimal_work;
|
||||
@ -76,10 +45,10 @@ void qrf_impl(const array& a, array& q, array& r) {
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Solve
|
||||
lpack<T>::xgeqrf(
|
||||
geqrf<T>(
|
||||
&M,
|
||||
&N,
|
||||
in.data<float>() + M * N * i,
|
||||
in.data<T>() + M * N * i,
|
||||
&lda,
|
||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||
static_cast<T*>(work.raw_ptr()),
|
||||
@ -105,7 +74,7 @@ void qrf_impl(const array& a, array& q, array& r) {
|
||||
|
||||
// Get work size
|
||||
lwork = -1;
|
||||
lpack<T>::xorgqr(
|
||||
orgqr<T>(
|
||||
&M,
|
||||
&num_reflectors,
|
||||
&num_reflectors,
|
||||
@ -121,11 +90,11 @@ void qrf_impl(const array& a, array& q, array& r) {
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Compute Q
|
||||
lpack<T>::xorgqr(
|
||||
orgqr<T>(
|
||||
&M,
|
||||
&num_reflectors,
|
||||
&num_reflectors,
|
||||
in.data<float>() + M * N * i,
|
||||
in.data<T>() + M * N * i,
|
||||
&lda,
|
||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||
static_cast<T*>(work.raw_ptr()),
|
||||
@ -152,10 +121,17 @@ void qrf_impl(const array& a, array& q, array& r) {
|
||||
void QRF::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
if (!(inputs[0].dtype() == float32)) {
|
||||
throw std::runtime_error("[QRF::eval] only supports float32.");
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
|
||||
break;
|
||||
case float64:
|
||||
qrf_impl<double>(inputs[0], outputs[0], outputs[1]);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[QRF::eval_cpu] only supports float32 or float64.");
|
||||
}
|
||||
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -7,6 +7,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
// Lapack uses the column-major convention. To avoid having to transpose
|
||||
// the input and then transpose the outputs, we swap the indices/sizes of the
|
||||
@ -31,7 +32,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// lapack clobbers the input, so we have to make a copy.
|
||||
array in(a.shape(), float32, nullptr, {});
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
// Allocate outputs.
|
||||
@ -45,7 +46,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
|
||||
// Will contain the number of singular values after the call has returned.
|
||||
int ns = 0;
|
||||
float workspace_dimension = 0;
|
||||
T workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not used
|
||||
// here but required by lapack).
|
||||
@ -54,13 +55,12 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
static const int lwork_query = -1;
|
||||
|
||||
static const int ignored_int = 0;
|
||||
static const float ignored_float = 0;
|
||||
static const T ignored_float = 0;
|
||||
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
MLX_LAPACK_FUNC(sgesvdx)
|
||||
(
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
@ -86,51 +86,50 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info;
|
||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_dimension;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
MLX_LAPACK_FUNC(sgesvdx)
|
||||
(
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ in.data<float>() + M * N * i,
|
||||
/* a = */ in.data<T>() + M * N * i,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ s.data<float>() + K * i,
|
||||
/* s = */ s.data<T>() + K * i,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ vt.data<float>() + N * N * i,
|
||||
/* u = */ vt.data<T>() + N * N * i,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ u.data<float>() + M * M * i,
|
||||
/* vt = */ u.data<T>() + M * M * i,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
ss << "[SVD::eval_cpu] failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
if (ns != K) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: expected " << K << " singular values, but " << ns
|
||||
ss << "[SVD::eval_cpu] expected " << K << " singular values, but " << ns
|
||||
<< " were computed.";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
@ -140,10 +139,17 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
void SVD::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
if (!(inputs[0].dtype() == float32)) {
|
||||
throw std::runtime_error("[SVD::eval] only supports float32.");
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
svd_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
break;
|
||||
case float64:
|
||||
svd_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[SVD::eval_cpu] only supports float32 or float64.");
|
||||
}
|
||||
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -18,6 +18,14 @@ void check_cpu_stream(const StreamOrDevice& s, const std::string& prefix) {
|
||||
"Explicitly pass a CPU stream to run it.");
|
||||
}
|
||||
}
|
||||
void check_float(Dtype dtype, const std::string& prefix) {
|
||||
if (dtype != float32 && dtype != float64) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Arrays must have type float32 or float64. "
|
||||
<< "Received array with type " << dtype << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||
@ -184,12 +192,8 @@ array norm(
|
||||
|
||||
std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
||||
check_cpu_stream(s, "[linalg::qr]");
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::qr] Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(a.dtype(), "[linalg::qr]");
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::qr] Arrays must have >= 2 dimensions. Received array "
|
||||
@ -212,12 +216,8 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||
check_cpu_stream(s, "[linalg::svd]");
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::svd] Input array must have type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(a.dtype(), "[linalg::svd]");
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array "
|
||||
@ -251,12 +251,8 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) {
|
||||
check_cpu_stream(s, "[linalg::inv]");
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::inv] Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(a.dtype(), "[linalg::inv]");
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::inv] Arrays must have >= 2 dimensions. Received array "
|
||||
@ -292,13 +288,7 @@ array cholesky(
|
||||
bool upper /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
check_cpu_stream(s, "[linalg::cholesky]");
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::cholesky] Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
check_float(a.dtype(), "[linalg::cholesky]");
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array "
|
||||
@ -321,12 +311,8 @@ array cholesky(
|
||||
|
||||
array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
check_cpu_stream(s, "[linalg::pinv]");
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::pinv] Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(a.dtype(), "[linalg::pinv]");
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::pinv] Arrays must have >= 2 dimensions. Received array "
|
||||
@ -368,12 +354,7 @@ array cholesky_inv(
|
||||
bool upper /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
check_cpu_stream(s, "[linalg::cholesky_inv]");
|
||||
if (L.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::cholesky_inv] Arrays must type float32. Received array "
|
||||
<< "with type " << L.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(L.dtype(), "[linalg::cholesky_inv]");
|
||||
|
||||
if (L.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
@ -474,12 +455,7 @@ void validate_eigh(
|
||||
const StreamOrDevice& stream,
|
||||
const std::string fname) {
|
||||
check_cpu_stream(stream, fname);
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Arrays must have type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(a.dtype(), fname);
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
@ -524,12 +500,7 @@ void validate_lu(
|
||||
const StreamOrDevice& stream,
|
||||
const std::string& fname) {
|
||||
check_cpu_stream(stream, fname);
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(a.dtype(), fname);
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
@ -627,10 +598,12 @@ void validate_solve(
|
||||
}
|
||||
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
if (out_type != float32) {
|
||||
if (out_type != float32 && out_type != float64) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Input arrays must promote to float32. Received arrays "
|
||||
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
|
||||
msg << fname
|
||||
<< " Input arrays must promote to float32 or float64. "
|
||||
" Received arrays with type "
|
||||
<< a.dtype() << " and " << b.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
@ -44,6 +44,8 @@ std::string buffer_format(const mx::array& a) {
|
||||
return "f";
|
||||
case mx::bfloat16:
|
||||
return "B";
|
||||
case mx::float64:
|
||||
return "d";
|
||||
case mx::complex64:
|
||||
return "Zf\0";
|
||||
default: {
|
||||
|
@ -152,6 +152,8 @@ nb::ndarray<NDParams...> mlx_to_nd_array(const mx::array& a) {
|
||||
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
|
||||
case mx::float32:
|
||||
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
||||
case mx::float64:
|
||||
return mlx_to_nd_array_impl<double, NDParams...>(a);
|
||||
case mx::complex64:
|
||||
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
|
||||
default:
|
||||
|
@ -183,6 +183,115 @@ class TestDouble(mlx_tests.MLXTestCase):
|
||||
c = a + b
|
||||
self.assertEqual(c.dtype, mx.float64)
|
||||
|
||||
def test_lapack(self):
|
||||
with mx.stream(mx.cpu):
|
||||
# QRF
|
||||
A = mx.array([[2.0, 3.0], [1.0, 2.0]], dtype=mx.float64)
|
||||
Q, R = mx.linalg.qr(A)
|
||||
out = Q @ R
|
||||
self.assertTrue(mx.allclose(out, A))
|
||||
out = Q.T @ Q
|
||||
self.assertTrue(mx.allclose(out, mx.eye(2)))
|
||||
self.assertTrue(mx.allclose(mx.tril(R, -1), mx.zeros_like(R)))
|
||||
self.assertEqual(Q.dtype, mx.float64)
|
||||
self.assertEqual(R.dtype, mx.float64)
|
||||
|
||||
# SVD
|
||||
A = mx.array(
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64
|
||||
)
|
||||
U, S, Vt = mx.linalg.svd(A)
|
||||
self.assertTrue(mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A))
|
||||
|
||||
# Inverse
|
||||
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64)
|
||||
A_inv = mx.linalg.inv(A)
|
||||
self.assertTrue(mx.allclose(A @ A_inv, mx.eye(A.shape[0])))
|
||||
|
||||
# Tri inv
|
||||
A = mx.array([[1, 0, 0], [6, -5, 0], [-9, 8, 7]], dtype=mx.float64)
|
||||
B = mx.array([[7, 0, 0], [3, -2, 0], [1, 8, 3]], dtype=mx.float64)
|
||||
AB = mx.stack([A, B])
|
||||
invs = mx.linalg.tri_inv(AB, upper=False)
|
||||
for M, M_inv in zip(AB, invs):
|
||||
self.assertTrue(mx.allclose(M @ M_inv, mx.eye(M.shape[0])))
|
||||
|
||||
# Cholesky
|
||||
sqrtA = mx.array(
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float64
|
||||
)
|
||||
A = sqrtA.T @ sqrtA / 81
|
||||
L = mx.linalg.cholesky(A)
|
||||
U = mx.linalg.cholesky(A, upper=True)
|
||||
self.assertTrue(mx.allclose(L @ L.T, A))
|
||||
self.assertTrue(mx.allclose(U.T @ U, A))
|
||||
|
||||
# Psueod inverse
|
||||
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64)
|
||||
A_plus = mx.linalg.pinv(A)
|
||||
self.assertTrue(mx.allclose(A @ A_plus @ A, A))
|
||||
|
||||
# Eigh
|
||||
def check_eigs_and_vecs(A_np, kwargs={}):
|
||||
A = mx.array(A_np, dtype=mx.float64)
|
||||
eig_vals, eig_vecs = mx.linalg.eigh(A, **kwargs)
|
||||
eig_vals_np, _ = np.linalg.eigh(A_np, **kwargs)
|
||||
self.assertTrue(np.allclose(eig_vals, eig_vals_np))
|
||||
self.assertTrue(
|
||||
mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs)
|
||||
)
|
||||
|
||||
eig_vals_only = mx.linalg.eigvalsh(A, **kwargs)
|
||||
self.assertTrue(mx.allclose(eig_vals, eig_vals_only))
|
||||
|
||||
# Test a simple 2x2 symmetric matrix
|
||||
A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float64)
|
||||
check_eigs_and_vecs(A_np)
|
||||
|
||||
# Test a larger random symmetric matrix
|
||||
n = 5
|
||||
np.random.seed(1)
|
||||
A_np = np.random.randn(n, n).astype(np.float64)
|
||||
A_np = (A_np + A_np.T) / 2
|
||||
check_eigs_and_vecs(A_np)
|
||||
|
||||
# Test with upper triangle
|
||||
check_eigs_and_vecs(A_np, {"UPLO": "U"})
|
||||
|
||||
# LU factorization
|
||||
# Test 3x3 matrix
|
||||
a = mx.array(
|
||||
[[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]], dtype=mx.float64
|
||||
)
|
||||
P, L, U = mx.linalg.lu(a)
|
||||
self.assertTrue(mx.allclose(L[P, :] @ U, a))
|
||||
|
||||
# Solve triangular
|
||||
# Test lower triangular matrix
|
||||
a = mx.array(
|
||||
[[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]], dtype=mx.float64
|
||||
)
|
||||
b = mx.array([8.0, 14.0, 3.0], dtype=mx.float64)
|
||||
|
||||
result = mx.linalg.solve_triangular(a, b, upper=False)
|
||||
expected = np.linalg.solve(np.array(a), np.array(b))
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test upper triangular matrix
|
||||
a = mx.array(
|
||||
[[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]], dtype=mx.float64
|
||||
)
|
||||
b = mx.array([13.0, 33.0, 18.0], dtype=mx.float64)
|
||||
|
||||
result = mx.linalg.solve_triangular(a, b, upper=True)
|
||||
expected = np.linalg.solve(np.array(a), np.array(b))
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
def test_conversion(self):
|
||||
a = mx.array([1.0, 2.0], mx.float64)
|
||||
b = np.array(a)
|
||||
self.assertTrue(np.array_equal(a, b))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user