Eigenvalues and eigenvectors (#1334)

* initial eigvalsh

* add compute_vectors

* add compute_vectors_

* return a pair

* add eigh to return only eigenvectors

* fixed typo

* merge merge Eighvalsh and Eigh into a single primitive

* use the same primate with the flag

* fix primatives

* use MULTI

* fix eval_gpu

* fix decleration

* rename EighPrimitive to Eigh

* tests

* tests

* fix rebase and format

* cleanup lapack

* format

* add cblas.h

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Kashif Rasul 2024-10-22 21:18:48 +02:00 committed by GitHub
parent c26208f67d
commit 3ddc07e936
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 434 additions and 86 deletions

View File

@ -16,3 +16,5 @@ Linear Algebra
cross
qr
svd
eigvalsh
eigh

View File

@ -178,9 +178,10 @@ void array::move_shared_buffer(
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
auto data_ptr = other.array_desc_->data_ptr;
other.array_desc_->data_ptr = nullptr;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
}
void array::move_shared_buffer(array other) {

View File

@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);

View File

@ -31,6 +31,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp

View File

@ -2,46 +2,12 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
namespace {
// Delegate to the Cholesky factorization taking into account differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int spotrf_wrapper(char uplo, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1));
#else
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
} // namespace
void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
@ -66,7 +32,14 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info = spotrf_wrapper(uplo, matrix, N);
int info;
MLX_LAPACK_FUNC(spotrf)
(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how

View File

@ -3,13 +3,8 @@
#include <cassert>
#include <numeric>
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"

View File

@ -1,14 +1,10 @@
// Copyright © 2023-2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
@ -114,6 +110,7 @@ DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
namespace {

117
mlx/backend/common/eigh.cpp Normal file
View File

@ -0,0 +1,117 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void ssyevd(
char jobz,
char uplo,
float* a,
int N,
float* w,
float* work,
int lwork,
int* iwork,
int liwork) {
int info;
MLX_LAPACK_FUNC(ssyevd)
(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ a,
/* lda = */ &N,
/* w = */ w,
/* work = */ work,
/* lwork = */ &lwork,
/* iwork = */ iwork,
/* liwork = */ &liwork,
/* info = */ &info);
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
} // namespace
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes()));
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
}
auto vec_ptr = vectors.data<float>();
auto eig_ptr = values.data<float>();
char jobz = compute_eigenvectors_ ? 'V' : 'N';
auto N = a.shape(-1);
// Work query
int lwork;
int liwork;
{
float work;
int iwork;
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < a.size() / (N * N); ++i) {
ssyevd(
jobz,
uplo_[0],
vec_ptr,
N,
eig_ptr,
static_cast<float*>(work_buf.buffer.raw_ptr()),
lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
liwork);
vec_ptr += N * N;
eig_ptr += N;
}
}
} // namespace mlx::core

View File

@ -2,39 +2,19 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
// Wrapper to account for differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
strtri_(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1),
/* diag_len = */ static_cast<size_t>(1));
#else
strtri_(
MLX_LAPACK_FUNC(strtri)
(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}

View File

@ -1,10 +1,11 @@
// Copyright © 2024 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#include <lapack.h>
#endif

View File

@ -1,15 +1,10 @@
// Copyright © 2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"

View File

@ -2,14 +2,9 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
template <typename T>

View File

@ -2,7 +2,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack_helper.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {

View File

@ -401,6 +401,12 @@ void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
}
void Eigh::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI.");
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());

View File

@ -48,6 +48,7 @@ NO_CPU(Divide)
NO_CPU_MULTI(DivMod)
NO_CPU(NumberOfElements)
NO_CPU(Remainder)
NO_CPU_MULTI(Eigh)
NO_CPU(Equal)
NO_CPU(Erf)
NO_CPU(ErfInv)

View File

@ -112,6 +112,7 @@ NO_GPU(Tanh)
NO_GPU(Transpose)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eigh)
NO_GPU(View)
namespace fast {

View File

@ -454,4 +454,50 @@ array cross(
return concatenate(outputs, axis, s);
}
void validate_eigh(const array& a, const std::string fname) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << fname << " Arrays must have type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) {
std::ostringstream msg;
msg << fname << " Arrays must have >= 2 dimensions. Received array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (a.shape(-1) != a.shape(-2)) {
throw std::invalid_argument(fname + " Only defined for square matrices.");
}
}
array eigvalsh(
const array& a,
std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) {
validate_eigh(a, "[linalg::eigvalsh]");
std::vector<int> out_shape(a.shape().begin(), a.shape().end() - 1);
return array(
std::move(out_shape),
a.dtype(),
std::make_shared<Eigh>(to_stream(s), UPLO, false),
{a});
}
std::pair<array, array> eigh(
const array& a,
std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) {
validate_eigh(a, "[linalg::eigh]");
auto out = array::make_arrays(
{std::vector<int>(a.shape().begin(), a.shape().end() - 1), a.shape()},
{a.dtype(), a.dtype()},
std::make_shared<Eigh>(to_stream(s), UPLO, true),
{a});
return std::make_pair(out[0], out[1]);
}
} // namespace mlx::core::linalg

View File

@ -83,4 +83,9 @@ array cross(
int axis = -1,
StreamOrDevice s = {});
array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
std::pair<array, array>
eigh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
} // namespace mlx::core::linalg

View File

@ -767,6 +767,27 @@ std::pair<std::vector<array>, std::vector<int>> Cholesky::vmap(
return {{linalg::cholesky(a, upper_, stream())}, {ax}};
}
std::pair<std::vector<array>, std::vector<int>> Eigh::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
bool needs_move = axes[0] >= (inputs[0].ndim() - 2);
auto a = needs_move ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
auto ax = needs_move ? 0 : axes[0];
std::vector<array> outputs;
if (compute_eigenvectors_) {
auto [values, vectors] = linalg::eigh(a, uplo_, stream());
outputs = {values, vectors};
} else {
outputs = {linalg::eigvalsh(a, uplo_, stream())};
}
return {outputs, std::vector<int>(outputs.size(), ax)};
}
std::vector<array> Concatenate::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@ -2196,4 +2196,44 @@ class Cholesky : public UnaryPrimitive {
bool upper_;
};
class Eigh : public Primitive {
public:
explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
: Primitive(stream),
uplo_(std::move(uplo)),
compute_eigenvectors_(compute_eigenvectors) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_VMAP()
DEFINE_PRINT(Eigh)
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override {
auto shape = inputs[0].shape();
shape.pop_back(); // Remove last dimension for eigenvalues
if (compute_eigenvectors_) {
return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors
} else {
return {shape}; // Only eigenvalues
}
}
bool is_equivalent(const Primitive& other) const override {
if (auto* p = dynamic_cast<const Eigh*>(&other)) {
return uplo_ == p->uplo_ &&
compute_eigenvectors_ == p->compute_eigenvectors_;
}
return false;
}
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
std::string uplo_;
bool compute_eigenvectors_;
};
} // namespace mlx::core

View File

@ -405,4 +405,85 @@ void init_linalg(nb::module_& parent_module) {
Returns:
array: The cross product of ``a`` and ``b`` along the specified axis.
)pbdoc");
m.def(
"eigvalsh",
&eigvalsh,
"a"_a,
"UPLO"_a = "L",
nb::kw_only(),
"stream"_a = nb::none(),
R"pbdoc(
Compute the eigenvalues of a complex Hermitian or real symmetric matrix.
This function supports arrays with at least 2 dimensions. When the
input has more than two dimensions, the eigenvalues are computed for
each matrix in the last two dimensions.
Args:
a (array): Input array. Must be a real symmetric or complex
Hermitian matrix.
UPLO (str, optional): Whether to use the upper (``"U"``) or
lower (``"L"``) triangle of the matrix. Default: ``"L"``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The eigenvalues in ascending order.
Note:
The input matrix is assumed to be symmetric (or Hermitian). Only
the selected triangle is used. No checks for symmetry are performed.
Example:
>>> A = mx.array([[1., -2.], [-2., 1.]])
>>> eigenvalues = mx.linalg.eigvalsh(A, stream=mx.cpu)
>>> eigenvalues
array([-1., 3.], dtype=float32)
)pbdoc");
m.def(
"eigh",
[](const array& a, const std::string UPLO, StreamOrDevice s) {
// TODO avoid cast?
auto result = eigh(a, UPLO, s);
return nb::make_tuple(result.first, result.second);
},
"a"_a,
"UPLO"_a = "L",
nb::kw_only(),
"stream"_a = nb::none(),
R"pbdoc(
Compute the eigenvalues and eigenvectors of a complex Hermitian or
real symmetric matrix.
This function supports arrays with at least 2 dimensions. When the input
has more than two dimensions, the eigenvalues and eigenvectors are
computed for each matrix in the last two dimensions.
Args:
a (array): Input array. Must be a real symmetric or complex
Hermitian matrix.
UPLO (str, optional): Whether to use the upper (``"U"``) or
lower (``"L"``) triangle of the matrix. Default: ``"L"``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
Tuple[array, array]:
A tuple containing the eigenvalues in ascending order and
the normalized eigenvectors. The column ``v[:, i]`` is the
eigenvector corresponding to the i-th eigenvalue.
Note:
The input matrix is assumed to be symmetric (or Hermitian). Only
the selected triangle is used. No checks for symmetry are performed.
Example:
>>> A = mx.array([[1., -2.], [-2., 1.]])
>>> w, v = mx.linalg.eigh(A, stream=mx.cpu)
>>> w
array([-1., 3.], dtype=float32)
>>> v
array([[ 0.707107, -0.707107],
[ 0.707107, 0.707107]], dtype=float32)
)pbdoc");
}

View File

@ -268,6 +268,57 @@ class TestLinalg(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
mx.linalg.cross(a, b)
def test_eigh(self):
tols = {"atol": 1e-5, "rtol": 1e-5}
def check_eigs_and_vecs(A_np, kwargs={}):
A = mx.array(A_np)
eig_vals, eig_vecs = mx.linalg.eigh(A, stream=mx.cpu, **kwargs)
eig_vals_np, _ = np.linalg.eigh(A_np, **kwargs)
self.assertTrue(np.allclose(eig_vals, eig_vals_np, **tols))
self.assertTrue(
mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs, **tols)
)
eig_vals_only = mx.linalg.eigvalsh(A, stream=mx.cpu, **kwargs)
self.assertTrue(mx.allclose(eig_vals, eig_vals_only, **tols))
# Test a simple 2x2 symmetric matrix
A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float32)
check_eigs_and_vecs(A_np)
# Test a larger random symmetric matrix
n = 5
np.random.seed(1)
A_np = np.random.randn(n, n).astype(np.float32)
A_np = (A_np + A_np.T) / 2
check_eigs_and_vecs(A_np)
# Test with upper triangle
check_eigs_and_vecs(A_np, {"UPLO": "U"})
# Test with batched input
A_np = np.random.randn(3, n, n).astype(np.float32)
A_np = (A_np + np.transpose(A_np, (0, 2, 1))) / 2
check_eigs_and_vecs(A_np)
# Test error cases
with self.assertRaises(ValueError):
mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array
with self.assertRaises(ValueError):
mx.linalg.eigh(
mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
) # Non-square matrix
with self.assertRaises(ValueError):
mx.linalg.eigvalsh(mx.array([1.0, 2.0])) # 1D array
with self.assertRaises(ValueError):
mx.linalg.eigvalsh(
mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
) # Non-square matrix
if __name__ == "__main__":
unittest.main()

View File

@ -435,3 +435,41 @@ TEST_CASE("test cross product") {
result = cross(a, b);
CHECK(allclose(result, expected).item<bool>());
}
TEST_CASE("test matrix eigh") {
// 0D and 1D throw
CHECK_THROWS(linalg::eigh(array(0.0)));
CHECK_THROWS(linalg::eigh(array({0.0, 1.0})));
CHECK_THROWS(linalg::eigvalsh(array(0.0)));
CHECK_THROWS(linalg::eigvalsh(array({0.0, 1.0})));
// Unsupported types throw
CHECK_THROWS(linalg::eigh(array({0, 1}, {1, 2})));
// Non-square throws
CHECK_THROWS(linalg::eigh(array({1, 2, 3, 4, 5, 6}, {2, 3})));
// Test a simple 2x2 symmetric matrix
array A = array({1.0, 2.0, 2.0, 4.0}, {2, 2}, float32);
auto [eigvals, eigvecs] = linalg::eigh(A, "L", Device::cpu);
// Expected eigenvalues
array expected_eigvals = array({0.0, 5.0});
CHECK(allclose(
eigvals,
expected_eigvals,
/* rtol = */ 1e-5,
/* atol = */ 1e-5)
.item<bool>());
// Verify orthogonality of eigenvectors
CHECK(allclose(
matmul(eigvecs, transpose(eigvecs)),
eye(2),
/* rtol = */ 1e-5,
/* atol = */ 1e-5)
.item<bool>());
// Verify eigendecomposition
CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item<bool>());
}