mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
non-symmetric eig and eigh (#2188)
This commit is contained in:
@@ -46,6 +46,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
|
174
mlx/backend/cpu/eig.cpp
Normal file
174
mlx/backend/cpu/eig.cpp
Normal file
@@ -0,0 +1,174 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void eig_impl(
|
||||
array& a,
|
||||
array& vectors,
|
||||
array& values,
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
using OT = std::complex<T>;
|
||||
auto a_ptr = a.data<T>();
|
||||
auto eig_ptr = values.data<OT>();
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(values);
|
||||
OT* vec_ptr = nullptr;
|
||||
if (compute_eigenvectors) {
|
||||
encoder.set_output_array(vectors);
|
||||
vec_ptr = vectors.data<OT>();
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
vec_ptr,
|
||||
eig_ptr,
|
||||
compute_eigenvectors,
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
char jobr = 'N';
|
||||
char jobl = compute_eigenvectors ? 'V' : 'N';
|
||||
int n_vecs_r = 1;
|
||||
int n_vecs_l = compute_eigenvectors ? N : 1;
|
||||
int lwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
&work,
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
}
|
||||
|
||||
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
||||
auto vec_tmp_data =
|
||||
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
a_ptr,
|
||||
&N,
|
||||
eig_tmp,
|
||||
eig_tmp + N,
|
||||
vec_tmp,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
||||
}
|
||||
if (vec_ptr) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
if (eig_ptr[i].imag() != 0) {
|
||||
// This vector and the next are a pair
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {
|
||||
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
||||
vec_ptr[(i + 1) * N + j] = {
|
||||
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
||||
}
|
||||
i += 1;
|
||||
} else {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
||||
}
|
||||
}
|
||||
}
|
||||
vec_ptr += N * N;
|
||||
}
|
||||
a_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
});
|
||||
encoder.add_temporary(a);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Eig::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const auto& a = inputs[0];
|
||||
auto& values = outputs[0];
|
||||
|
||||
auto vectors = compute_eigenvectors_
|
||||
? outputs[1]
|
||||
: array(a.shape(), complex64, nullptr, {});
|
||||
|
||||
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
||||
copy(
|
||||
a,
|
||||
a_copy,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
|
||||
values.set_data(allocator::malloc(values.nbytes()));
|
||||
|
||||
if (compute_eigenvectors_) {
|
||||
// Set the strides and flags so the eigenvectors
|
||||
// are in the columns of the output
|
||||
auto flags = vectors.flags();
|
||||
auto strides = vectors.strides();
|
||||
auto ndim = a.ndim();
|
||||
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
||||
|
||||
if (a.size() > 1) {
|
||||
flags.row_contiguous = false;
|
||||
if (ndim > 2) {
|
||||
flags.col_contiguous = false;
|
||||
} else {
|
||||
flags.col_contiguous = true;
|
||||
}
|
||||
}
|
||||
vectors.set_data(
|
||||
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
|
||||
}
|
||||
switch (a.dtype()) {
|
||||
case float32:
|
||||
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -45,6 +45,7 @@
|
||||
INSTANTIATE_LAPACK_TYPES(geqrf)
|
||||
INSTANTIATE_LAPACK_TYPES(orgqr)
|
||||
INSTANTIATE_LAPACK_TYPES(syevd)
|
||||
INSTANTIATE_LAPACK_TYPES(geev)
|
||||
INSTANTIATE_LAPACK_TYPES(potrf)
|
||||
INSTANTIATE_LAPACK_TYPES(gesvdx)
|
||||
INSTANTIATE_LAPACK_TYPES(getrf)
|
||||
|
@@ -378,10 +378,16 @@ void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
|
||||
}
|
||||
|
||||
void Eig::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI.");
|
||||
}
|
||||
|
||||
void Eigh::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI.");
|
||||
throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI.");
|
||||
}
|
||||
|
||||
void LUF::eval_gpu(
|
||||
|
@@ -55,6 +55,7 @@ NO_CPU(DynamicSlice)
|
||||
NO_CPU(DynamicSliceUpdate)
|
||||
NO_CPU(NumberOfElements)
|
||||
NO_CPU(Remainder)
|
||||
NO_CPU_MULTI(Eig)
|
||||
NO_CPU_MULTI(Eigh)
|
||||
NO_CPU(Equal)
|
||||
NO_CPU(Erf)
|
||||
|
@@ -126,6 +126,7 @@ NO_GPU(Unflatten)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
NO_GPU_MULTI(Eig)
|
||||
NO_GPU(View)
|
||||
|
||||
namespace fast {
|
||||
|
@@ -331,6 +331,7 @@ struct PrimitiveFactory {
|
||||
SERIALIZE_PRIMITIVE(SVD),
|
||||
SERIALIZE_PRIMITIVE(Inverse),
|
||||
SERIALIZE_PRIMITIVE(Cholesky),
|
||||
SERIALIZE_PRIMITIVE(Eig),
|
||||
SERIALIZE_PRIMITIVE(Eigh),
|
||||
SERIALIZE_PRIMITIVE(AffineQuantize),
|
||||
SERIALIZE_PRIMITIVE(RMSNorm),
|
||||
|
@@ -488,7 +488,7 @@ array cross(
|
||||
return concatenate(outputs, axis, s);
|
||||
}
|
||||
|
||||
void validate_eigh(
|
||||
void validate_eig(
|
||||
const array& a,
|
||||
const StreamOrDevice& stream,
|
||||
const std::string fname) {
|
||||
@@ -511,7 +511,7 @@ array eigvalsh(
|
||||
const array& a,
|
||||
std::string UPLO /* = "L" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_eigh(a, s, "[linalg::eigvalsh]");
|
||||
validate_eig(a, s, "[linalg::eigvalsh]");
|
||||
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
@@ -524,7 +524,7 @@ std::pair<array, array> eigh(
|
||||
const array& a,
|
||||
std::string UPLO /* = "L" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_eigh(a, s, "[linalg::eigh]");
|
||||
validate_eig(a, s, "[linalg::eigh]");
|
||||
auto out = array::make_arrays(
|
||||
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
||||
{a.dtype(), a.dtype()},
|
||||
@@ -533,6 +533,26 @@ std::pair<array, array> eigh(
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
|
||||
array eigvals(const array& a, StreamOrDevice s /* = {} */) {
|
||||
validate_eig(a, s, "[linalg::eigvals]");
|
||||
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
complex64,
|
||||
std::make_shared<Eig>(to_stream(s), false),
|
||||
{a});
|
||||
}
|
||||
|
||||
std::pair<array, array> eig(const array& a, StreamOrDevice s /* = {} */) {
|
||||
validate_eig(a, s, "[linalg::eig]");
|
||||
auto out = array::make_arrays(
|
||||
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
||||
{complex64, complex64},
|
||||
std::make_shared<Eig>(to_stream(s), true),
|
||||
{a});
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
|
||||
void validate_lu(
|
||||
const array& a,
|
||||
const StreamOrDevice& stream,
|
||||
|
@@ -99,6 +99,10 @@ array cross(
|
||||
int axis = -1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
std::pair<array, array> eig(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array eigvals(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
|
||||
|
||||
std::pair<array, array>
|
||||
|
@@ -875,6 +875,43 @@ std::pair<std::vector<array>, std::vector<int>> Cholesky::vmap(
|
||||
return {{linalg::cholesky(a, upper_, stream())}, {ax}};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Eig::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
|
||||
bool needs_move = axes[0] >= (inputs[0].ndim() - 2);
|
||||
auto a = needs_move ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||
auto ax = needs_move ? 0 : axes[0];
|
||||
|
||||
std::vector<array> outputs;
|
||||
if (compute_eigenvectors_) {
|
||||
auto [values, vectors] = linalg::eig(a, stream());
|
||||
outputs = {values, vectors};
|
||||
} else {
|
||||
outputs = {linalg::eigvals(a, stream())};
|
||||
}
|
||||
|
||||
return {outputs, std::vector<int>(outputs.size(), ax)};
|
||||
}
|
||||
|
||||
std::vector<Shape> Eig::output_shapes(const std::vector<array>& inputs) {
|
||||
auto shape = inputs[0].shape();
|
||||
shape.pop_back(); // Remove last dimension for eigenvalues
|
||||
if (compute_eigenvectors_) {
|
||||
return {
|
||||
std::move(shape), inputs[0].shape()}; // Eigenvalues and eigenvectors
|
||||
} else {
|
||||
return {std::move(shape)}; // Only eigenvalues
|
||||
}
|
||||
}
|
||||
|
||||
bool Eig::is_equivalent(const Primitive& other) const {
|
||||
auto& e_other = static_cast<const Eig&>(other);
|
||||
return compute_eigenvectors_ == e_other.compute_eigenvectors_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Eigh::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
@@ -2381,6 +2381,29 @@ class Cholesky : public UnaryPrimitive {
|
||||
bool upper_;
|
||||
};
|
||||
|
||||
class Eig : public Primitive {
|
||||
public:
|
||||
explicit Eig(Stream stream, bool compute_eigenvectors)
|
||||
: Primitive(stream), compute_eigenvectors_(compute_eigenvectors) {}
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Eig)
|
||||
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return compute_eigenvectors_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool compute_eigenvectors_;
|
||||
};
|
||||
|
||||
class Eigh : public Primitive {
|
||||
public:
|
||||
explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
|
||||
|
Reference in New Issue
Block a user