diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index 227711c22..f6c51ed0b 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -16,3 +16,5 @@ Linear Algebra cross qr svd + eigvalsh + eigh diff --git a/mlx/array.cpp b/mlx/array.cpp index 374c2d36f..bb92989c3 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -178,9 +178,10 @@ void array::move_shared_buffer( array_desc_->flags = flags; array_desc_->data_size = data_size; auto char_offset = sizeof(char) * itemsize() * offset; - array_desc_->data_ptr = static_cast( - static_cast(other.array_desc_->data_ptr) + char_offset); + auto data_ptr = other.array_desc_->data_ptr; other.array_desc_->data_ptr = nullptr; + array_desc_->data_ptr = + static_cast(static_cast(data_ptr) + char_offset); } void array::move_shared_buffer(array other) { diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index eee93f2ab..1f80224ad 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) +DEFAULT_MULTI(Eigh) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 925f4731c..4fca2274e 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -31,6 +31,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp diff --git a/mlx/backend/common/cholesky.cpp b/mlx/backend/common/cholesky.cpp index 5fd9c8065..62807e6dd 100644 --- a/mlx/backend/common/cholesky.cpp +++ b/mlx/backend/common/cholesky.cpp @@ -2,46 +2,12 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" #include "mlx/linalg.h" #include "mlx/primitives.h" -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif - namespace mlx::core { -namespace { - -// Delegate to the Cholesky factorization taking into account differences in -// LAPACK implementations (basically how to pass the 'uplo' string to fortran). -int spotrf_wrapper(char uplo, float* matrix, int N) { - int info; - -#ifdef LAPACK_FORTRAN_STRLEN_END - spotrf_( - /* uplo = */ &uplo, - /* n = */ &N, - /* a = */ matrix, - /* lda = */ &N, - /* info = */ &info, - /* uplo_len = */ static_cast(1)); -#else - spotrf_( - /* uplo = */ &uplo, - /* n = */ &N, - /* a = */ matrix, - /* lda = */ &N, - /* info = */ &info); -#endif - - return info; -} - -} // namespace - void cholesky_impl(const array& a, array& factor, bool upper) { // Lapack uses the column-major convention. We take advantage of the fact that // the matrix should be symmetric: @@ -66,7 +32,14 @@ void cholesky_impl(const array& a, array& factor, bool upper) { for (int i = 0; i < num_matrices; i++) { // Compute Cholesky factorization. - int info = spotrf_wrapper(uplo, matrix, N); + int info; + MLX_LAPACK_FUNC(spotrf) + ( + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* info = */ &info); // TODO: We do nothing when the matrix is not positive semi-definite // because throwing an error would result in a crash. If we figure out how diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp index 76edc9a27..67bdaeefb 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/common/conv.cpp @@ -3,13 +3,8 @@ #include #include -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif - #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" #include "mlx/primitives.h" #include "mlx/utils.h" diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index f8932c5f8..547d8e25d 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -1,14 +1,10 @@ // Copyright © 2023-2024 Apple Inc. -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif #include #include "mlx/array.h" #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" @@ -114,6 +110,7 @@ DEFAULT(Tanh) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) +DEFAULT_MULTI(Eigh) namespace { diff --git a/mlx/backend/common/eigh.cpp b/mlx/backend/common/eigh.cpp new file mode 100644 index 000000000..8a4e499a3 --- /dev/null +++ b/mlx/backend/common/eigh.cpp @@ -0,0 +1,117 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" +#include "mlx/linalg.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +void ssyevd( + char jobz, + char uplo, + float* a, + int N, + float* w, + float* work, + int lwork, + int* iwork, + int liwork) { + int info; + MLX_LAPACK_FUNC(ssyevd) + ( + /* jobz = */ &jobz, + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ a, + /* lda = */ &N, + /* w = */ w, + /* work = */ work, + /* lwork = */ &lwork, + /* iwork = */ iwork, + /* liwork = */ &liwork, + /* info = */ &info); + if (info != 0) { + std::stringstream msg; + msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " + << info; + throw std::runtime_error(msg.str()); + } +} + +} // namespace + +void Eigh::eval(const std::vector& inputs, std::vector& outputs) { + const auto& a = inputs[0]; + auto& values = outputs[0]; + + auto vectors = compute_eigenvectors_ + ? outputs[1] + : array(a.shape(), a.dtype(), nullptr, {}); + + values.set_data(allocator::malloc_or_wait(values.nbytes())); + + copy( + a, + vectors, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + + if (compute_eigenvectors_) { + // Set the strides and flags so the eigenvectors + // are in the columns of the output + auto flags = vectors.flags(); + auto strides = vectors.strides(); + auto ndim = a.ndim(); + std::swap(strides[ndim - 1], strides[ndim - 2]); + + if (a.size() > 1) { + flags.row_contiguous = false; + if (ndim > 2) { + flags.col_contiguous = false; + } else { + flags.col_contiguous = true; + } + } + vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size()); + } + + auto vec_ptr = vectors.data(); + auto eig_ptr = values.data(); + + char jobz = compute_eigenvectors_ ? 'V' : 'N'; + auto N = a.shape(-1); + + // Work query + int lwork; + int liwork; + { + float work; + int iwork; + ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1); + lwork = static_cast(work); + liwork = iwork; + } + + auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; + auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)}; + for (size_t i = 0; i < a.size() / (N * N); ++i) { + ssyevd( + jobz, + uplo_[0], + vec_ptr, + N, + eig_ptr, + static_cast(work_buf.buffer.raw_ptr()), + lwork, + static_cast(iwork_buf.buffer.raw_ptr()), + liwork); + vec_ptr += N * N; + eig_ptr += N; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/inverse.cpp b/mlx/backend/common/inverse.cpp index 57d885c73..96dbfc001 100644 --- a/mlx/backend/common/inverse.cpp +++ b/mlx/backend/common/inverse.cpp @@ -2,39 +2,19 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" #include "mlx/primitives.h" -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif - -// Wrapper to account for differences in -// LAPACK implementations (basically how to pass the 'uplo' string to fortran). int strtri_wrapper(char uplo, char diag, float* matrix, int N) { int info; - -#ifdef LAPACK_FORTRAN_STRLEN_END - strtri_( - /* uplo = */ &uplo, - /* diag = */ &diag, - /* N = */ &N, - /* a = */ matrix, - /* lda = */ &N, - /* info = */ &info, - /* uplo_len = */ static_cast(1), - /* diag_len = */ static_cast(1)); -#else - strtri_( + MLX_LAPACK_FUNC(strtri) + ( /* uplo = */ &uplo, /* diag = */ &diag, /* N = */ &N, /* a = */ matrix, /* lda = */ &N, /* info = */ &info); -#endif - return info; } diff --git a/mlx/backend/common/lapack_helper.h b/mlx/backend/common/lapack.h similarity index 90% rename from mlx/backend/common/lapack_helper.h rename to mlx/backend/common/lapack.h index bf0f76437..b3bb7ebf0 100644 --- a/mlx/backend/common/lapack_helper.h +++ b/mlx/backend/common/lapack.h @@ -1,10 +1,11 @@ -// Copyright © 2024 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once #ifdef ACCELERATE_NEW_LAPACK #include #else +#include #include #endif diff --git a/mlx/backend/common/masked_mm.cpp b/mlx/backend/common/masked_mm.cpp index 44a471168..d0286f0fd 100644 --- a/mlx/backend/common/masked_mm.cpp +++ b/mlx/backend/common/masked_mm.cpp @@ -1,15 +1,10 @@ // Copyright © 2024 Apple Inc. -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif - #include #include "mlx/array.h" #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/common/qrf.cpp b/mlx/backend/common/qrf.cpp index 4171398fd..9383f6c88 100644 --- a/mlx/backend/common/qrf.cpp +++ b/mlx/backend/common/qrf.cpp @@ -2,14 +2,9 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" #include "mlx/primitives.h" -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif - namespace mlx::core { template diff --git a/mlx/backend/common/svd.cpp b/mlx/backend/common/svd.cpp index 412f06297..1a6f1b1ad 100644 --- a/mlx/backend/common/svd.cpp +++ b/mlx/backend/common/svd.cpp @@ -2,7 +2,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" -#include "mlx/backend/common/lapack_helper.h" +#include "mlx/backend/common/lapack.h" #include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 31f2248d7..e5a7d885b 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -401,6 +401,12 @@ void Cholesky::eval_gpu(const std::vector& inputs, array& out) { "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } +void Eigh::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI."); +} + void View::eval_gpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; auto ibytes = size_of(in.dtype()); diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index fd15c403b..c87fcc8bb 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -48,6 +48,7 @@ NO_CPU(Divide) NO_CPU_MULTI(DivMod) NO_CPU(NumberOfElements) NO_CPU(Remainder) +NO_CPU_MULTI(Eigh) NO_CPU(Equal) NO_CPU(Erf) NO_CPU(ErfInv) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 5270a6fdd..aaee51d83 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -112,6 +112,7 @@ NO_GPU(Tanh) NO_GPU(Transpose) NO_GPU(Inverse) NO_GPU(Cholesky) +NO_GPU_MULTI(Eigh) NO_GPU(View) namespace fast { diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index a64f98aa8..daf5573fc 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -454,4 +454,50 @@ array cross( return concatenate(outputs, axis, s); } +void validate_eigh(const array& a, const std::string fname) { + if (a.dtype() != float32) { + std::ostringstream msg; + msg << fname << " Arrays must have type float32. Received array " + << "with type " << a.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + + if (a.ndim() < 2) { + std::ostringstream msg; + msg << fname << " Arrays must have >= 2 dimensions. Received array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != a.shape(-2)) { + throw std::invalid_argument(fname + " Only defined for square matrices."); + } +} + +array eigvalsh( + const array& a, + std::string UPLO /* = "L" */, + StreamOrDevice s /* = {} */) { + validate_eigh(a, "[linalg::eigvalsh]"); + std::vector out_shape(a.shape().begin(), a.shape().end() - 1); + return array( + std::move(out_shape), + a.dtype(), + std::make_shared(to_stream(s), UPLO, false), + {a}); +} + +std::pair eigh( + const array& a, + std::string UPLO /* = "L" */, + StreamOrDevice s /* = {} */) { + validate_eigh(a, "[linalg::eigh]"); + auto out = array::make_arrays( + {std::vector(a.shape().begin(), a.shape().end() - 1), a.shape()}, + {a.dtype(), a.dtype()}, + std::make_shared(to_stream(s), UPLO, true), + {a}); + return std::make_pair(out[0], out[1]); +} + } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index acfcc1a41..4ea81bef0 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -83,4 +83,9 @@ array cross( int axis = -1, StreamOrDevice s = {}); +array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); + +std::pair +eigh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); + } // namespace mlx::core::linalg diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 8aa0392b7..c9f839d4b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -767,6 +767,27 @@ std::pair, std::vector> Cholesky::vmap( return {{linalg::cholesky(a, upper_, stream())}, {ax}}; } +std::pair, std::vector> Eigh::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + + bool needs_move = axes[0] >= (inputs[0].ndim() - 2); + auto a = needs_move ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + auto ax = needs_move ? 0 : axes[0]; + + std::vector outputs; + if (compute_eigenvectors_) { + auto [values, vectors] = linalg::eigh(a, uplo_, stream()); + outputs = {values, vectors}; + } else { + outputs = {linalg::eigvalsh(a, uplo_, stream())}; + } + + return {outputs, std::vector(outputs.size(), ax)}; +} + std::vector Concatenate::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 4bec71445..f2b5bab7c 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2196,4 +2196,44 @@ class Cholesky : public UnaryPrimitive { bool upper_; }; +class Eigh : public Primitive { + public: + explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors) + : Primitive(stream), + uplo_(std::move(uplo)), + compute_eigenvectors_(compute_eigenvectors) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_PRINT(Eigh) + + std::vector> output_shapes( + const std::vector& inputs) override { + auto shape = inputs[0].shape(); + shape.pop_back(); // Remove last dimension for eigenvalues + if (compute_eigenvectors_) { + return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors + } else { + return {shape}; // Only eigenvalues + } + } + + bool is_equivalent(const Primitive& other) const override { + if (auto* p = dynamic_cast(&other)) { + return uplo_ == p->uplo_ && + compute_eigenvectors_ == p->compute_eigenvectors_; + } + return false; + } + + private: + void eval(const std::vector& inputs, std::vector& outputs); + std::string uplo_; + bool compute_eigenvectors_; +}; + } // namespace mlx::core diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 13d61e980..e2c3aea23 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -405,4 +405,85 @@ void init_linalg(nb::module_& parent_module) { Returns: array: The cross product of ``a`` and ``b`` along the specified axis. )pbdoc"); + m.def( + "eigvalsh", + &eigvalsh, + "a"_a, + "UPLO"_a = "L", + nb::kw_only(), + "stream"_a = nb::none(), + R"pbdoc( + Compute the eigenvalues of a complex Hermitian or real symmetric matrix. + + This function supports arrays with at least 2 dimensions. When the + input has more than two dimensions, the eigenvalues are computed for + each matrix in the last two dimensions. + + Args: + a (array): Input array. Must be a real symmetric or complex + Hermitian matrix. + UPLO (str, optional): Whether to use the upper (``"U"``) or + lower (``"L"``) triangle of the matrix. Default: ``"L"``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The eigenvalues in ascending order. + + Note: + The input matrix is assumed to be symmetric (or Hermitian). Only + the selected triangle is used. No checks for symmetry are performed. + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> eigenvalues = mx.linalg.eigvalsh(A, stream=mx.cpu) + >>> eigenvalues + array([-1., 3.], dtype=float32) + )pbdoc"); + m.def( + "eigh", + [](const array& a, const std::string UPLO, StreamOrDevice s) { + // TODO avoid cast? + auto result = eigh(a, UPLO, s); + return nb::make_tuple(result.first, result.second); + }, + "a"_a, + "UPLO"_a = "L", + nb::kw_only(), + "stream"_a = nb::none(), + R"pbdoc( + Compute the eigenvalues and eigenvectors of a complex Hermitian or + real symmetric matrix. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the eigenvalues and eigenvectors are + computed for each matrix in the last two dimensions. + + Args: + a (array): Input array. Must be a real symmetric or complex + Hermitian matrix. + UPLO (str, optional): Whether to use the upper (``"U"``) or + lower (``"L"``) triangle of the matrix. Default: ``"L"``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + Tuple[array, array]: + A tuple containing the eigenvalues in ascending order and + the normalized eigenvectors. The column ``v[:, i]`` is the + eigenvector corresponding to the i-th eigenvalue. + + Note: + The input matrix is assumed to be symmetric (or Hermitian). Only + the selected triangle is used. No checks for symmetry are performed. + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> w, v = mx.linalg.eigh(A, stream=mx.cpu) + >>> w + array([-1., 3.], dtype=float32) + >>> v + array([[ 0.707107, -0.707107], + [ 0.707107, 0.707107]], dtype=float32) + )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 6051beef7..695d7704f 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -268,6 +268,57 @@ class TestLinalg(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.linalg.cross(a, b) + def test_eigh(self): + tols = {"atol": 1e-5, "rtol": 1e-5} + + def check_eigs_and_vecs(A_np, kwargs={}): + A = mx.array(A_np) + eig_vals, eig_vecs = mx.linalg.eigh(A, stream=mx.cpu, **kwargs) + eig_vals_np, _ = np.linalg.eigh(A_np, **kwargs) + self.assertTrue(np.allclose(eig_vals, eig_vals_np, **tols)) + self.assertTrue( + mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs, **tols) + ) + + eig_vals_only = mx.linalg.eigvalsh(A, stream=mx.cpu, **kwargs) + self.assertTrue(mx.allclose(eig_vals, eig_vals_only, **tols)) + + # Test a simple 2x2 symmetric matrix + A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float32) + check_eigs_and_vecs(A_np) + + # Test a larger random symmetric matrix + n = 5 + np.random.seed(1) + A_np = np.random.randn(n, n).astype(np.float32) + A_np = (A_np + A_np.T) / 2 + check_eigs_and_vecs(A_np) + + # Test with upper triangle + check_eigs_and_vecs(A_np, {"UPLO": "U"}) + + # Test with batched input + A_np = np.random.randn(3, n, n).astype(np.float32) + A_np = (A_np + np.transpose(A_np, (0, 2, 1))) / 2 + check_eigs_and_vecs(A_np) + + # Test error cases + with self.assertRaises(ValueError): + mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array + + with self.assertRaises(ValueError): + mx.linalg.eigh( + mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) # Non-square matrix + + with self.assertRaises(ValueError): + mx.linalg.eigvalsh(mx.array([1.0, 2.0])) # 1D array + + with self.assertRaises(ValueError): + mx.linalg.eigvalsh( + mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) # Non-square matrix + if __name__ == "__main__": unittest.main() diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index e9e196583..f0b34cc01 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -435,3 +435,41 @@ TEST_CASE("test cross product") { result = cross(a, b); CHECK(allclose(result, expected).item()); } + +TEST_CASE("test matrix eigh") { + // 0D and 1D throw + CHECK_THROWS(linalg::eigh(array(0.0))); + CHECK_THROWS(linalg::eigh(array({0.0, 1.0}))); + CHECK_THROWS(linalg::eigvalsh(array(0.0))); + CHECK_THROWS(linalg::eigvalsh(array({0.0, 1.0}))); + + // Unsupported types throw + CHECK_THROWS(linalg::eigh(array({0, 1}, {1, 2}))); + + // Non-square throws + CHECK_THROWS(linalg::eigh(array({1, 2, 3, 4, 5, 6}, {2, 3}))); + + // Test a simple 2x2 symmetric matrix + array A = array({1.0, 2.0, 2.0, 4.0}, {2, 2}, float32); + auto [eigvals, eigvecs] = linalg::eigh(A, "L", Device::cpu); + + // Expected eigenvalues + array expected_eigvals = array({0.0, 5.0}); + CHECK(allclose( + eigvals, + expected_eigvals, + /* rtol = */ 1e-5, + /* atol = */ 1e-5) + .item()); + + // Verify orthogonality of eigenvectors + CHECK(allclose( + matmul(eigvecs, transpose(eigvecs)), + eye(2), + /* rtol = */ 1e-5, + /* atol = */ 1e-5) + .item()); + + // Verify eigendecomposition + CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item()); +}