mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Double for lapack (#1904)
* double for lapack ops * add double support for lapack ops
This commit is contained in:
parent
28b8079e30
commit
7d042f17fe
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||||
// the matrix should be symmetric:
|
// the matrix should be symmetric:
|
||||||
@ -28,13 +29,12 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
|
|||||||
const int N = a.shape(-1);
|
const int N = a.shape(-1);
|
||||||
const size_t num_matrices = a.size() / (N * N);
|
const size_t num_matrices = a.size() / (N * N);
|
||||||
|
|
||||||
float* matrix = factor.data<float>();
|
T* matrix = factor.data<T>();
|
||||||
|
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
// Compute Cholesky factorization.
|
// Compute Cholesky factorization.
|
||||||
int info;
|
int info;
|
||||||
MLX_LAPACK_FUNC(spotrf)
|
potrf<T>(
|
||||||
(
|
|
||||||
/* uplo = */ &uplo,
|
/* uplo = */ &uplo,
|
||||||
/* n = */ &N,
|
/* n = */ &N,
|
||||||
/* a = */ matrix,
|
/* a = */ matrix,
|
||||||
@ -65,10 +65,17 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
|
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||||
if (inputs[0].dtype() != float32) {
|
switch (inputs[0].dtype()) {
|
||||||
throw std::runtime_error("[Cholesky::eval] only supports float32.");
|
case float32:
|
||||||
|
cholesky_impl<float>(inputs[0], output, upper_);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
cholesky_impl<double>(inputs[0], output, upper_);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[Cholesky::eval_cpu] only supports float32 or float64.");
|
||||||
}
|
}
|
||||||
cholesky_impl(inputs[0], output, upper_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -11,30 +11,58 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void ssyevd(
|
template <typename T>
|
||||||
char jobz,
|
void eigh_impl(
|
||||||
char uplo,
|
array& vectors,
|
||||||
float* a,
|
array& values,
|
||||||
int N,
|
const std::string& uplo,
|
||||||
float* w,
|
bool compute_eigenvectors) {
|
||||||
float* work,
|
auto vec_ptr = vectors.data<T>();
|
||||||
int lwork,
|
auto eig_ptr = values.data<T>();
|
||||||
int* iwork,
|
|
||||||
int liwork) {
|
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||||
|
auto N = vectors.shape(-1);
|
||||||
|
|
||||||
|
// Work query
|
||||||
|
int lwork = -1;
|
||||||
|
int liwork = -1;
|
||||||
int info;
|
int info;
|
||||||
MLX_LAPACK_FUNC(ssyevd)
|
{
|
||||||
(
|
T work;
|
||||||
/* jobz = */ &jobz,
|
int iwork;
|
||||||
/* uplo = */ &uplo,
|
syevd<T>(
|
||||||
/* n = */ &N,
|
&jobz,
|
||||||
/* a = */ a,
|
uplo.c_str(),
|
||||||
/* lda = */ &N,
|
&N,
|
||||||
/* w = */ w,
|
nullptr,
|
||||||
/* work = */ work,
|
&N,
|
||||||
/* lwork = */ &lwork,
|
nullptr,
|
||||||
/* iwork = */ iwork,
|
&work,
|
||||||
/* liwork = */ &liwork,
|
&lwork,
|
||||||
/* info = */ &info);
|
&iwork,
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work);
|
||||||
|
liwork = iwork;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||||
|
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||||
|
for (size_t i = 0; i < vectors.size() / (N * N); ++i) {
|
||||||
|
syevd<T>(
|
||||||
|
&jobz,
|
||||||
|
uplo.c_str(),
|
||||||
|
&N,
|
||||||
|
vec_ptr,
|
||||||
|
&N,
|
||||||
|
eig_ptr,
|
||||||
|
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
vec_ptr += N * N;
|
||||||
|
eig_ptr += N;
|
||||||
if (info != 0) {
|
if (info != 0) {
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||||
@ -42,6 +70,7 @@ void ssyevd(
|
|||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -80,39 +109,16 @@ void Eigh::eval_cpu(
|
|||||||
}
|
}
|
||||||
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
|
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
|
||||||
}
|
}
|
||||||
|
switch (a.dtype()) {
|
||||||
auto vec_ptr = vectors.data<float>();
|
case float32:
|
||||||
auto eig_ptr = values.data<float>();
|
eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_);
|
||||||
|
break;
|
||||||
char jobz = compute_eigenvectors_ ? 'V' : 'N';
|
case float64:
|
||||||
auto N = a.shape(-1);
|
eigh_impl<double>(vectors, values, uplo_, compute_eigenvectors_);
|
||||||
|
break;
|
||||||
// Work query
|
default:
|
||||||
int lwork;
|
throw std::runtime_error(
|
||||||
int liwork;
|
"[Eigh::eval_cpu] only supports float32 or float64.");
|
||||||
{
|
|
||||||
float work;
|
|
||||||
int iwork;
|
|
||||||
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
|
|
||||||
lwork = static_cast<int>(work);
|
|
||||||
liwork = iwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
|
||||||
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
|
||||||
for (size_t i = 0; i < a.size() / (N * N); ++i) {
|
|
||||||
ssyevd(
|
|
||||||
jobz,
|
|
||||||
uplo_[0],
|
|
||||||
vec_ptr,
|
|
||||||
N,
|
|
||||||
eig_ptr,
|
|
||||||
static_cast<float*>(work_buf.buffer.raw_ptr()),
|
|
||||||
lwork,
|
|
||||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
|
||||||
liwork);
|
|
||||||
vec_ptr += N * N;
|
|
||||||
eig_ptr += N;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,44 +5,33 @@
|
|||||||
#include "mlx/backend/cpu/lapack.h"
|
#include "mlx/backend/cpu/lapack.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
|
|
||||||
int info;
|
|
||||||
MLX_LAPACK_FUNC(strtri)
|
|
||||||
(
|
|
||||||
/* uplo = */ &uplo,
|
|
||||||
/* diag = */ &diag,
|
|
||||||
/* N = */ &N,
|
|
||||||
/* a = */ matrix,
|
|
||||||
/* lda = */ &N,
|
|
||||||
/* info = */ &info);
|
|
||||||
return info;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
void general_inv(array& inv, int N, int i) {
|
void general_inv(array& inv, int N, int i) {
|
||||||
int info;
|
int info;
|
||||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||||
// Compute LU factorization.
|
// Compute LU factorization.
|
||||||
sgetrf_(
|
getrf<T>(
|
||||||
/* m = */ &N,
|
/* m = */ &N,
|
||||||
/* n = */ &N,
|
/* n = */ &N,
|
||||||
/* a = */ inv.data<float>() + N * N * i,
|
/* a = */ inv.data<T>() + N * N * i,
|
||||||
/* lda = */ &N,
|
/* lda = */ &N,
|
||||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||||
/* info = */ &info);
|
/* info = */ &info);
|
||||||
|
|
||||||
if (info != 0) {
|
if (info != 0) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "inverse_impl: LU factorization failed with error code " << info;
|
ss << "[Inverse::eval_cpu] LU factorization failed with error code "
|
||||||
|
<< info;
|
||||||
throw std::runtime_error(ss.str());
|
throw std::runtime_error(ss.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
static const int lwork_query = -1;
|
static const int lwork_query = -1;
|
||||||
float workspace_size = 0;
|
T workspace_size = 0;
|
||||||
|
|
||||||
// Compute workspace size.
|
// Compute workspace size.
|
||||||
sgetri_(
|
getri<T>(
|
||||||
/* m = */ &N,
|
/* m = */ &N,
|
||||||
/* a = */ nullptr,
|
/* a = */ nullptr,
|
||||||
/* lda = */ &N,
|
/* lda = */ &N,
|
||||||
@ -53,36 +42,44 @@ void general_inv(array& inv, int N, int i) {
|
|||||||
|
|
||||||
if (info != 0) {
|
if (info != 0) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "inverse_impl: LU workspace calculation failed with error code "
|
ss << "[Inverse::eval_cpu] LU workspace calculation failed with error code "
|
||||||
<< info;
|
<< info;
|
||||||
throw std::runtime_error(ss.str());
|
throw std::runtime_error(ss.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
const int lwork = workspace_size;
|
const int lwork = workspace_size;
|
||||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||||
|
|
||||||
// Compute inverse.
|
// Compute inverse.
|
||||||
sgetri_(
|
getri<T>(
|
||||||
/* m = */ &N,
|
/* m = */ &N,
|
||||||
/* a = */ inv.data<float>() + N * N * i,
|
/* a = */ inv.data<T>() + N * N * i,
|
||||||
/* lda = */ &N,
|
/* lda = */ &N,
|
||||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||||
/* lwork = */ &lwork,
|
/* lwork = */ &lwork,
|
||||||
/* info = */ &info);
|
/* info = */ &info);
|
||||||
|
|
||||||
if (info != 0) {
|
if (info != 0) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "inverse_impl: inversion failed with error code " << info;
|
ss << "[Inverse::eval_cpu] inversion failed with error code " << info;
|
||||||
throw std::runtime_error(ss.str());
|
throw std::runtime_error(ss.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
void tri_inv(array& inv, int N, int i, bool upper) {
|
void tri_inv(array& inv, int N, int i, bool upper) {
|
||||||
const char uplo = upper ? 'L' : 'U';
|
const char uplo = upper ? 'L' : 'U';
|
||||||
const char diag = 'N';
|
const char diag = 'N';
|
||||||
float* data = inv.data<float>() + N * N * i;
|
T* data = inv.data<T>() + N * N * i;
|
||||||
int info = strtri_wrapper(uplo, diag, data, N);
|
int info;
|
||||||
|
trtri<T>(
|
||||||
|
/* uplo = */ &uplo,
|
||||||
|
/* diag = */ &diag,
|
||||||
|
/* N = */ &N,
|
||||||
|
/* a = */ data,
|
||||||
|
/* lda = */ &N,
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
// zero out the other triangle
|
// zero out the other triangle
|
||||||
if (upper) {
|
if (upper) {
|
||||||
@ -99,11 +96,13 @@ void tri_inv(array& inv, int N, int i, bool upper) {
|
|||||||
|
|
||||||
if (info != 0) {
|
if (info != 0) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "inverse_impl: triangular inversion failed with error code " << info;
|
ss << "[Inverse::eval_cpu] triangular inversion failed with error code "
|
||||||
|
<< info;
|
||||||
throw std::runtime_error(ss.str());
|
throw std::runtime_error(ss.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
||||||
// Lapack uses the column-major convention. We take advantage of the following
|
// Lapack uses the column-major convention. We take advantage of the following
|
||||||
// identity to avoid transposing (see
|
// identity to avoid transposing (see
|
||||||
@ -118,18 +117,25 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
|||||||
|
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
if (tri) {
|
if (tri) {
|
||||||
tri_inv(inv, N, i, upper);
|
tri_inv<T>(inv, N, i, upper);
|
||||||
} else {
|
} else {
|
||||||
general_inv(inv, N, i);
|
general_inv<T>(inv, N, i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
|
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||||
if (inputs[0].dtype() != float32) {
|
switch (inputs[0].dtype()) {
|
||||||
throw std::runtime_error("[Inverse::eval] only supports float32.");
|
case float32:
|
||||||
|
inverse_impl<float>(inputs[0], output, tri_, upper_);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
inverse_impl<double>(inputs[0], output, tri_, upper_);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[Inverse::eval_cpu] only supports float32 or float64.");
|
||||||
}
|
}
|
||||||
inverse_impl(inputs[0], output, tri_, upper_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -31,3 +31,22 @@
|
|||||||
#define MLX_LAPACK_FUNC(f) f##_
|
#define MLX_LAPACK_FUNC(f) f##_
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
|
||||||
|
template <typename T, typename... Args> \
|
||||||
|
void FUNC(Args... args) { \
|
||||||
|
if constexpr (std::is_same_v<T, float>) { \
|
||||||
|
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} else if constexpr (std::is_same_v<T, double>) { \
|
||||||
|
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_LAPACK_TYPES(geqrf)
|
||||||
|
INSTANTIATE_LAPACK_TYPES(orgqr)
|
||||||
|
INSTANTIATE_LAPACK_TYPES(syevd)
|
||||||
|
INSTANTIATE_LAPACK_TYPES(potrf)
|
||||||
|
INSTANTIATE_LAPACK_TYPES(gesvdx)
|
||||||
|
INSTANTIATE_LAPACK_TYPES(getrf)
|
||||||
|
INSTANTIATE_LAPACK_TYPES(getri)
|
||||||
|
INSTANTIATE_LAPACK_TYPES(trtri)
|
||||||
|
@ -9,11 +9,8 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void lu_factor_impl(
|
template <typename T>
|
||||||
const array& a,
|
void luf_impl(const array& a, array& lu, array& pivots, array& row_indices) {
|
||||||
array& lu,
|
|
||||||
array& pivots,
|
|
||||||
array& row_indices) {
|
|
||||||
int M = a.shape(-2);
|
int M = a.shape(-2);
|
||||||
int N = a.shape(-1);
|
int N = a.shape(-1);
|
||||||
|
|
||||||
@ -31,7 +28,7 @@ void lu_factor_impl(
|
|||||||
copy_inplace(
|
copy_inplace(
|
||||||
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
|
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
|
||||||
|
|
||||||
auto a_ptr = lu.data<float>();
|
auto a_ptr = lu.data<T>();
|
||||||
|
|
||||||
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
|
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
|
||||||
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
|
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
|
||||||
@ -42,8 +39,8 @@ void lu_factor_impl(
|
|||||||
size_t num_matrices = a.size() / (M * N);
|
size_t num_matrices = a.size() / (M * N);
|
||||||
for (size_t i = 0; i < num_matrices; ++i) {
|
for (size_t i = 0; i < num_matrices; ++i) {
|
||||||
// Compute LU factorization of A
|
// Compute LU factorization of A
|
||||||
MLX_LAPACK_FUNC(sgetrf)
|
getrf<T>(
|
||||||
(/* m */ &M,
|
/* m */ &M,
|
||||||
/* n */ &N,
|
/* n */ &N,
|
||||||
/* a */ a_ptr,
|
/* a */ a_ptr,
|
||||||
/* lda */ &M,
|
/* lda */ &M,
|
||||||
@ -86,7 +83,17 @@ void LUF::eval_cpu(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
switch (inputs[0].dtype()) {
|
||||||
|
case float32:
|
||||||
|
luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[LUF::eval_cpu] only supports float32 or float64.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -7,36 +7,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct lpack;
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct lpack<float> {
|
|
||||||
static void xgeqrf(
|
|
||||||
const int* m,
|
|
||||||
const int* n,
|
|
||||||
float* a,
|
|
||||||
const int* lda,
|
|
||||||
float* tau,
|
|
||||||
float* work,
|
|
||||||
const int* lwork,
|
|
||||||
int* info) {
|
|
||||||
sgeqrf_(m, n, a, lda, tau, work, lwork, info);
|
|
||||||
}
|
|
||||||
static void xorgqr(
|
|
||||||
const int* m,
|
|
||||||
const int* n,
|
|
||||||
const int* k,
|
|
||||||
float* a,
|
|
||||||
const int* lda,
|
|
||||||
const float* tau,
|
|
||||||
float* work,
|
|
||||||
const int* lwork,
|
|
||||||
int* info) {
|
|
||||||
sorgqr_(m, n, k, a, lda, tau, work, lwork, info);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void qrf_impl(const array& a, array& q, array& r) {
|
void qrf_impl(const array& a, array& q, array& r) {
|
||||||
const int M = a.shape(-2);
|
const int M = a.shape(-2);
|
||||||
@ -48,7 +18,7 @@ void qrf_impl(const array& a, array& q, array& r) {
|
|||||||
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
|
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
|
||||||
|
|
||||||
// Copy A to inplace input and make it col-contiguous
|
// Copy A to inplace input and make it col-contiguous
|
||||||
array in(a.shape(), float32, nullptr, {});
|
array in(a.shape(), a.dtype(), nullptr, {});
|
||||||
auto flags = in.flags();
|
auto flags = in.flags();
|
||||||
|
|
||||||
// Copy the input to be column contiguous
|
// Copy the input to be column contiguous
|
||||||
@ -66,8 +36,7 @@ void qrf_impl(const array& a, array& q, array& r) {
|
|||||||
int info;
|
int info;
|
||||||
|
|
||||||
// Compute workspace size
|
// Compute workspace size
|
||||||
lpack<T>::xgeqrf(
|
geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
|
||||||
&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
|
|
||||||
|
|
||||||
// Update workspace size
|
// Update workspace size
|
||||||
lwork = optimal_work;
|
lwork = optimal_work;
|
||||||
@ -76,10 +45,10 @@ void qrf_impl(const array& a, array& q, array& r) {
|
|||||||
// Loop over matrices
|
// Loop over matrices
|
||||||
for (int i = 0; i < num_matrices; ++i) {
|
for (int i = 0; i < num_matrices; ++i) {
|
||||||
// Solve
|
// Solve
|
||||||
lpack<T>::xgeqrf(
|
geqrf<T>(
|
||||||
&M,
|
&M,
|
||||||
&N,
|
&N,
|
||||||
in.data<float>() + M * N * i,
|
in.data<T>() + M * N * i,
|
||||||
&lda,
|
&lda,
|
||||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||||
static_cast<T*>(work.raw_ptr()),
|
static_cast<T*>(work.raw_ptr()),
|
||||||
@ -105,7 +74,7 @@ void qrf_impl(const array& a, array& q, array& r) {
|
|||||||
|
|
||||||
// Get work size
|
// Get work size
|
||||||
lwork = -1;
|
lwork = -1;
|
||||||
lpack<T>::xorgqr(
|
orgqr<T>(
|
||||||
&M,
|
&M,
|
||||||
&num_reflectors,
|
&num_reflectors,
|
||||||
&num_reflectors,
|
&num_reflectors,
|
||||||
@ -121,11 +90,11 @@ void qrf_impl(const array& a, array& q, array& r) {
|
|||||||
// Loop over matrices
|
// Loop over matrices
|
||||||
for (int i = 0; i < num_matrices; ++i) {
|
for (int i = 0; i < num_matrices; ++i) {
|
||||||
// Compute Q
|
// Compute Q
|
||||||
lpack<T>::xorgqr(
|
orgqr<T>(
|
||||||
&M,
|
&M,
|
||||||
&num_reflectors,
|
&num_reflectors,
|
||||||
&num_reflectors,
|
&num_reflectors,
|
||||||
in.data<float>() + M * N * i,
|
in.data<T>() + M * N * i,
|
||||||
&lda,
|
&lda,
|
||||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||||
static_cast<T*>(work.raw_ptr()),
|
static_cast<T*>(work.raw_ptr()),
|
||||||
@ -152,10 +121,17 @@ void qrf_impl(const array& a, array& q, array& r) {
|
|||||||
void QRF::eval_cpu(
|
void QRF::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
if (!(inputs[0].dtype() == float32)) {
|
switch (inputs[0].dtype()) {
|
||||||
throw std::runtime_error("[QRF::eval] only supports float32.");
|
case float32:
|
||||||
}
|
|
||||||
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
|
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
qrf_impl<double>(inputs[0], outputs[0], outputs[1]);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[QRF::eval_cpu] only supports float32 or float64.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
void svd_impl(const array& a, array& u, array& s, array& vt) {
|
void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||||
// Lapack uses the column-major convention. To avoid having to transpose
|
// Lapack uses the column-major convention. To avoid having to transpose
|
||||||
// the input and then transpose the outputs, we swap the indices/sizes of the
|
// the input and then transpose the outputs, we swap the indices/sizes of the
|
||||||
@ -31,7 +32,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
|||||||
size_t num_matrices = a.size() / (M * N);
|
size_t num_matrices = a.size() / (M * N);
|
||||||
|
|
||||||
// lapack clobbers the input, so we have to make a copy.
|
// lapack clobbers the input, so we have to make a copy.
|
||||||
array in(a.shape(), float32, nullptr, {});
|
array in(a.shape(), a.dtype(), nullptr, {});
|
||||||
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||||
|
|
||||||
// Allocate outputs.
|
// Allocate outputs.
|
||||||
@ -45,7 +46,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
|||||||
|
|
||||||
// Will contain the number of singular values after the call has returned.
|
// Will contain the number of singular values after the call has returned.
|
||||||
int ns = 0;
|
int ns = 0;
|
||||||
float workspace_dimension = 0;
|
T workspace_dimension = 0;
|
||||||
|
|
||||||
// Will contain the indices of eigenvectors that failed to converge (not used
|
// Will contain the indices of eigenvectors that failed to converge (not used
|
||||||
// here but required by lapack).
|
// here but required by lapack).
|
||||||
@ -54,13 +55,12 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
|||||||
static const int lwork_query = -1;
|
static const int lwork_query = -1;
|
||||||
|
|
||||||
static const int ignored_int = 0;
|
static const int ignored_int = 0;
|
||||||
static const float ignored_float = 0;
|
static const T ignored_float = 0;
|
||||||
|
|
||||||
int info;
|
int info;
|
||||||
|
|
||||||
// Compute workspace size.
|
// Compute workspace size.
|
||||||
MLX_LAPACK_FUNC(sgesvdx)
|
gesvdx<T>(
|
||||||
(
|
|
||||||
/* jobu = */ job_u,
|
/* jobu = */ job_u,
|
||||||
/* jobvt = */ job_vt,
|
/* jobvt = */ job_vt,
|
||||||
/* range = */ range,
|
/* range = */ range,
|
||||||
@ -86,51 +86,50 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
|||||||
|
|
||||||
if (info != 0) {
|
if (info != 0) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info;
|
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||||
throw std::runtime_error(ss.str());
|
throw std::runtime_error(ss.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
const int lwork = workspace_dimension;
|
const int lwork = workspace_dimension;
|
||||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||||
|
|
||||||
// Loop over matrices.
|
// Loop over matrices.
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
MLX_LAPACK_FUNC(sgesvdx)
|
gesvdx<T>(
|
||||||
(
|
|
||||||
/* jobu = */ job_u,
|
/* jobu = */ job_u,
|
||||||
/* jobvt = */ job_vt,
|
/* jobvt = */ job_vt,
|
||||||
/* range = */ range,
|
/* range = */ range,
|
||||||
// M and N are swapped since lapack expects column-major.
|
// M and N are swapped since lapack expects column-major.
|
||||||
/* m = */ &N,
|
/* m = */ &N,
|
||||||
/* n = */ &M,
|
/* n = */ &M,
|
||||||
/* a = */ in.data<float>() + M * N * i,
|
/* a = */ in.data<T>() + M * N * i,
|
||||||
/* lda = */ &lda,
|
/* lda = */ &lda,
|
||||||
/* vl = */ &ignored_float,
|
/* vl = */ &ignored_float,
|
||||||
/* vu = */ &ignored_float,
|
/* vu = */ &ignored_float,
|
||||||
/* il = */ &ignored_int,
|
/* il = */ &ignored_int,
|
||||||
/* iu = */ &ignored_int,
|
/* iu = */ &ignored_int,
|
||||||
/* ns = */ &ns,
|
/* ns = */ &ns,
|
||||||
/* s = */ s.data<float>() + K * i,
|
/* s = */ s.data<T>() + K * i,
|
||||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||||
/* u = */ vt.data<float>() + N * N * i,
|
/* u = */ vt.data<T>() + N * N * i,
|
||||||
/* ldu = */ &ldu,
|
/* ldu = */ &ldu,
|
||||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||||
/* vt = */ u.data<float>() + M * M * i,
|
/* vt = */ u.data<T>() + M * M * i,
|
||||||
/* ldvt = */ &ldvt,
|
/* ldvt = */ &ldvt,
|
||||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||||
/* lwork = */ &lwork,
|
/* lwork = */ &lwork,
|
||||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||||
/* info = */ &info);
|
/* info = */ &info);
|
||||||
|
|
||||||
if (info != 0) {
|
if (info != 0) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
ss << "[SVD::eval_cpu] failed with code " << info;
|
||||||
throw std::runtime_error(ss.str());
|
throw std::runtime_error(ss.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ns != K) {
|
if (ns != K) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "svd_impl: expected " << K << " singular values, but " << ns
|
ss << "[SVD::eval_cpu] expected " << K << " singular values, but " << ns
|
||||||
<< " were computed.";
|
<< " were computed.";
|
||||||
throw std::runtime_error(ss.str());
|
throw std::runtime_error(ss.str());
|
||||||
}
|
}
|
||||||
@ -140,10 +139,17 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
|||||||
void SVD::eval_cpu(
|
void SVD::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
if (!(inputs[0].dtype() == float32)) {
|
switch (inputs[0].dtype()) {
|
||||||
throw std::runtime_error("[SVD::eval] only supports float32.");
|
case float32:
|
||||||
|
svd_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
svd_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[SVD::eval_cpu] only supports float32 or float64.");
|
||||||
}
|
}
|
||||||
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -18,6 +18,14 @@ void check_cpu_stream(const StreamOrDevice& s, const std::string& prefix) {
|
|||||||
"Explicitly pass a CPU stream to run it.");
|
"Explicitly pass a CPU stream to run it.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
void check_float(Dtype dtype, const std::string& prefix) {
|
||||||
|
if (dtype != float32 && dtype != float64) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << prefix << " Arrays must have type float32 or float64. "
|
||||||
|
<< "Received array with type " << dtype << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Dtype at_least_float(const Dtype& d) {
|
Dtype at_least_float(const Dtype& d) {
|
||||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||||
@ -184,12 +192,8 @@ array norm(
|
|||||||
|
|
||||||
std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
check_cpu_stream(s, "[linalg::qr]");
|
check_cpu_stream(s, "[linalg::qr]");
|
||||||
if (a.dtype() != float32) {
|
check_float(a.dtype(), "[linalg::qr]");
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[linalg::qr] Arrays must type float32. Received array "
|
|
||||||
<< "with type " << a.dtype() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
if (a.ndim() < 2) {
|
if (a.ndim() < 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::qr] Arrays must have >= 2 dimensions. Received array "
|
msg << "[linalg::qr] Arrays must have >= 2 dimensions. Received array "
|
||||||
@ -212,12 +216,8 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
check_cpu_stream(s, "[linalg::svd]");
|
check_cpu_stream(s, "[linalg::svd]");
|
||||||
if (a.dtype() != float32) {
|
check_float(a.dtype(), "[linalg::svd]");
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[linalg::svd] Input array must have type float32. Received array "
|
|
||||||
<< "with type " << a.dtype() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
if (a.ndim() < 2) {
|
if (a.ndim() < 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array "
|
msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array "
|
||||||
@ -251,12 +251,8 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) {
|
array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) {
|
||||||
check_cpu_stream(s, "[linalg::inv]");
|
check_cpu_stream(s, "[linalg::inv]");
|
||||||
if (a.dtype() != float32) {
|
check_float(a.dtype(), "[linalg::inv]");
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[linalg::inv] Arrays must type float32. Received array "
|
|
||||||
<< "with type " << a.dtype() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
if (a.ndim() < 2) {
|
if (a.ndim() < 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::inv] Arrays must have >= 2 dimensions. Received array "
|
msg << "[linalg::inv] Arrays must have >= 2 dimensions. Received array "
|
||||||
@ -292,13 +288,7 @@ array cholesky(
|
|||||||
bool upper /* = false */,
|
bool upper /* = false */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
check_cpu_stream(s, "[linalg::cholesky]");
|
check_cpu_stream(s, "[linalg::cholesky]");
|
||||||
if (a.dtype() != float32) {
|
check_float(a.dtype(), "[linalg::cholesky]");
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[linalg::cholesky] Arrays must type float32. Received array "
|
|
||||||
<< "with type " << a.dtype() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (a.ndim() < 2) {
|
if (a.ndim() < 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array "
|
msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array "
|
||||||
@ -321,12 +311,8 @@ array cholesky(
|
|||||||
|
|
||||||
array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
check_cpu_stream(s, "[linalg::pinv]");
|
check_cpu_stream(s, "[linalg::pinv]");
|
||||||
if (a.dtype() != float32) {
|
check_float(a.dtype(), "[linalg::pinv]");
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[linalg::pinv] Arrays must type float32. Received array "
|
|
||||||
<< "with type " << a.dtype() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
if (a.ndim() < 2) {
|
if (a.ndim() < 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::pinv] Arrays must have >= 2 dimensions. Received array "
|
msg << "[linalg::pinv] Arrays must have >= 2 dimensions. Received array "
|
||||||
@ -368,12 +354,7 @@ array cholesky_inv(
|
|||||||
bool upper /* = false */,
|
bool upper /* = false */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
check_cpu_stream(s, "[linalg::cholesky_inv]");
|
check_cpu_stream(s, "[linalg::cholesky_inv]");
|
||||||
if (L.dtype() != float32) {
|
check_float(L.dtype(), "[linalg::cholesky_inv]");
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[linalg::cholesky_inv] Arrays must type float32. Received array "
|
|
||||||
<< "with type " << L.dtype() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (L.ndim() < 2) {
|
if (L.ndim() < 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -474,12 +455,7 @@ void validate_eigh(
|
|||||||
const StreamOrDevice& stream,
|
const StreamOrDevice& stream,
|
||||||
const std::string fname) {
|
const std::string fname) {
|
||||||
check_cpu_stream(stream, fname);
|
check_cpu_stream(stream, fname);
|
||||||
if (a.dtype() != float32) {
|
check_float(a.dtype(), fname);
|
||||||
std::ostringstream msg;
|
|
||||||
msg << fname << " Arrays must have type float32. Received array "
|
|
||||||
<< "with type " << a.dtype() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (a.ndim() < 2) {
|
if (a.ndim() < 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -524,12 +500,7 @@ void validate_lu(
|
|||||||
const StreamOrDevice& stream,
|
const StreamOrDevice& stream,
|
||||||
const std::string& fname) {
|
const std::string& fname) {
|
||||||
check_cpu_stream(stream, fname);
|
check_cpu_stream(stream, fname);
|
||||||
if (a.dtype() != float32) {
|
check_float(a.dtype(), fname);
|
||||||
std::ostringstream msg;
|
|
||||||
msg << fname << " Arrays must type float32. Received array "
|
|
||||||
<< "with type " << a.dtype() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (a.ndim() < 2) {
|
if (a.ndim() < 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -627,10 +598,12 @@ void validate_solve(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
if (out_type != float32) {
|
if (out_type != float32 && out_type != float64) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << fname << " Input arrays must promote to float32. Received arrays "
|
msg << fname
|
||||||
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
|
<< " Input arrays must promote to float32 or float64. "
|
||||||
|
" Received arrays with type "
|
||||||
|
<< a.dtype() << " and " << b.dtype() << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -44,6 +44,8 @@ std::string buffer_format(const mx::array& a) {
|
|||||||
return "f";
|
return "f";
|
||||||
case mx::bfloat16:
|
case mx::bfloat16:
|
||||||
return "B";
|
return "B";
|
||||||
|
case mx::float64:
|
||||||
|
return "d";
|
||||||
case mx::complex64:
|
case mx::complex64:
|
||||||
return "Zf\0";
|
return "Zf\0";
|
||||||
default: {
|
default: {
|
||||||
|
@ -152,6 +152,8 @@ nb::ndarray<NDParams...> mlx_to_nd_array(const mx::array& a) {
|
|||||||
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
|
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
|
||||||
case mx::float32:
|
case mx::float32:
|
||||||
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
||||||
|
case mx::float64:
|
||||||
|
return mlx_to_nd_array_impl<double, NDParams...>(a);
|
||||||
case mx::complex64:
|
case mx::complex64:
|
||||||
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
|
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
|
||||||
default:
|
default:
|
||||||
|
@ -183,6 +183,115 @@ class TestDouble(mlx_tests.MLXTestCase):
|
|||||||
c = a + b
|
c = a + b
|
||||||
self.assertEqual(c.dtype, mx.float64)
|
self.assertEqual(c.dtype, mx.float64)
|
||||||
|
|
||||||
|
def test_lapack(self):
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
# QRF
|
||||||
|
A = mx.array([[2.0, 3.0], [1.0, 2.0]], dtype=mx.float64)
|
||||||
|
Q, R = mx.linalg.qr(A)
|
||||||
|
out = Q @ R
|
||||||
|
self.assertTrue(mx.allclose(out, A))
|
||||||
|
out = Q.T @ Q
|
||||||
|
self.assertTrue(mx.allclose(out, mx.eye(2)))
|
||||||
|
self.assertTrue(mx.allclose(mx.tril(R, -1), mx.zeros_like(R)))
|
||||||
|
self.assertEqual(Q.dtype, mx.float64)
|
||||||
|
self.assertEqual(R.dtype, mx.float64)
|
||||||
|
|
||||||
|
# SVD
|
||||||
|
A = mx.array(
|
||||||
|
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64
|
||||||
|
)
|
||||||
|
U, S, Vt = mx.linalg.svd(A)
|
||||||
|
self.assertTrue(mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A))
|
||||||
|
|
||||||
|
# Inverse
|
||||||
|
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64)
|
||||||
|
A_inv = mx.linalg.inv(A)
|
||||||
|
self.assertTrue(mx.allclose(A @ A_inv, mx.eye(A.shape[0])))
|
||||||
|
|
||||||
|
# Tri inv
|
||||||
|
A = mx.array([[1, 0, 0], [6, -5, 0], [-9, 8, 7]], dtype=mx.float64)
|
||||||
|
B = mx.array([[7, 0, 0], [3, -2, 0], [1, 8, 3]], dtype=mx.float64)
|
||||||
|
AB = mx.stack([A, B])
|
||||||
|
invs = mx.linalg.tri_inv(AB, upper=False)
|
||||||
|
for M, M_inv in zip(AB, invs):
|
||||||
|
self.assertTrue(mx.allclose(M @ M_inv, mx.eye(M.shape[0])))
|
||||||
|
|
||||||
|
# Cholesky
|
||||||
|
sqrtA = mx.array(
|
||||||
|
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float64
|
||||||
|
)
|
||||||
|
A = sqrtA.T @ sqrtA / 81
|
||||||
|
L = mx.linalg.cholesky(A)
|
||||||
|
U = mx.linalg.cholesky(A, upper=True)
|
||||||
|
self.assertTrue(mx.allclose(L @ L.T, A))
|
||||||
|
self.assertTrue(mx.allclose(U.T @ U, A))
|
||||||
|
|
||||||
|
# Psueod inverse
|
||||||
|
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64)
|
||||||
|
A_plus = mx.linalg.pinv(A)
|
||||||
|
self.assertTrue(mx.allclose(A @ A_plus @ A, A))
|
||||||
|
|
||||||
|
# Eigh
|
||||||
|
def check_eigs_and_vecs(A_np, kwargs={}):
|
||||||
|
A = mx.array(A_np, dtype=mx.float64)
|
||||||
|
eig_vals, eig_vecs = mx.linalg.eigh(A, **kwargs)
|
||||||
|
eig_vals_np, _ = np.linalg.eigh(A_np, **kwargs)
|
||||||
|
self.assertTrue(np.allclose(eig_vals, eig_vals_np))
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs)
|
||||||
|
)
|
||||||
|
|
||||||
|
eig_vals_only = mx.linalg.eigvalsh(A, **kwargs)
|
||||||
|
self.assertTrue(mx.allclose(eig_vals, eig_vals_only))
|
||||||
|
|
||||||
|
# Test a simple 2x2 symmetric matrix
|
||||||
|
A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float64)
|
||||||
|
check_eigs_and_vecs(A_np)
|
||||||
|
|
||||||
|
# Test a larger random symmetric matrix
|
||||||
|
n = 5
|
||||||
|
np.random.seed(1)
|
||||||
|
A_np = np.random.randn(n, n).astype(np.float64)
|
||||||
|
A_np = (A_np + A_np.T) / 2
|
||||||
|
check_eigs_and_vecs(A_np)
|
||||||
|
|
||||||
|
# Test with upper triangle
|
||||||
|
check_eigs_and_vecs(A_np, {"UPLO": "U"})
|
||||||
|
|
||||||
|
# LU factorization
|
||||||
|
# Test 3x3 matrix
|
||||||
|
a = mx.array(
|
||||||
|
[[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]], dtype=mx.float64
|
||||||
|
)
|
||||||
|
P, L, U = mx.linalg.lu(a)
|
||||||
|
self.assertTrue(mx.allclose(L[P, :] @ U, a))
|
||||||
|
|
||||||
|
# Solve triangular
|
||||||
|
# Test lower triangular matrix
|
||||||
|
a = mx.array(
|
||||||
|
[[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]], dtype=mx.float64
|
||||||
|
)
|
||||||
|
b = mx.array([8.0, 14.0, 3.0], dtype=mx.float64)
|
||||||
|
|
||||||
|
result = mx.linalg.solve_triangular(a, b, upper=False)
|
||||||
|
expected = np.linalg.solve(np.array(a), np.array(b))
|
||||||
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
# Test upper triangular matrix
|
||||||
|
a = mx.array(
|
||||||
|
[[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]], dtype=mx.float64
|
||||||
|
)
|
||||||
|
b = mx.array([13.0, 33.0, 18.0], dtype=mx.float64)
|
||||||
|
|
||||||
|
result = mx.linalg.solve_triangular(a, b, upper=True)
|
||||||
|
expected = np.linalg.solve(np.array(a), np.array(b))
|
||||||
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
def test_conversion(self):
|
||||||
|
a = mx.array([1.0, 2.0], mx.float64)
|
||||||
|
b = np.array(a)
|
||||||
|
self.assertTrue(np.array_equal(a, b))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user