mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add complex eigh (#2191)
This commit is contained in:
parent
48ef3e74e2
commit
0654543dcc
@ -224,6 +224,10 @@ class array {
|
|||||||
// Not copyable
|
// Not copyable
|
||||||
Data(const Data& d) = delete;
|
Data(const Data& d) = delete;
|
||||||
Data& operator=(const Data& d) = delete;
|
Data& operator=(const Data& d) = delete;
|
||||||
|
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
||||||
|
o.buffer = allocator::Buffer(nullptr);
|
||||||
|
o.d = [](allocator::Buffer) {};
|
||||||
|
}
|
||||||
~Data() {
|
~Data() {
|
||||||
d(buffer);
|
d(buffer);
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,133 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename T, class Enable = void>
|
||||||
|
struct EighWork {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct EighWork<
|
||||||
|
T,
|
||||||
|
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||||
|
using R = T;
|
||||||
|
|
||||||
|
char jobz;
|
||||||
|
char uplo;
|
||||||
|
int N;
|
||||||
|
int lwork;
|
||||||
|
int liwork;
|
||||||
|
int info;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EighWork(char jobz_, char uplo_, int N_)
|
||||||
|
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
|
||||||
|
T work;
|
||||||
|
int iwork;
|
||||||
|
syevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&iwork,
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work);
|
||||||
|
liwork = iwork;
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* vectors, T* values) {
|
||||||
|
syevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
vectors,
|
||||||
|
&N,
|
||||||
|
values,
|
||||||
|
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
static_cast<int*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct EighWork<std::complex<float>> {
|
||||||
|
using T = std::complex<float>;
|
||||||
|
using R = float;
|
||||||
|
|
||||||
|
char jobz;
|
||||||
|
char uplo;
|
||||||
|
int N;
|
||||||
|
int lwork;
|
||||||
|
int lrwork;
|
||||||
|
int liwork;
|
||||||
|
int info;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EighWork(char jobz_, char uplo_, int N_)
|
||||||
|
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
|
||||||
|
T work;
|
||||||
|
R rwork;
|
||||||
|
int iwork;
|
||||||
|
heevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&rwork,
|
||||||
|
&lrwork,
|
||||||
|
&iwork,
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work.real());
|
||||||
|
lrwork = static_cast<int>(rwork);
|
||||||
|
liwork = iwork;
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* vectors, R* values) {
|
||||||
|
heevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
vectors,
|
||||||
|
&N,
|
||||||
|
values,
|
||||||
|
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
&lrwork,
|
||||||
|
static_cast<int*>(buffers[2].buffer.raw_ptr()),
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
if (jobz == 'V') {
|
||||||
|
// We have pre-transposed the vectors but we also must conjugate them
|
||||||
|
// when they are complex.
|
||||||
|
//
|
||||||
|
// We could vectorize this but it is so fast in comparison to heevd that
|
||||||
|
// it doesn't really matter.
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
for (int j = 0; j < N; j++) {
|
||||||
|
*vectors = std::conj(*vectors);
|
||||||
|
vectors++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void eigh_impl(
|
void eigh_impl(
|
||||||
array& vectors,
|
array& vectors,
|
||||||
@ -19,8 +146,10 @@ void eigh_impl(
|
|||||||
const std::string& uplo,
|
const std::string& uplo,
|
||||||
bool compute_eigenvectors,
|
bool compute_eigenvectors,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
|
using R = typename EighWork<T>::R;
|
||||||
|
|
||||||
auto vec_ptr = vectors.data<T>();
|
auto vec_ptr = vectors.data<T>();
|
||||||
auto eig_ptr = values.data<T>();
|
auto eig_ptr = values.data<R>();
|
||||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
@ -33,49 +162,17 @@ void eigh_impl(
|
|||||||
N = vectors.shape(-1),
|
N = vectors.shape(-1),
|
||||||
size = vectors.size()]() mutable {
|
size = vectors.size()]() mutable {
|
||||||
// Work query
|
// Work query
|
||||||
int lwork = -1;
|
EighWork<T> work(jobz, uplo, N);
|
||||||
int liwork = -1;
|
|
||||||
int info;
|
|
||||||
{
|
|
||||||
T work;
|
|
||||||
int iwork;
|
|
||||||
syevd<T>(
|
|
||||||
&jobz,
|
|
||||||
&uplo,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&work,
|
|
||||||
&lwork,
|
|
||||||
&iwork,
|
|
||||||
&liwork,
|
|
||||||
&info);
|
|
||||||
lwork = static_cast<int>(work);
|
|
||||||
liwork = iwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
// Work loop
|
||||||
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
|
|
||||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||||
syevd<T>(
|
work.run(vec_ptr, eig_ptr);
|
||||||
&jobz,
|
|
||||||
&uplo,
|
|
||||||
&N,
|
|
||||||
vec_ptr,
|
|
||||||
&N,
|
|
||||||
eig_ptr,
|
|
||||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
|
||||||
&lwork,
|
|
||||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
|
||||||
&liwork,
|
|
||||||
&info);
|
|
||||||
vec_ptr += N * N;
|
vec_ptr += N * N;
|
||||||
eig_ptr += N;
|
eig_ptr += N;
|
||||||
if (info != 0) {
|
if (work.info != 0) {
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||||
<< info;
|
<< work.info;
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -131,6 +228,10 @@ void Eigh::eval_cpu(
|
|||||||
eigh_impl<double>(
|
eigh_impl<double>(
|
||||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||||
break;
|
break;
|
||||||
|
case complex64:
|
||||||
|
eigh_impl<std::complex<float>>(
|
||||||
|
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[Eigh::eval_cpu] only supports float32 or float64.");
|
"[Eigh::eval_cpu] only supports float32 or float64.");
|
||||||
|
@ -2,14 +2,14 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
// Required for Visual Studio.
|
|
||||||
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
|
|
||||||
#ifdef _MSC_VER
|
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#define LAPACK_COMPLEX_CUSTOM
|
#define LAPACK_COMPLEX_CUSTOM
|
||||||
#define lapack_complex_float std::complex<float>
|
#define lapack_complex_float std::complex<float>
|
||||||
#define lapack_complex_double std::complex<double>
|
#define lapack_complex_double std::complex<double>
|
||||||
#endif
|
#define lapack_complex_float_real(z) ((z).real())
|
||||||
|
#define lapack_complex_float_imag(z) ((z).imag())
|
||||||
|
#define lapack_complex_double_real(z) ((z).real())
|
||||||
|
#define lapack_complex_double_imag(z) ((z).imag())
|
||||||
|
|
||||||
#ifdef MLX_USE_ACCELERATE
|
#ifdef MLX_USE_ACCELERATE
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
@ -32,7 +32,7 @@
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
|
#define INSTANTIATE_LAPACK_REAL(FUNC) \
|
||||||
template <typename T, typename... Args> \
|
template <typename T, typename... Args> \
|
||||||
void FUNC(Args... args) { \
|
void FUNC(Args... args) { \
|
||||||
if constexpr (std::is_same_v<T, float>) { \
|
if constexpr (std::is_same_v<T, float>) { \
|
||||||
@ -42,12 +42,24 @@
|
|||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_LAPACK_TYPES(geqrf)
|
INSTANTIATE_LAPACK_REAL(geqrf)
|
||||||
INSTANTIATE_LAPACK_TYPES(orgqr)
|
INSTANTIATE_LAPACK_REAL(orgqr)
|
||||||
INSTANTIATE_LAPACK_TYPES(syevd)
|
INSTANTIATE_LAPACK_REAL(syevd)
|
||||||
INSTANTIATE_LAPACK_TYPES(geev)
|
INSTANTIATE_LAPACK_REAL(geev)
|
||||||
INSTANTIATE_LAPACK_TYPES(potrf)
|
INSTANTIATE_LAPACK_REAL(potrf)
|
||||||
INSTANTIATE_LAPACK_TYPES(gesvdx)
|
INSTANTIATE_LAPACK_REAL(gesvdx)
|
||||||
INSTANTIATE_LAPACK_TYPES(getrf)
|
INSTANTIATE_LAPACK_REAL(getrf)
|
||||||
INSTANTIATE_LAPACK_TYPES(getri)
|
INSTANTIATE_LAPACK_REAL(getri)
|
||||||
INSTANTIATE_LAPACK_TYPES(trtri)
|
INSTANTIATE_LAPACK_REAL(trtri)
|
||||||
|
|
||||||
|
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
|
||||||
|
template <typename T, typename... Args> \
|
||||||
|
void FUNC(Args... args) { \
|
||||||
|
if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
||||||
|
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
||||||
|
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
||||||
|
@ -27,6 +27,15 @@ void check_float(Dtype dtype, const std::string& prefix) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void check_float_or_complex(Dtype dtype, const std::string& prefix) {
|
||||||
|
if (dtype != float32 && dtype != float64 && dtype != complex64) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << prefix << " Arrays must have type float32, float64 or complex64. "
|
||||||
|
<< "Received array with type " << dtype << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Dtype at_least_float(const Dtype& d) {
|
Dtype at_least_float(const Dtype& d) {
|
||||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||||
}
|
}
|
||||||
@ -493,7 +502,7 @@ void validate_eig(
|
|||||||
const StreamOrDevice& stream,
|
const StreamOrDevice& stream,
|
||||||
const std::string fname) {
|
const std::string fname) {
|
||||||
check_cpu_stream(stream, fname);
|
check_cpu_stream(stream, fname);
|
||||||
check_float(a.dtype(), fname);
|
check_float_or_complex(a.dtype(), fname);
|
||||||
|
|
||||||
if (a.ndim() < 2) {
|
if (a.ndim() < 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -513,9 +522,10 @@ array eigvalsh(
|
|||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
validate_eig(a, s, "[linalg::eigvalsh]");
|
validate_eig(a, s, "[linalg::eigvalsh]");
|
||||||
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
|
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
|
||||||
|
Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();
|
||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
a.dtype(),
|
eigval_type,
|
||||||
std::make_shared<Eigh>(to_stream(s), UPLO, false),
|
std::make_shared<Eigh>(to_stream(s), UPLO, false),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
@ -525,9 +535,10 @@ std::pair<array, array> eigh(
|
|||||||
std::string UPLO /* = "L" */,
|
std::string UPLO /* = "L" */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
validate_eig(a, s, "[linalg::eigh]");
|
validate_eig(a, s, "[linalg::eigh]");
|
||||||
|
Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();
|
||||||
auto out = array::make_arrays(
|
auto out = array::make_arrays(
|
||||||
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
||||||
{a.dtype(), a.dtype()},
|
{eigval_type, a.dtype()},
|
||||||
std::make_shared<Eigh>(to_stream(s), UPLO, true),
|
std::make_shared<Eigh>(to_stream(s), UPLO, true),
|
||||||
{a});
|
{a});
|
||||||
return std::make_pair(out[0], out[1]);
|
return std::make_pair(out[0], out[1]);
|
||||||
|
@ -423,6 +423,13 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
A_np = (A_np + np.transpose(A_np, (0, 2, 1))) / 2
|
A_np = (A_np + np.transpose(A_np, (0, 2, 1))) / 2
|
||||||
check_eigs_and_vecs(A_np)
|
check_eigs_and_vecs(A_np)
|
||||||
|
|
||||||
|
# Test with complex inputs
|
||||||
|
A_np = (
|
||||||
|
np.random.randn(8, 8, 2).astype(np.float32).view(np.complex64).squeeze(-1)
|
||||||
|
)
|
||||||
|
A_np = A_np + A_np.T.conj()
|
||||||
|
check_eigs_and_vecs(A_np)
|
||||||
|
|
||||||
# Test error cases
|
# Test error cases
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array
|
mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array
|
||||||
|
Loading…
Reference in New Issue
Block a user