mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add the option for complex eigh
This commit is contained in:
parent
602f43e3d1
commit
9eb53248ba
@ -224,6 +224,10 @@ class array {
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
||||
o.buffer = allocator::Buffer(nullptr);
|
||||
o.d = [](allocator::Buffer) {};
|
||||
}
|
||||
~Data() {
|
||||
d(buffer);
|
||||
}
|
||||
|
@ -12,6 +12,133 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, class Enable = void>
|
||||
struct EighWork {};
|
||||
|
||||
template <typename T>
|
||||
struct EighWork<
|
||||
T,
|
||||
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||
using R = T;
|
||||
|
||||
char jobz;
|
||||
char uplo;
|
||||
int N;
|
||||
int lwork;
|
||||
int liwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EighWork(char jobz_, char uplo_, int N_)
|
||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||
}
|
||||
|
||||
void run(T* vectors, T* values) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vectors,
|
||||
&N,
|
||||
values,
|
||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(buffers[1].buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct EighWork<std::complex<float>> {
|
||||
using T = std::complex<float>;
|
||||
using R = float;
|
||||
|
||||
char jobz;
|
||||
char uplo;
|
||||
int N;
|
||||
int lwork;
|
||||
int lrwork;
|
||||
int liwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EighWork(char jobz_, char uplo_, int N_)
|
||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
|
||||
T work;
|
||||
R rwork;
|
||||
int iwork;
|
||||
heevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&rwork,
|
||||
&lrwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work.real());
|
||||
lrwork = static_cast<int>(rwork);
|
||||
liwork = iwork;
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||
}
|
||||
|
||||
void run(T* vectors, R* values) {
|
||||
heevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vectors,
|
||||
&N,
|
||||
values,
|
||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||
&lrwork,
|
||||
static_cast<int*>(buffers[2].buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
if (jobz == 'V') {
|
||||
// We have pre-transposed the vectors but we also must conjugate them
|
||||
// when they are complex.
|
||||
//
|
||||
// We could vectorize this but it is so fast in comparison to heevd that
|
||||
// it doesn't really matter.
|
||||
for (int i = 0; i < N; i++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
*vectors = std::conj(*vectors);
|
||||
vectors++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void eigh_impl(
|
||||
array& vectors,
|
||||
@ -19,8 +146,10 @@ void eigh_impl(
|
||||
const std::string& uplo,
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
using R = typename EighWork<T>::R;
|
||||
|
||||
auto vec_ptr = vectors.data<T>();
|
||||
auto eig_ptr = values.data<T>();
|
||||
auto eig_ptr = values.data<R>();
|
||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
@ -33,49 +162,17 @@ void eigh_impl(
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
int lwork = -1;
|
||||
int liwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
EighWork<T> work(jobz, uplo, N);
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
|
||||
// Work loop
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vec_ptr,
|
||||
&N,
|
||||
eig_ptr,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
work.run(vec_ptr, eig_ptr);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
if (work.info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
<< work.info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
@ -131,6 +228,10 @@ void Eigh::eval_cpu(
|
||||
eigh_impl<double>(
|
||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
eigh_impl<std::complex<float>>(
|
||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Eigh::eval_cpu] only supports float32 or float64.");
|
||||
|
@ -32,7 +32,7 @@
|
||||
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
|
||||
#define INSTANTIATE_LAPACK_REAL(FUNC) \
|
||||
template <typename T, typename... Args> \
|
||||
void FUNC(Args... args) { \
|
||||
if constexpr (std::is_same_v<T, float>) { \
|
||||
@ -42,12 +42,24 @@
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_TYPES(geqrf)
|
||||
INSTANTIATE_LAPACK_TYPES(orgqr)
|
||||
INSTANTIATE_LAPACK_TYPES(syevd)
|
||||
INSTANTIATE_LAPACK_TYPES(geev)
|
||||
INSTANTIATE_LAPACK_TYPES(potrf)
|
||||
INSTANTIATE_LAPACK_TYPES(gesvdx)
|
||||
INSTANTIATE_LAPACK_TYPES(getrf)
|
||||
INSTANTIATE_LAPACK_TYPES(getri)
|
||||
INSTANTIATE_LAPACK_TYPES(trtri)
|
||||
INSTANTIATE_LAPACK_REAL(geqrf)
|
||||
INSTANTIATE_LAPACK_REAL(orgqr)
|
||||
INSTANTIATE_LAPACK_REAL(syevd)
|
||||
INSTANTIATE_LAPACK_REAL(geev)
|
||||
INSTANTIATE_LAPACK_REAL(potrf)
|
||||
INSTANTIATE_LAPACK_REAL(gesvdx)
|
||||
INSTANTIATE_LAPACK_REAL(getrf)
|
||||
INSTANTIATE_LAPACK_REAL(getri)
|
||||
INSTANTIATE_LAPACK_REAL(trtri)
|
||||
|
||||
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
|
||||
template <typename T, typename... Args> \
|
||||
void FUNC(Args... args) { \
|
||||
if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
||||
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
||||
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
||||
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
||||
|
@ -27,6 +27,15 @@ void check_float(Dtype dtype, const std::string& prefix) {
|
||||
}
|
||||
}
|
||||
|
||||
void check_float_or_complex(Dtype dtype, const std::string& prefix) {
|
||||
if (dtype != float32 && dtype != float64 && dtype != complex64) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Arrays must have type float32, float64 or complex64. "
|
||||
<< "Received array with type " << dtype << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||
}
|
||||
@ -493,7 +502,7 @@ void validate_eig(
|
||||
const StreamOrDevice& stream,
|
||||
const std::string fname) {
|
||||
check_cpu_stream(stream, fname);
|
||||
check_float(a.dtype(), fname);
|
||||
check_float_or_complex(a.dtype(), fname);
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
@ -513,9 +522,10 @@ array eigvalsh(
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_eig(a, s, "[linalg::eigvalsh]");
|
||||
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
|
||||
Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
eigval_type,
|
||||
std::make_shared<Eigh>(to_stream(s), UPLO, false),
|
||||
{a});
|
||||
}
|
||||
@ -525,9 +535,10 @@ std::pair<array, array> eigh(
|
||||
std::string UPLO /* = "L" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_eig(a, s, "[linalg::eigh]");
|
||||
Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype();
|
||||
auto out = array::make_arrays(
|
||||
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
||||
{a.dtype(), a.dtype()},
|
||||
{eigval_type, a.dtype()},
|
||||
std::make_shared<Eigh>(to_stream(s), UPLO, true),
|
||||
{a});
|
||||
return std::make_pair(out[0], out[1]);
|
||||
|
@ -423,6 +423,13 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
A_np = (A_np + np.transpose(A_np, (0, 2, 1))) / 2
|
||||
check_eigs_and_vecs(A_np)
|
||||
|
||||
# Test with complex inputs
|
||||
A_np = (
|
||||
np.random.randn(8, 8, 2).astype(np.float32).view(np.complex64).squeeze(-1)
|
||||
)
|
||||
A_np = A_np + A_np.T.conj()
|
||||
check_eigs_and_vecs(A_np)
|
||||
|
||||
# Test error cases
|
||||
with self.assertRaises(ValueError):
|
||||
mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array
|
||||
|
Loading…
Reference in New Issue
Block a user