Add complex eigh (#2191)

This commit is contained in:
Angelos Katharopoulos 2025-05-18 00:18:43 -07:00 committed by GitHub
parent 48ef3e74e2
commit 0654543dcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 190 additions and 55 deletions

View File

@ -224,6 +224,10 @@ class array {
// Not copyable // Not copyable
Data(const Data& d) = delete; Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete; Data& operator=(const Data& d) = delete;
Data(Data&& o) : buffer(o.buffer), d(o.d) {
o.buffer = allocator::Buffer(nullptr);
o.d = [](allocator::Buffer) {};
}
~Data() { ~Data() {
d(buffer); d(buffer);
} }

View File

@ -12,31 +12,25 @@ namespace mlx::core {
namespace { namespace {
template <typename T> template <typename T, class Enable = void>
void eigh_impl( struct EighWork {};
array& vectors,
array& values,
const std::string& uplo,
bool compute_eigenvectors,
Stream stream) {
auto vec_ptr = vectors.data<T>();
auto eig_ptr = values.data<T>();
char jobz = compute_eigenvectors ? 'V' : 'N';
auto& encoder = cpu::get_command_encoder(stream); template <typename T>
encoder.set_output_array(vectors); struct EighWork<
encoder.set_output_array(values); T,
encoder.dispatch([vec_ptr, typename std::enable_if<std::is_floating_point<T>::value>::type> {
eig_ptr, using R = T;
jobz,
uplo = uplo[0], char jobz;
N = vectors.shape(-1), char uplo;
size = vectors.size()]() mutable { int N;
// Work query int lwork;
int lwork = -1; int liwork;
int liwork = -1;
int info; int info;
{ std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
T work; T work;
int iwork; int iwork;
syevd<T>( syevd<T>(
@ -53,29 +47,132 @@ void eigh_impl(
&info); &info);
lwork = static_cast<int>(work); lwork = static_cast<int>(work);
liwork = iwork; liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
} }
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)}; void run(T* vectors, T* values) {
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
syevd<T>( syevd<T>(
&jobz, &jobz,
&uplo, &uplo,
&N, &N,
vec_ptr, vectors,
&N, &N,
eig_ptr, values,
static_cast<T*>(work_buf.buffer.raw_ptr()), static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork, &lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()), static_cast<int*>(buffers[1].buffer.raw_ptr()),
&liwork, &liwork,
&info); &info);
}
};
template <>
struct EighWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
char jobz;
char uplo;
int N;
int lwork;
int lrwork;
int liwork;
int info;
std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
T work;
R rwork;
int iwork;
heevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&rwork,
&lrwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work.real());
lrwork = static_cast<int>(rwork);
liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
}
void run(T* vectors, R* values) {
heevd<T>(
&jobz,
&uplo,
&N,
vectors,
&N,
values,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&lrwork,
static_cast<int*>(buffers[2].buffer.raw_ptr()),
&liwork,
&info);
if (jobz == 'V') {
// We have pre-transposed the vectors but we also must conjugate them
// when they are complex.
//
// We could vectorize this but it is so fast in comparison to heevd that
// it doesn't really matter.
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
*vectors = std::conj(*vectors);
vectors++;
}
}
}
}
};
template <typename T>
void eigh_impl(
array& vectors,
array& values,
const std::string& uplo,
bool compute_eigenvectors,
Stream stream) {
using R = typename EighWork<T>::R;
auto vec_ptr = vectors.data<T>();
auto eig_ptr = values.data<R>();
char jobz = compute_eigenvectors ? 'V' : 'N';
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(vectors);
encoder.set_output_array(values);
encoder.dispatch([vec_ptr,
eig_ptr,
jobz,
uplo = uplo[0],
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
EighWork<T> work(jobz, uplo, N);
// Work loop
for (size_t i = 0; i < size / (N * N); ++i) {
work.run(vec_ptr, eig_ptr);
vec_ptr += N * N; vec_ptr += N * N;
eig_ptr += N; eig_ptr += N;
if (info != 0) { if (work.info != 0) {
std::stringstream msg; std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info; << work.info;
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
} }
@ -131,6 +228,10 @@ void Eigh::eval_cpu(
eigh_impl<double>( eigh_impl<double>(
vectors, values, uplo_, compute_eigenvectors_, stream()); vectors, values, uplo_, compute_eigenvectors_, stream());
break; break;
case complex64:
eigh_impl<std::complex<float>>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break;
default: default:
throw std::runtime_error( throw std::runtime_error(
"[Eigh::eval_cpu] only supports float32 or float64."); "[Eigh::eval_cpu] only supports float32 or float64.");

View File

@ -2,14 +2,14 @@
#pragma once #pragma once
// Required for Visual Studio.
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
#ifdef _MSC_VER
#include <complex> #include <complex>
#define LAPACK_COMPLEX_CUSTOM #define LAPACK_COMPLEX_CUSTOM
#define lapack_complex_float std::complex<float> #define lapack_complex_float std::complex<float>
#define lapack_complex_double std::complex<double> #define lapack_complex_double std::complex<double>
#endif #define lapack_complex_float_real(z) ((z).real())
#define lapack_complex_float_imag(z) ((z).imag())
#define lapack_complex_double_real(z) ((z).real())
#define lapack_complex_double_imag(z) ((z).imag())
#ifdef MLX_USE_ACCELERATE #ifdef MLX_USE_ACCELERATE
#include <Accelerate/Accelerate.h> #include <Accelerate/Accelerate.h>
@ -32,7 +32,7 @@
#endif #endif
#define INSTANTIATE_LAPACK_TYPES(FUNC) \ #define INSTANTIATE_LAPACK_REAL(FUNC) \
template <typename T, typename... Args> \ template <typename T, typename... Args> \
void FUNC(Args... args) { \ void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \ if constexpr (std::is_same_v<T, float>) { \
@ -42,12 +42,24 @@
} \ } \
} }
INSTANTIATE_LAPACK_TYPES(geqrf) INSTANTIATE_LAPACK_REAL(geqrf)
INSTANTIATE_LAPACK_TYPES(orgqr) INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_TYPES(syevd) INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_TYPES(geev) INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_TYPES(potrf) INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_TYPES(gesvdx) INSTANTIATE_LAPACK_REAL(gesvdx)
INSTANTIATE_LAPACK_TYPES(getrf) INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_TYPES(getri) INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_TYPES(trtri) INSTANTIATE_LAPACK_REAL(trtri)
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, std::complex<float>>) { \
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_COMPLEX(heevd)

View File

@ -27,6 +27,15 @@ void check_float(Dtype dtype, const std::string& prefix) {
} }
} }
void check_float_or_complex(Dtype dtype, const std::string& prefix) {
if (dtype != float32 && dtype != float64 && dtype != complex64) {
std::ostringstream msg;
msg << prefix << " Arrays must have type float32, float64 or complex64. "
<< "Received array with type " << dtype << ".";
throw std::invalid_argument(msg.str());
}
}
Dtype at_least_float(const Dtype& d) { Dtype at_least_float(const Dtype& d) {
return issubdtype(d, inexact) ? d : promote_types(d, float32); return issubdtype(d, inexact) ? d : promote_types(d, float32);
} }
@ -493,7 +502,7 @@ void validate_eig(
const StreamOrDevice& stream, const StreamOrDevice& stream,
const std::string fname) { const std::string fname) {
check_cpu_stream(stream, fname); check_cpu_stream(stream, fname);
check_float(a.dtype(), fname); check_float_or_complex(a.dtype(), fname);
if (a.ndim() < 2) { if (a.ndim() < 2) {
std::ostringstream msg; std::ostringstream msg;
@ -513,9 +522,10 @@ array eigvalsh(
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
validate_eig(a, s, "[linalg::eigvalsh]"); validate_eig(a, s, "[linalg::eigvalsh]");
Shape out_shape(a.shape().begin(), a.shape().end() - 1); Shape out_shape(a.shape().begin(), a.shape().end() - 1);
Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();
return array( return array(
std::move(out_shape), std::move(out_shape),
a.dtype(), eigval_type,
std::make_shared<Eigh>(to_stream(s), UPLO, false), std::make_shared<Eigh>(to_stream(s), UPLO, false),
{a}); {a});
} }
@ -525,9 +535,10 @@ std::pair<array, array> eigh(
std::string UPLO /* = "L" */, std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
validate_eig(a, s, "[linalg::eigh]"); validate_eig(a, s, "[linalg::eigh]");
Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();
auto out = array::make_arrays( auto out = array::make_arrays(
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
{a.dtype(), a.dtype()}, {eigval_type, a.dtype()},
std::make_shared<Eigh>(to_stream(s), UPLO, true), std::make_shared<Eigh>(to_stream(s), UPLO, true),
{a}); {a});
return std::make_pair(out[0], out[1]); return std::make_pair(out[0], out[1]);

View File

@ -423,6 +423,13 @@ class TestLinalg(mlx_tests.MLXTestCase):
A_np = (A_np + np.transpose(A_np, (0, 2, 1))) / 2 A_np = (A_np + np.transpose(A_np, (0, 2, 1))) / 2
check_eigs_and_vecs(A_np) check_eigs_and_vecs(A_np)
# Test with complex inputs
A_np = (
np.random.randn(8, 8, 2).astype(np.float32).view(np.complex64).squeeze(-1)
)
A_np = A_np + A_np.T.conj()
check_eigs_and_vecs(A_np)
# Test error cases # Test error cases
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array