mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add float64 Eig and complex64 SVD/Eig support (Fixes #2708) (#2737)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -12,6 +12,167 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
complex64_t to_complex(T r, T i) {
|
||||
return {static_cast<float>(r), static_cast<float>(i)};
|
||||
}
|
||||
|
||||
template <typename T, class Enable = void>
|
||||
struct EigWork {};
|
||||
|
||||
template <typename T>
|
||||
struct EigWork<
|
||||
T,
|
||||
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||
using O = complex64_t;
|
||||
|
||||
char jobl;
|
||||
char jobr;
|
||||
int N;
|
||||
int lwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
|
||||
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
|
||||
T work;
|
||||
int n_vecs_l = compute_eigenvectors ? N_ : 1;
|
||||
int n_vecs_r = 1;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
&work,
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
|
||||
if (compute_eigenvectors) {
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
|
||||
}
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
}
|
||||
|
||||
void run(T* a, O* values, O* vectors) {
|
||||
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
|
||||
T* vec_tmp = nullptr;
|
||||
if (vectors) {
|
||||
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
|
||||
}
|
||||
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
|
||||
|
||||
int n_vecs_l = vectors ? N : 1;
|
||||
int n_vecs_r = 1;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
a,
|
||||
&N,
|
||||
eig_tmp,
|
||||
eig_tmp + N,
|
||||
vectors ? vec_tmp : nullptr,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
work,
|
||||
&lwork,
|
||||
&info);
|
||||
|
||||
for (int i = 0; i < N; ++i) {
|
||||
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
|
||||
}
|
||||
|
||||
if (vectors) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
if (values[i].imag() != 0) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vectors[i * N + j] =
|
||||
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
|
||||
vectors[(i + 1) * N + j] =
|
||||
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
|
||||
}
|
||||
i += 1;
|
||||
} else {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct EigWork<std::complex<float>> {
|
||||
using T = std::complex<float>;
|
||||
using R = float;
|
||||
using O = T;
|
||||
|
||||
char jobl;
|
||||
char jobr;
|
||||
int N;
|
||||
int lwork;
|
||||
int lrwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
|
||||
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
|
||||
T work;
|
||||
R rwork;
|
||||
int n_vecs_l = compute_eigenvectors ? N_ : 1;
|
||||
int n_vecs_r = 1;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
nullptr,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
&work,
|
||||
&lwork,
|
||||
&rwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work.real());
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||
}
|
||||
|
||||
void run(T* a, T* values, T* vectors) {
|
||||
int n_vecs_l = vectors ? N : 1;
|
||||
int n_vecs_r = 1;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
a,
|
||||
&N,
|
||||
values,
|
||||
vectors,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||
&info);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void eig_impl(
|
||||
array& a,
|
||||
@@ -19,101 +180,39 @@ void eig_impl(
|
||||
array& values,
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
using OT = std::complex<T>;
|
||||
auto a_ptr = a.data<T>();
|
||||
auto eig_ptr = values.data<OT>();
|
||||
auto val_ptr = values.data<complex64_t>();
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(values);
|
||||
OT* vec_ptr = nullptr;
|
||||
complex64_t* vec_ptr = nullptr;
|
||||
if (compute_eigenvectors) {
|
||||
encoder.set_output_array(vectors);
|
||||
vec_ptr = vectors.data<OT>();
|
||||
vec_ptr = vectors.data<complex64_t>();
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
val_ptr,
|
||||
vec_ptr,
|
||||
eig_ptr,
|
||||
compute_eigenvectors,
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
char jobr = 'N';
|
||||
char jobl = compute_eigenvectors ? 'V' : 'N';
|
||||
int n_vecs_r = 1;
|
||||
int n_vecs_l = compute_eigenvectors ? N : 1;
|
||||
int lwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
&work,
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
}
|
||||
|
||||
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
||||
auto vec_tmp_data =
|
||||
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
EigWork<T> work(jobl, jobr, N, compute_eigenvectors);
|
||||
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
a_ptr,
|
||||
&N,
|
||||
eig_tmp,
|
||||
eig_tmp + N,
|
||||
vec_tmp,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
||||
}
|
||||
work.run(a_ptr, val_ptr, vec_ptr);
|
||||
a_ptr += N * N;
|
||||
val_ptr += N;
|
||||
if (vec_ptr) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
if (eig_ptr[i].imag() != 0) {
|
||||
// This vector and the next are a pair
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {
|
||||
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
||||
vec_ptr[(i + 1) * N + j] = {
|
||||
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
||||
}
|
||||
i += 1;
|
||||
} else {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
||||
}
|
||||
}
|
||||
}
|
||||
vec_ptr += N * N;
|
||||
}
|
||||
a_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
if (work.info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
<< work.info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
@@ -165,8 +264,17 @@ void Eig::eval_cpu(
|
||||
case float32:
|
||||
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||
break;
|
||||
case float64:
|
||||
eig_impl<double>(
|
||||
a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
eig_impl<std::complex<float>>(
|
||||
a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
||||
throw std::runtime_error(
|
||||
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -45,9 +45,7 @@
|
||||
INSTANTIATE_LAPACK_REAL(geqrf)
|
||||
INSTANTIATE_LAPACK_REAL(orgqr)
|
||||
INSTANTIATE_LAPACK_REAL(syevd)
|
||||
INSTANTIATE_LAPACK_REAL(geev)
|
||||
INSTANTIATE_LAPACK_REAL(potrf)
|
||||
INSTANTIATE_LAPACK_REAL(gesdd)
|
||||
INSTANTIATE_LAPACK_REAL(getrf)
|
||||
INSTANTIATE_LAPACK_REAL(getri)
|
||||
INSTANTIATE_LAPACK_REAL(trtri)
|
||||
@@ -63,3 +61,20 @@ INSTANTIATE_LAPACK_REAL(trtri)
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
||||
|
||||
#define INSTANTIATE_LAPACK_ALL(FUNC) \
|
||||
template <typename T, typename... Args> \
|
||||
void FUNC(Args... args) { \
|
||||
if constexpr (std::is_same_v<T, float>) { \
|
||||
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
|
||||
} else if constexpr (std::is_same_v<T, double>) { \
|
||||
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
|
||||
} else if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
||||
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
||||
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
||||
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_ALL(geev)
|
||||
INSTANTIATE_LAPACK_ALL(gesdd)
|
||||
|
||||
@@ -8,6 +8,183 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T, class Enable = void>
|
||||
struct SVDWork {};
|
||||
|
||||
template <typename T>
|
||||
struct SVDWork<
|
||||
T,
|
||||
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||
using R = T;
|
||||
|
||||
int N;
|
||||
int M;
|
||||
int K;
|
||||
int lda;
|
||||
int ldu;
|
||||
int ldvt;
|
||||
char jobz;
|
||||
std::vector<array::Data> buffers;
|
||||
int lwork;
|
||||
|
||||
SVDWork(int N, int M, int K, char jobz)
|
||||
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
|
||||
T workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
|
||||
|
||||
int lwork_query = -1;
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
gesdd<T>(
|
||||
/* jobz = */ &jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
/* vt = */ nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ &workspace_dimension,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
lwork = workspace_dimension;
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
}
|
||||
|
||||
void run(T* a, R* s, T* u, T* vt) {
|
||||
int info;
|
||||
gesdd<T>(
|
||||
/* jobz = */ &jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ a,
|
||||
/* lda = */ &lda,
|
||||
/* s = */ s,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ u,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ vt,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<T*>(buffers[1].buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SVDWork<std::complex<float>> {
|
||||
using T = std::complex<float>;
|
||||
using R = float;
|
||||
|
||||
int N;
|
||||
int M;
|
||||
int K;
|
||||
int lda;
|
||||
int ldu;
|
||||
int ldvt;
|
||||
char jobz;
|
||||
std::vector<array::Data> buffers;
|
||||
int lwork;
|
||||
|
||||
SVDWork(int N, int M, int K, char jobz)
|
||||
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
|
||||
T workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
|
||||
|
||||
const int lrwork =
|
||||
jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K);
|
||||
buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork));
|
||||
|
||||
int lwork_query = -1;
|
||||
int work_query = -1;
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
gesdd<T>(
|
||||
/* jobz = */ &jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
/* vt = */ nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ &workspace_dimension,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
|
||||
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
lwork = workspace_dimension.real();
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
}
|
||||
|
||||
void run(T* a, R* s, T* u, T* vt) {
|
||||
int info;
|
||||
gesdd<T>(
|
||||
/* jobz = */ &jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ a,
|
||||
/* lda = */ &lda,
|
||||
/* s = */ s,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ u,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ vt,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<T*>(buffers[2].buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
|
||||
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void svd_impl(
|
||||
const array& a,
|
||||
@@ -27,6 +204,8 @@ void svd_impl(
|
||||
const int N = a.shape(-1);
|
||||
const int K = std::min(M, N);
|
||||
|
||||
using R = typename SVDWork<T>::R;
|
||||
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// lapack clobbers the input, so we have to make a copy.
|
||||
@@ -42,7 +221,7 @@ void svd_impl(
|
||||
encoder.set_input_array(a);
|
||||
auto in_ptr = in.data<T>();
|
||||
T* u_ptr;
|
||||
T* s_ptr;
|
||||
R* s_ptr;
|
||||
T* vt_ptr;
|
||||
|
||||
if (compute_uv) {
|
||||
@@ -58,7 +237,7 @@ void svd_impl(
|
||||
encoder.set_output_array(s);
|
||||
encoder.set_output_array(vt);
|
||||
|
||||
s_ptr = s.data<T>();
|
||||
s_ptr = s.data<R>();
|
||||
u_ptr = u.data<T>();
|
||||
vt_ptr = vt.data<T>();
|
||||
} else {
|
||||
@@ -68,96 +247,26 @@ void svd_impl(
|
||||
|
||||
encoder.set_output_array(s);
|
||||
|
||||
s_ptr = s.data<T>();
|
||||
s_ptr = s.data<R>();
|
||||
u_ptr = nullptr;
|
||||
vt_ptr = nullptr;
|
||||
}
|
||||
|
||||
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
|
||||
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
|
||||
const int lda = N;
|
||||
// U of shape M x M. (N x N in lapack).
|
||||
const int ldu = N;
|
||||
// Vᵀ of shape N x N. (M x M in lapack).
|
||||
const int ldvt = M;
|
||||
|
||||
auto jobz = (u_ptr) ? "A" : "N";
|
||||
|
||||
T workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
gesdd<T>(
|
||||
/* jobz = */ jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
/* vt = */ nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ &workspace_dimension,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_dimension;
|
||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
|
||||
auto jobz = (u_ptr) ? 'A' : 'N';
|
||||
SVDWork<T> svd_work(N, M, K, jobz);
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
gesdd<T>(
|
||||
/* jobz = */ jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ in_ptr + M * N * i,
|
||||
/* lda = */ &lda,
|
||||
/* s = */ s_ptr + K * i,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
svd_work.run(
|
||||
in_ptr + M * N * i,
|
||||
s_ptr + K * i,
|
||||
vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||
u_ptr ? u_ptr + M * M * i : nullptr);
|
||||
}
|
||||
});
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void compute_svd(
|
||||
const array& a,
|
||||
bool compute_uv,
|
||||
std::vector<array>& outputs,
|
||||
Stream stream) {}
|
||||
|
||||
void SVD::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
@@ -168,9 +277,12 @@ void SVD::eval_cpu(
|
||||
case float64:
|
||||
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
svd_impl<std::complex<float>>(inputs[0], outputs, compute_uv_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[SVD::eval_cpu] only supports float32 or float64.");
|
||||
"[SVD::eval_cpu] only supports float32, float64, or complex64.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -250,7 +250,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
||||
std::vector<array>
|
||||
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
||||
check_cpu_stream(s, "[linalg::svd]");
|
||||
check_float(a.dtype(), "[linalg::svd]");
|
||||
check_float_or_complex(a.dtype(), "[linalg::svd]");
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
@@ -268,10 +268,12 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
||||
s_shape.pop_back();
|
||||
s_shape[rank - 2] = std::min(m, n);
|
||||
|
||||
auto s_dtype = a.dtype() == complex64 ? float32 : a.dtype();
|
||||
|
||||
if (!compute_uv) {
|
||||
return {array(
|
||||
std::move(s_shape),
|
||||
a.dtype(),
|
||||
s_dtype,
|
||||
std::make_shared<SVD>(to_stream(s), compute_uv),
|
||||
{a})};
|
||||
}
|
||||
@@ -286,7 +288,7 @@ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
||||
|
||||
return array::make_arrays(
|
||||
{u_shape, s_shape, vt_shape},
|
||||
{a.dtype(), a.dtype(), a.dtype()},
|
||||
{a.dtype(), s_dtype, a.dtype()},
|
||||
std::make_shared<SVD>(to_stream(s), compute_uv),
|
||||
{a});
|
||||
}
|
||||
|
||||
@@ -168,6 +168,42 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
# Test float64 - use CPU stream since float64 is not supported on GPU
|
||||
with mx.stream(mx.cpu):
|
||||
A_f64 = mx.array(
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64
|
||||
)
|
||||
U_f64, S_f64, Vt_f64 = mx.linalg.svd(A_f64, compute_uv=True)
|
||||
mx.eval(U_f64, S_f64, Vt_f64)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
U_f64[:, : len(S_f64)] @ mx.diag(S_f64) @ Vt_f64,
|
||||
A_f64,
|
||||
rtol=1e-5,
|
||||
atol=1e-7,
|
||||
)
|
||||
)
|
||||
self.assertEqual(S_f64.dtype, mx.float64)
|
||||
|
||||
# Test complex64 - use CPU stream since complex64 is not supported on GPU
|
||||
with mx.stream(mx.cpu):
|
||||
A_c64 = mx.array(
|
||||
[[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=mx.complex64
|
||||
)
|
||||
U_c64, S_c64, Vt_c64 = mx.linalg.svd(A_c64, compute_uv=True)
|
||||
mx.eval(U_c64, S_c64, Vt_c64)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
U_c64[:, : len(S_c64)] @ mx.diag(S_c64) @ Vt_c64,
|
||||
A_c64,
|
||||
rtol=1e-5,
|
||||
atol=1e-7,
|
||||
)
|
||||
)
|
||||
self.assertEqual(S_c64.dtype, mx.float32)
|
||||
self.assertEqual(U_c64.dtype, mx.complex64)
|
||||
self.assertEqual(Vt_c64.dtype, mx.complex64)
|
||||
|
||||
def test_inverse(self):
|
||||
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
|
||||
A_inv = mx.linalg.inv(A, stream=mx.cpu)
|
||||
@@ -342,6 +378,43 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
A_np = np.random.randn(3, n, n).astype(np.float32)
|
||||
check_eigs_and_vecs(A_np)
|
||||
|
||||
# Test float64 - use CPU stream since float64 is not supported on GPU
|
||||
with mx.stream(mx.cpu):
|
||||
A_np_f64 = np.array([[1.0, 1.0], [3.0, 4.0]], dtype=np.float64)
|
||||
A_f64 = mx.array(A_np_f64, dtype=mx.float64)
|
||||
eig_vals_f64, eig_vecs_f64 = mx.linalg.eig(A_f64)
|
||||
mx.eval(eig_vals_f64, eig_vecs_f64)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
A_f64 @ eig_vecs_f64,
|
||||
eig_vals_f64[..., None, :] * eig_vecs_f64,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
)
|
||||
)
|
||||
# Eigenvalues should be complex64 (output dtype)
|
||||
self.assertEqual(eig_vals_f64.dtype, mx.complex64)
|
||||
self.assertEqual(eig_vecs_f64.dtype, mx.complex64)
|
||||
|
||||
# Test complex64 input - use CPU stream since complex64 is not supported on GPU
|
||||
with mx.stream(mx.cpu):
|
||||
A_np_c64 = np.array(
|
||||
[[1.0 + 1j, 2.0 + 2j], [3.0 + 3j, 4.0 + 4j]], dtype=np.complex64
|
||||
)
|
||||
A_c64 = mx.array(A_np_c64, dtype=mx.complex64)
|
||||
eig_vals_c64, eig_vecs_c64 = mx.linalg.eig(A_c64)
|
||||
mx.eval(eig_vals_c64, eig_vecs_c64)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
A_c64 @ eig_vecs_c64,
|
||||
eig_vals_c64[..., None, :] * eig_vecs_c64,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
)
|
||||
)
|
||||
self.assertEqual(eig_vals_c64.dtype, mx.complex64)
|
||||
self.assertEqual(eig_vecs_c64.dtype, mx.complex64)
|
||||
|
||||
# Test error cases
|
||||
with self.assertRaises(ValueError):
|
||||
mx.linalg.eig(mx.array([1.0, 2.0])) # 1D array
|
||||
|
||||
Reference in New Issue
Block a user