diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index b01f74117..495380c46 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -16,6 +16,8 @@ Linear Algebra cross qr svd + eigvals + eig eigvalsh eigh lu diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt index 96b3f1313..9d322c4c4 100644 --- a/mlx/backend/cpu/CMakeLists.txt +++ b/mlx/backend/cpu/CMakeLists.txt @@ -46,6 +46,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp diff --git a/mlx/backend/cpu/eig.cpp b/mlx/backend/cpu/eig.cpp new file mode 100644 index 000000000..c89003fc0 --- /dev/null +++ b/mlx/backend/cpu/eig.cpp @@ -0,0 +1,174 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/lapack.h" +#include "mlx/linalg.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +void eig_impl( + array& a, + array& vectors, + array& values, + bool compute_eigenvectors, + Stream stream) { + using OT = std::complex; + auto a_ptr = a.data(); + auto eig_ptr = values.data(); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(values); + OT* vec_ptr = nullptr; + if (compute_eigenvectors) { + encoder.set_output_array(vectors); + vec_ptr = vectors.data(); + } + encoder.dispatch([a_ptr, + vec_ptr, + eig_ptr, + compute_eigenvectors, + N = vectors.shape(-1), + size = vectors.size()]() mutable { + // Work query + char jobr = 'N'; + char jobl = compute_eigenvectors ? 'V' : 'N'; + int n_vecs_r = 1; + int n_vecs_l = compute_eigenvectors ? N : 1; + int lwork = -1; + int info; + { + T work; + int iwork; + geev( + &jobl, + &jobr, + &N, + nullptr, + &N, + nullptr, + nullptr, + nullptr, + &n_vecs_l, + nullptr, + &n_vecs_r, + &work, + &lwork, + &info); + lwork = static_cast(work); + } + + auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)}; + auto vec_tmp_data = + array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)}; + auto eig_tmp = static_cast(eig_tmp_data.buffer.raw_ptr()); + auto vec_tmp = static_cast(vec_tmp_data.buffer.raw_ptr()); + auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)}; + for (size_t i = 0; i < size / (N * N); ++i) { + geev( + &jobl, + &jobr, + &N, + a_ptr, + &N, + eig_tmp, + eig_tmp + N, + vec_tmp, + &n_vecs_l, + nullptr, + &n_vecs_r, + static_cast(work_buf.buffer.raw_ptr()), + &lwork, + &info); + for (int i = 0; i < N; ++i) { + eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]}; + } + if (vec_ptr) { + for (int i = 0; i < N; ++i) { + if (eig_ptr[i].imag() != 0) { + // This vector and the next are a pair + for (int j = 0; j < N; ++j) { + vec_ptr[i * N + j] = { + vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]}; + vec_ptr[(i + 1) * N + j] = { + vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]}; + } + i += 1; + } else { + for (int j = 0; j < N; ++j) { + vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0}; + } + } + } + vec_ptr += N * N; + } + a_ptr += N * N; + eig_ptr += N; + if (info != 0) { + std::stringstream msg; + msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code " + << info; + throw std::runtime_error(msg.str()); + } + } + }); + encoder.add_temporary(a); +} + +} // namespace + +void Eig::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + const auto& a = inputs[0]; + auto& values = outputs[0]; + + auto vectors = compute_eigenvectors_ + ? outputs[1] + : array(a.shape(), complex64, nullptr, {}); + + auto a_copy = array(a.shape(), a.dtype(), nullptr, {}); + copy( + a, + a_copy, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream()); + + values.set_data(allocator::malloc(values.nbytes())); + + if (compute_eigenvectors_) { + // Set the strides and flags so the eigenvectors + // are in the columns of the output + auto flags = vectors.flags(); + auto strides = vectors.strides(); + auto ndim = a.ndim(); + std::swap(strides[ndim - 1], strides[ndim - 2]); + + if (a.size() > 1) { + flags.row_contiguous = false; + if (ndim > 2) { + flags.col_contiguous = false; + } else { + flags.col_contiguous = true; + } + } + vectors.set_data( + allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags); + } + switch (a.dtype()) { + case float32: + eig_impl(a_copy, vectors, values, compute_eigenvectors_, stream()); + break; + default: + throw std::runtime_error("[Eig::eval_cpu] only supports float32."); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/lapack.h b/mlx/backend/cpu/lapack.h index 2911c63f8..411742d56 100644 --- a/mlx/backend/cpu/lapack.h +++ b/mlx/backend/cpu/lapack.h @@ -45,6 +45,7 @@ INSTANTIATE_LAPACK_TYPES(geqrf) INSTANTIATE_LAPACK_TYPES(orgqr) INSTANTIATE_LAPACK_TYPES(syevd) +INSTANTIATE_LAPACK_TYPES(geev) INSTANTIATE_LAPACK_TYPES(potrf) INSTANTIATE_LAPACK_TYPES(gesvdx) INSTANTIATE_LAPACK_TYPES(getrf) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 860e9ddd7..6e42b29c9 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -378,10 +378,16 @@ void Cholesky::eval_gpu(const std::vector& inputs, array& out) { "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } +void Eig::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI."); +} + void Eigh::eval_gpu( const std::vector& inputs, std::vector& outputs) { - throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI."); + throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI."); } void LUF::eval_gpu( diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 84372b096..1a180bfe0 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -55,6 +55,7 @@ NO_CPU(DynamicSlice) NO_CPU(DynamicSliceUpdate) NO_CPU(NumberOfElements) NO_CPU(Remainder) +NO_CPU_MULTI(Eig) NO_CPU_MULTI(Eigh) NO_CPU(Equal) NO_CPU(Erf) diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 6826c97f6..676a6e550 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -126,6 +126,7 @@ NO_GPU(Unflatten) NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eigh) +NO_GPU_MULTI(Eig) NO_GPU(View) namespace fast { diff --git a/mlx/export.cpp b/mlx/export.cpp index c9139e156..bd2f24ba2 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -331,6 +331,7 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(SVD), SERIALIZE_PRIMITIVE(Inverse), SERIALIZE_PRIMITIVE(Cholesky), + SERIALIZE_PRIMITIVE(Eig), SERIALIZE_PRIMITIVE(Eigh), SERIALIZE_PRIMITIVE(AffineQuantize), SERIALIZE_PRIMITIVE(RMSNorm), diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 53f13486a..e0f4ec2e6 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -488,7 +488,7 @@ array cross( return concatenate(outputs, axis, s); } -void validate_eigh( +void validate_eig( const array& a, const StreamOrDevice& stream, const std::string fname) { @@ -511,7 +511,7 @@ array eigvalsh( const array& a, std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { - validate_eigh(a, s, "[linalg::eigvalsh]"); + validate_eig(a, s, "[linalg::eigvalsh]"); Shape out_shape(a.shape().begin(), a.shape().end() - 1); return array( std::move(out_shape), @@ -524,7 +524,7 @@ std::pair eigh( const array& a, std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { - validate_eigh(a, s, "[linalg::eigh]"); + validate_eig(a, s, "[linalg::eigh]"); auto out = array::make_arrays( {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, {a.dtype(), a.dtype()}, @@ -533,6 +533,26 @@ std::pair eigh( return std::make_pair(out[0], out[1]); } +array eigvals(const array& a, StreamOrDevice s /* = {} */) { + validate_eig(a, s, "[linalg::eigvals]"); + Shape out_shape(a.shape().begin(), a.shape().end() - 1); + return array( + std::move(out_shape), + complex64, + std::make_shared(to_stream(s), false), + {a}); +} + +std::pair eig(const array& a, StreamOrDevice s /* = {} */) { + validate_eig(a, s, "[linalg::eig]"); + auto out = array::make_arrays( + {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, + {complex64, complex64}, + std::make_shared(to_stream(s), true), + {a}); + return std::make_pair(out[0], out[1]); +} + void validate_lu( const array& a, const StreamOrDevice& stream, diff --git a/mlx/linalg.h b/mlx/linalg.h index 8c3a2070a..0690fba95 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -99,6 +99,10 @@ array cross( int axis = -1, StreamOrDevice s = {}); +std::pair eig(const array& a, StreamOrDevice s = {}); + +array eigvals(const array& a, StreamOrDevice s = {}); + array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); std::pair diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index e1924e66c..87b2bc924 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -875,6 +875,43 @@ std::pair, std::vector> Cholesky::vmap( return {{linalg::cholesky(a, upper_, stream())}, {ax}}; } +std::pair, std::vector> Eig::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + + bool needs_move = axes[0] >= (inputs[0].ndim() - 2); + auto a = needs_move ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + auto ax = needs_move ? 0 : axes[0]; + + std::vector outputs; + if (compute_eigenvectors_) { + auto [values, vectors] = linalg::eig(a, stream()); + outputs = {values, vectors}; + } else { + outputs = {linalg::eigvals(a, stream())}; + } + + return {outputs, std::vector(outputs.size(), ax)}; +} + +std::vector Eig::output_shapes(const std::vector& inputs) { + auto shape = inputs[0].shape(); + shape.pop_back(); // Remove last dimension for eigenvalues + if (compute_eigenvectors_) { + return { + std::move(shape), inputs[0].shape()}; // Eigenvalues and eigenvectors + } else { + return {std::move(shape)}; // Only eigenvalues + } +} + +bool Eig::is_equivalent(const Primitive& other) const { + auto& e_other = static_cast(other); + return compute_eigenvectors_ == e_other.compute_eigenvectors_; +} + std::pair, std::vector> Eigh::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 2caed8477..c0fbfc84d 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2381,6 +2381,29 @@ class Cholesky : public UnaryPrimitive { bool upper_; }; +class Eig : public Primitive { + public: + explicit Eig(Stream stream, bool compute_eigenvectors) + : Primitive(stream), compute_eigenvectors_(compute_eigenvectors) {} + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_PRINT(Eig) + + std::vector output_shapes(const std::vector& inputs) override; + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return compute_eigenvectors_; + } + + private: + bool compute_eigenvectors_; +}; + class Eigh : public Primitive { public: explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors) diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 3bc0e5b1b..cc8e79db6 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -236,7 +236,7 @@ void init_linalg(nb::module_& parent_module) { Returns: Union[tuple(array, ...), array]: - If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that + If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that ``A = U @ diag(S) @ Vt``. If compute_uv is ``False`` returns singular values array ``S``. )pbdoc"); m.def( @@ -407,6 +407,76 @@ void init_linalg(nb::module_& parent_module) { Returns: array: The cross product of ``a`` and ``b`` along the specified axis. )pbdoc"); + m.def( + "eigvals", + &mx::linalg::eigvals, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + R"pbdoc( + Compute the eigenvalues of a square matrix. + + This function differs from :func:`numpy.linalg.eigvals` in that the + return type is always complex even if the eigenvalues are all real. + + This function supports arrays with at least 2 dimensions. When the + input has more than two dimensions, the eigenvalues are computed for + each matrix in the last two dimensions. + + Args: + a (array): The input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The eigenvalues (not necessarily in order). + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> eigenvalues = mx.linalg.eigvals(A, stream=mx.cpu) + >>> eigenvalues + array([3+0j, -1+0j], dtype=complex64) + )pbdoc"); + m.def( + "eig", + [](const mx::array& a, mx::StreamOrDevice s) { + auto result = mx::linalg::eig(a, s); + return nb::make_tuple(result.first, result.second); + }, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + R"pbdoc( + Compute the eigenvalues and eigenvectors of a square matrix. + + This function differs from :func:`numpy.linalg.eig` in that the + return type is always complex even if the eigenvalues are all real. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the eigenvalues and eigenvectors are + computed for each matrix in the last two dimensions. + + Args: + a (array): The input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + Tuple[array, array]: + A tuple containing the eigenvalues and the normalized right + eigenvectors. The column ``v[:, i]`` is the eigenvector + corresponding to the i-th eigenvalue. + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> w, v = mx.linalg.eig(A, stream=mx.cpu) + >>> w + array([3+0j, -1+0j], dtype=complex64) + >>> v + array([[0.707107+0j, 0.707107+0j], + [-0.707107+0j, 0.707107+0j]], dtype=complex64) + )pbdoc"); + m.def( "eigvalsh", &mx::linalg::eigvalsh, diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index a9fe572af..f65da1ff7 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -312,6 +312,83 @@ class TestLinalg(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.linalg.cross(a, b) + def test_eig(self): + tols = {"atol": 1e-5, "rtol": 1e-5} + + def check_eigs_and_vecs(A_np, kwargs={}): + A = mx.array(A_np) + eig_vals, eig_vecs = mx.linalg.eig(A, stream=mx.cpu, **kwargs) + self.assertTrue( + mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs, **tols) + ) + eig_vals_only = mx.linalg.eigvals(A, stream=mx.cpu, **kwargs) + self.assertTrue(mx.allclose(eig_vals, eig_vals_only, **tols)) + + # Test a simple 2x2 matrix + A_np = np.array([[1.0, 1.0], [3.0, 4.0]], dtype=np.float32) + check_eigs_and_vecs(A_np) + + # Test complex eigenvalues + A_np = np.array([[1.0, -1.0], [1.0, 1.0]], dtype=np.float32) + check_eigs_and_vecs(A_np) + + # Test a larger random symmetric matrix + n = 5 + np.random.seed(1) + A_np = np.random.randn(n, n).astype(np.float32) + check_eigs_and_vecs(A_np) + + # Test with batched input + A_np = np.random.randn(3, n, n).astype(np.float32) + check_eigs_and_vecs(A_np) + + # Test error cases + with self.assertRaises(ValueError): + mx.linalg.eig(mx.array([1.0, 2.0])) # 1D array + + with self.assertRaises(ValueError): + mx.linalg.eig( + mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) # Non-square matrix + + with self.assertRaises(ValueError): + mx.linalg.eigvals(mx.array([1.0, 2.0])) # 1D array + + with self.assertRaises(ValueError): + mx.linalg.eigvals( + mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) # Non-square matrix + + def test_lu(self): + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array(0.0), stream=mx.cpu) + + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array([0.0, 1.0]), stream=mx.cpu) + + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array([[0, 1], [1, 0]]), stream=mx.cpu) + + # Test 3x3 matrix + a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(L[P, :] @ U, a)) + + # Test batch dimension + a = mx.broadcast_to(a, (5, 5, 3, 3)) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + L = mx.take_along_axis(L, P[..., None], axis=-2) + self.assertTrue(mx.allclose(L @ U, a)) + + # Test non-square matrix + a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0]]) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(L[P, :] @ U, a)) + + a = mx.array([[3.0, 1.0], [1.0, 8.0], [9.0, 2.0]]) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(L[P, :] @ U, a)) + def test_eigh(self): tols = {"atol": 1e-5, "rtol": 1e-5}