From b3ec792380216d984ccb13ac4cf9439a42c593e5 Mon Sep 17 00:00:00 2001 From: Luca Arnaboldi Date: Fri, 17 May 2024 21:31:59 +0200 Subject: [PATCH] Implemented Cholesky on CPU (#1119) --- mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/CMakeLists.txt | 1 + mlx/backend/common/cholesky.cpp | 109 ++++++++++++++++++++++ mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/metal/primitives.cpp | 5 + mlx/backend/no_metal/primitives.cpp | 1 + mlx/linalg.cpp | 31 ++++++ mlx/linalg.h | 2 + mlx/primitives.h | 16 ++++ python/src/linalg.cpp | 29 ++++++ python/tests/test_linalg.py | 17 ++++ tests/linalg_tests.cpp | 26 ++++++ 12 files changed, 239 insertions(+) create mode 100644 mlx/backend/common/cholesky.cpp diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 5a2500c64..1187015cf 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -80,6 +80,7 @@ DEFAULT(StopGradient) DEFAULT_MULTI(SVD) DEFAULT(Transpose) DEFAULT(Inverse) +DEFAULT(Cholesky) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 3e9f87dfa..6f265b1ad 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -56,6 +56,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ) diff --git a/mlx/backend/common/cholesky.cpp b/mlx/backend/common/cholesky.cpp new file mode 100644 index 000000000..2af5d8ddf --- /dev/null +++ b/mlx/backend/common/cholesky.cpp @@ -0,0 +1,109 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/backend/common/copy.h" +#include "mlx/linalg.h" +#include "mlx/primitives.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#else +#include +#endif + +namespace mlx::core { + +namespace { + +// Delegate to the Cholesky factorization taking into account differences in +// LAPACK implementations (basically how to pass the 'uplo' string to fortran). +int spotrf_wrapper(char uplo, float* matrix, int N) { + int info; + +#ifdef LAPACK_FORTRAN_STRLEN_END + spotrf_( + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* info = */ &info, + /* uplo_len = */ static_cast(1)); +#else + spotrf_( + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* info = */ &info); +#endif + + return info; +} + +} // namespace + +void cholesky_impl(const array& a, array& factor, bool upper) { + // Lapack uses the column-major convention. We take advantage of the fact that + // the matrix should be symmetric: + // (A)ᵀ = A + // and that a column-major lower triangular matrix is a row-major upper + // triangular matrix, so uplo is the opposite of what we would expect from + // upper + + char uplo = (upper) ? 'L' : 'U'; + + // The decomposition is computed in place, so just copy the input to the + // output. + copy( + a, + factor, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + + const int N = a.shape(-1); + const size_t num_matrices = a.size() / (N * N); + + float* matrix = factor.data(); + + for (int i = 0; i < num_matrices; i++) { + // Compute Cholesky factorization. + int info = spotrf_wrapper(uplo, matrix, N); + + // TODO: We do nothing when the matrix is not positive semi-definite + // because throwing an error would result in a crash. If we figure out how + // to catch errors from the implementation we should throw. + if (info < 0) { + std::stringstream msg; + msg << "[cholesky] Cholesky decomposition failed with error code " + << info; + throw std::runtime_error(msg.str()); + } + + // Zero out the upper/lower triangle while advancing the pointer to the + // next matrix at the same time. + for (int row = 0; row < N; row++) { + if (upper) { + std::fill(matrix, matrix + row, 0); + } else { + std::fill(matrix + row + 1, matrix + N, 0); + } + matrix += N; + } + } +} + +void Cholesky::eval(const std::vector& inputs, array& output) { + if (inputs[0].dtype() != float32) { + throw std::runtime_error("[Cholesky::eval] only supports float32."); + } + cholesky_impl(inputs[0], output, upper_); +} + +std::pair, std::vector> Cholesky::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0] >= 0 ? 0 : -1; + auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + return {{linalg::cholesky(a, upper_, stream())}, {ax}}; +} + +} // namespace mlx::core diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 4ebb27af2..0b502615c 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -113,6 +113,7 @@ DEFAULT(Tan) DEFAULT(Tanh) DEFAULT(Transpose) DEFAULT(Inverse) +DEFAULT(Cholesky) namespace { diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index d989b2197..a2c3df651 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -1012,4 +1012,9 @@ void Inverse::eval_gpu(const std::vector& inputs, array& output) { throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI."); } +void Cholesky::eval_gpu(const std::vector& inputs, array& out) { + throw std::runtime_error( + "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); +} + } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 9934336cb..43ee0efad 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -107,6 +107,7 @@ NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Transpose) NO_GPU(Inverse) +NO_GPU(Cholesky) namespace fast { NO_GPU_MULTI(LayerNorm) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index d772c0e14..845d1981f 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -261,4 +261,35 @@ array inv(const array& a, StreamOrDevice s /* = {} */) { a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); } +array cholesky( + const array& a, + bool upper /* = false */, + StreamOrDevice s /* = {} */) { + if (a.dtype() != float32) { + std::ostringstream msg; + msg << "[linalg::cholesky] Arrays must type float32. Received array " + << "with type " << a.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + + if (a.ndim() < 2) { + std::ostringstream msg; + msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array " + "with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != a.shape(-2)) { + throw std::invalid_argument( + "[linalg::cholesky] Cholesky decomposition is only defined for square " + "matrices."); + } + return array( + a.shape(), + a.dtype(), + std::make_shared(to_stream(s), upper), + {a}); +} + } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index aa46a7959..16a2bf25b 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -66,4 +66,6 @@ std::vector svd(const array& a, StreamOrDevice s = {}); array inv(const array& a, StreamOrDevice s = {}); +array cholesky(const array& a, bool upper = false, StreamOrDevice s = {}); + } // namespace mlx::core::linalg diff --git a/mlx/primitives.h b/mlx/primitives.h index dff21a072..7569eb834 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2093,4 +2093,20 @@ class Inverse : public UnaryPrimitive { void eval(const std::vector& inputs, array& output); }; +class Cholesky : public UnaryPrimitive { + public: + explicit Cholesky(Stream stream, bool upper) + : UnaryPrimitive(stream), upper_(upper) {}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_PRINT(Cholesky) + + private: + void eval(const std::vector& inputs, array& output); + bool upper_; +}; + } // namespace mlx::core diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index a6a86e414..eed8fe53f 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -260,4 +260,33 @@ void init_linalg(nb::module_& parent_module) { Returns: array: ``ainv`` such that ``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])`` )pbdoc"); + m.def( + "cholesky", + &cholesky, + "a"_a, + "upper"_a = false, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def cholesky(a: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the Cholesky decomposition is computed for each matrix + in the last two dimensions of ``a``. + + If the input matrix is not symmetric positive semi-definite, behaviour is undefined. + + Args: + a (array): Input array. + upper (bool, optional): If ``True``, return the upper triangular Cholesky factor. + If ``False``, return the lower triangular Cholesky factor. Default: ``False``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: if ``upper = False``, it returns a lower trinagular ``L``matrix such that ``dot(L, L.T) = a``. + If ``upper = True``, it returns an upper triangular ``U`` matrix such that ``dot(U.T, U) = a``. + )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index a8dec0322..944df89b8 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -150,6 +150,23 @@ class TestLinalg(mlx_tests.MLXTestCase): mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5) ) + def test_cholesky(self): + sqrtA = mx.array( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float32 + ) + A = sqrtA.T @ sqrtA / 81 + L = mx.linalg.cholesky(A, stream=mx.cpu) + U = mx.linalg.cholesky(A, upper=True, stream=mx.cpu) + self.assertTrue(mx.allclose(L @ L.T, A, rtol=1e-5, atol=1e-7)) + self.assertTrue(mx.allclose(U.T @ U, A, rtol=1e-5, atol=1e-7)) + + # Multiple matrices + B = A + 1 / 9 + AB = mx.stack([A, B]) + Ls = mx.linalg.cholesky(AB, stream=mx.cpu) + for M, L in zip(AB, Ls): + self.assertTrue(mx.allclose(L @ L.T, M, rtol=1e-5, atol=1e-7)) + if __name__ == "__main__": unittest.main() diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 45ccb6134..2af868965 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -322,3 +322,29 @@ TEST_CASE("test matrix inversion") { CHECK(allclose(matmul(A_inv, A), identity, /* rtol = */ 0, /* atol = */ 1e-6) .item()); } + +TEST_CASE("test matrix cholesky") { + // 0D and 1D throw + CHECK_THROWS(linalg::cholesky(array(0.0), /* upper = */ false, Device::cpu)); + CHECK_THROWS( + linalg::cholesky(array({0.0, 1.0}), /* upper = */ false, Device::cpu)); + + // Unsupported types throw + CHECK_THROWS(linalg::cholesky( + array({0, 1}, {1, 2}), /* upper = */ false, Device::cpu)); + + // Non-square throws. + CHECK_THROWS(linalg::cholesky( + array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ false, Device::cpu)); + + const auto prng_key = random::key(220398); + const auto sqrtA = random::normal({5, 5}, prng_key); + const auto A = matmul(sqrtA, transpose(sqrtA)); + const auto L = linalg::cholesky(A, /* upper = */ false, Device::cpu); + const auto U = linalg::cholesky(A, /* upper = */ true, Device::cpu); + + CHECK(allclose(matmul(L, transpose(L)), A, /* rtol = */ 0, /* atol = */ 1e-6) + .item()); + CHECK(allclose(matmul(transpose(U), U), A, /* rtol = */ 0, /* atol = */ 1e-6) + .item()); +} \ No newline at end of file