diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 25545a4c1..4ee0c0e89 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -74,6 +74,7 @@ DEFAULT(Sort) DEFAULT(StopGradient) DEFAULT_MULTI(SVD) DEFAULT(Transpose) +DEFAULT(Inverse) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 90fc25d84..71d35f370 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -54,6 +54,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index b63414408..4f43f1965 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -105,6 +105,7 @@ DEFAULT_MULTI(SVD) DEFAULT(Tan) DEFAULT(Tanh) DEFAULT(Transpose) +DEFAULT(Inverse) namespace { diff --git a/mlx/backend/common/inverse.cpp b/mlx/backend/common/inverse.cpp new file mode 100644 index 000000000..2dfc78d21 --- /dev/null +++ b/mlx/backend/common/inverse.cpp @@ -0,0 +1,95 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/backend/common/copy.h" +#include "mlx/primitives.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#else +#include +#endif + +namespace mlx::core { + +void inverse_impl(const array& a, array& inv) { + // Lapack uses the column-major convention. We take advantage of the following + // identity to avoid transposing (see + // https://math.stackexchange.com/a/340234): + // (A⁻¹)ᵀ = (Aᵀ)⁻¹ + + // The inverse is computed in place, so just copy the input to the output. + copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + + const int N = a.shape(-1); + const size_t num_matrices = a.size() / (N * N); + + int info; + auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; + + for (int i = 0; i < num_matrices; i++) { + // Compute LU factorization. + sgetrf_( + /* m = */ &N, + /* n = */ &N, + /* a = */ inv.data() + N * N * i, + /* lda = */ &N, + /* ipiv = */ static_cast(ipiv.buffer.raw_ptr()), + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "inverse_impl: LU factorization failed with error code " << info; + throw std::runtime_error(ss.str()); + } + + static const int lwork_query = -1; + float workspace_size = 0; + + // Compute workspace size. + sgetri_( + /* m = */ &N, + /* a = */ nullptr, + /* lda = */ &N, + /* ipiv = */ nullptr, + /* work = */ &workspace_size, + /* lwork = */ &lwork_query, + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "inverse_impl: LU workspace calculation failed with error code " + << info; + throw std::runtime_error(ss.str()); + } + + const int lwork = workspace_size; + auto scratch = + array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; + + // Compute inverse. + sgetri_( + /* m = */ &N, + /* a = */ inv.data() + N * N * i, + /* lda = */ &N, + /* ipiv = */ static_cast(ipiv.buffer.raw_ptr()), + /* work = */ static_cast(scratch.buffer.raw_ptr()), + /* lwork = */ &lwork, + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "inverse_impl: inversion failed with error code " << info; + throw std::runtime_error(ss.str()); + } + } +} + +void Inverse::eval(const std::vector& inputs, array& output) { + if (inputs[0].dtype() != float32) { + throw std::runtime_error("[Inverse::eval] only supports float32."); + } + inverse_impl(inputs[0], output); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/svd.cpp b/mlx/backend/common/svd.cpp index 38fc67a8c..412f06297 100644 --- a/mlx/backend/common/svd.cpp +++ b/mlx/backend/common/svd.cpp @@ -49,8 +49,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { // Will contain the indices of eigenvectors that failed to converge (not used // here but required by lapack). - std::vector iwork; - iwork.resize(12 * K); + auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)}; static const int lwork_query = -1; @@ -82,7 +81,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { /* ldvt = */ &ldvt, /* work = */ &workspace_dimension, /* lwork = */ &lwork_query, - /* iwork = */ iwork.data(), + /* iwork = */ static_cast(iwork.buffer.raw_ptr()), /* info = */ &info); if (info != 0) { @@ -120,7 +119,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { /* ldvt = */ &ldvt, /* work = */ static_cast(scratch.buffer.raw_ptr()), /* lwork = */ &lwork, - /* iwork = */ iwork.data(), + /* iwork = */ static_cast(iwork.buffer.raw_ptr()), /* info = */ &info); if (info != 0) { diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index f0c1d7f3f..686c35d7f 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -900,4 +900,8 @@ void SVD::eval_gpu( throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI."); } +void Inverse::eval_gpu(const std::vector& inputs, array& output) { + throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI."); +} + } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 5e4dbfe0a..cbb971384 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -98,6 +98,7 @@ NO_GPU_MULTI(SVD) NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Transpose) +NO_GPU(Inverse) namespace fast { NO_GPU_MULTI(RoPE) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index dca186143..5d609e7f1 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -238,4 +238,27 @@ std::vector svd(const array& a, StreamOrDevice s /* = {} */) { {a}); } +array inv(const array& a, StreamOrDevice s /* = {} */) { + if (a.dtype() != float32) { + std::ostringstream msg; + msg << "[linalg::inv] Arrays must type float32. Received array " + << "with type " << a.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + if (a.ndim() < 2) { + std::ostringstream msg; + msg << "[linalg::inv] Arrays must have >= 2 dimensions. Received array " + "with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (a.shape(-1) != a.shape(-2)) { + throw std::invalid_argument( + "[linalg::inv] Inverses are only defined for square matrices."); + } + + return array( + a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); +} + } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index 521fdf4b5..aa46a7959 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -64,4 +64,6 @@ std::pair qr(const array& a, StreamOrDevice s = {}); std::vector svd(const array& a, StreamOrDevice s = {}); +array inv(const array& a, StreamOrDevice s = {}); + } // namespace mlx::core::linalg diff --git a/mlx/primitives.h b/mlx/primitives.h index 394d8ecb7..ebb11b04d 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1897,4 +1897,18 @@ class SVD : public Primitive { void eval(const std::vector& inputs, std::vector& outputs); }; +/* Matrix inversion primitive. */ +class Inverse : public UnaryPrimitive { + public: + explicit Inverse(Stream stream) : UnaryPrimitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& output) override; + void eval_gpu(const std::vector& inputs, array& output) override; + + DEFINE_PRINT(Inverse) + + private: + void eval(const std::vector& inputs, array& output); +}; + } // namespace mlx::core diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index b0ca1cadf..92b80b9eb 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -241,4 +241,27 @@ void init_linalg(py::module_& parent_module) { tuple(array, array, array): The ``U``, ``S``, and ``Vt`` matrices, such that ``A = U @ diag(S) @ Vt`` )pbdoc"); + m.def( + "inv", + &inv, + "a"_a, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array + + Compute the inverse of a square matrix. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the inverse is computed for each matrix + in the last two dimensions of ``a``. + + Args: + a (array): Input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: ``ainv`` such that ``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])`` + )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index c5f31505c..a8dec0322 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -136,6 +136,20 @@ class TestLinalg(mlx_tests.MLXTestCase): mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7) ) + def test_inverse(self): + A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32) + A_inv = mx.linalg.inv(A, stream=mx.cpu) + self.assertTrue(mx.allclose(A @ A_inv, mx.eye(A.shape[0]), rtol=0, atol=1e-6)) + + # Multiple matrices + B = A - 100 + AB = mx.stack([A, B]) + invs = mx.linalg.inv(AB, stream=mx.cpu) + for M, M_inv in zip(AB, invs): + self.assertTrue( + mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5) + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 5ab11ff8a..45ccb6134 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -300,3 +300,25 @@ TEST_CASE("test SVD factorization") { CHECK_EQ(S.dtype(), float32); CHECK_EQ(Vt.dtype(), float32); } + +TEST_CASE("test matrix inversion") { + // 0D and 1D throw + CHECK_THROWS(linalg::inv(array(0.0), Device::cpu)); + CHECK_THROWS(linalg::inv(array({0.0, 1.0}), Device::cpu)); + + // Unsupported types throw + CHECK_THROWS(linalg::inv(array({0, 1}, {1, 2}), Device::cpu)); + + // Non-square throws. + CHECK_THROWS(linalg::inv(array({1, 2, 3, 4, 5, 6}, {2, 3}), Device::cpu)); + + const auto prng_key = random::key(42); + const auto A = random::normal({5, 5}, prng_key); + const auto A_inv = linalg::inv(A, Device::cpu); + const auto identity = eye(A.shape(0)); + + CHECK(allclose(matmul(A, A_inv), identity, /* rtol = */ 0, /* atol = */ 1e-6) + .item()); + CHECK(allclose(matmul(A_inv, A), identity, /* rtol = */ 0, /* atol = */ 1e-6) + .item()); +}