diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index 3c34cb3f7..e7fd5ecee 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -9,7 +9,9 @@ Linear Algebra :toctree: _autosummary inv + tri_inv norm cholesky + cholesky_inv qr svd diff --git a/mlx/backend/common/inverse.cpp b/mlx/backend/common/inverse.cpp index 2dfc78d21..57d885c73 100644 --- a/mlx/backend/common/inverse.cpp +++ b/mlx/backend/common/inverse.cpp @@ -10,9 +10,106 @@ #include #endif +// Wrapper to account for differences in +// LAPACK implementations (basically how to pass the 'uplo' string to fortran). +int strtri_wrapper(char uplo, char diag, float* matrix, int N) { + int info; + +#ifdef LAPACK_FORTRAN_STRLEN_END + strtri_( + /* uplo = */ &uplo, + /* diag = */ &diag, + /* N = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* info = */ &info, + /* uplo_len = */ static_cast(1), + /* diag_len = */ static_cast(1)); +#else + strtri_( + /* uplo = */ &uplo, + /* diag = */ &diag, + /* N = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* info = */ &info); +#endif + + return info; +} + namespace mlx::core { -void inverse_impl(const array& a, array& inv) { +void general_inv(array& inv, int N, int i) { + int info; + auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; + // Compute LU factorization. + sgetrf_( + /* m = */ &N, + /* n = */ &N, + /* a = */ inv.data() + N * N * i, + /* lda = */ &N, + /* ipiv = */ static_cast(ipiv.buffer.raw_ptr()), + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "inverse_impl: LU factorization failed with error code " << info; + throw std::runtime_error(ss.str()); + } + + static const int lwork_query = -1; + float workspace_size = 0; + + // Compute workspace size. + sgetri_( + /* m = */ &N, + /* a = */ nullptr, + /* lda = */ &N, + /* ipiv = */ nullptr, + /* work = */ &workspace_size, + /* lwork = */ &lwork_query, + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "inverse_impl: LU workspace calculation failed with error code " + << info; + throw std::runtime_error(ss.str()); + } + + const int lwork = workspace_size; + auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; + + // Compute inverse. + sgetri_( + /* m = */ &N, + /* a = */ inv.data() + N * N * i, + /* lda = */ &N, + /* ipiv = */ static_cast(ipiv.buffer.raw_ptr()), + /* work = */ static_cast(scratch.buffer.raw_ptr()), + /* lwork = */ &lwork, + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "inverse_impl: inversion failed with error code " << info; + throw std::runtime_error(ss.str()); + } +} + +void tri_inv(array& inv, int N, int i, bool upper) { + const char uplo = upper ? 'L' : 'U'; + const char diag = 'N'; + int info = strtri_wrapper(uplo, diag, inv.data() + N * N * i, N); + if (info != 0) { + std::stringstream ss; + ss << "inverse_impl: triangular inversion failed with error code " << info; + throw std::runtime_error(ss.str()); + } +} + +void inverse_impl(const array& a, array& inv, bool tri, bool upper) { // Lapack uses the column-major convention. We take advantage of the following // identity to avoid transposing (see // https://math.stackexchange.com/a/340234): @@ -24,63 +121,11 @@ void inverse_impl(const array& a, array& inv) { const int N = a.shape(-1); const size_t num_matrices = a.size() / (N * N); - int info; - auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; - for (int i = 0; i < num_matrices; i++) { - // Compute LU factorization. - sgetrf_( - /* m = */ &N, - /* n = */ &N, - /* a = */ inv.data() + N * N * i, - /* lda = */ &N, - /* ipiv = */ static_cast(ipiv.buffer.raw_ptr()), - /* info = */ &info); - - if (info != 0) { - std::stringstream ss; - ss << "inverse_impl: LU factorization failed with error code " << info; - throw std::runtime_error(ss.str()); - } - - static const int lwork_query = -1; - float workspace_size = 0; - - // Compute workspace size. - sgetri_( - /* m = */ &N, - /* a = */ nullptr, - /* lda = */ &N, - /* ipiv = */ nullptr, - /* work = */ &workspace_size, - /* lwork = */ &lwork_query, - /* info = */ &info); - - if (info != 0) { - std::stringstream ss; - ss << "inverse_impl: LU workspace calculation failed with error code " - << info; - throw std::runtime_error(ss.str()); - } - - const int lwork = workspace_size; - auto scratch = - array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; - - // Compute inverse. - sgetri_( - /* m = */ &N, - /* a = */ inv.data() + N * N * i, - /* lda = */ &N, - /* ipiv = */ static_cast(ipiv.buffer.raw_ptr()), - /* work = */ static_cast(scratch.buffer.raw_ptr()), - /* lwork = */ &lwork, - /* info = */ &info); - - if (info != 0) { - std::stringstream ss; - ss << "inverse_impl: inversion failed with error code " << info; - throw std::runtime_error(ss.str()); + if (tri) { + tri_inv(inv, N, i, upper); + } else { + general_inv(inv, N, i); } } } @@ -89,7 +134,7 @@ void Inverse::eval(const std::vector& inputs, array& output) { if (inputs[0].dtype() != float32) { throw std::runtime_error("[Inverse::eval] only supports float32."); } - inverse_impl(inputs[0], output); + inverse_impl(inputs[0], output, tri_, upper_); } } // namespace mlx::core diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 845d1981f..f59efd5ad 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -238,7 +238,7 @@ std::vector svd(const array& a, StreamOrDevice s /* = {} */) { {a}); } -array inv(const array& a, StreamOrDevice s /* = {} */) { +array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) { if (a.dtype() != float32) { std::ostringstream msg; msg << "[linalg::inv] Arrays must type float32. Received array " @@ -258,7 +258,21 @@ array inv(const array& a, StreamOrDevice s /* = {} */) { } return array( - a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); + a.shape(), + a.dtype(), + std::make_shared(to_stream(s), tri, upper), + {a}); +} + +array inv(const array& a, StreamOrDevice s /* = {} */) { + return inv_impl(a, /*tri=*/false, /*upper=*/true, s); +} + +array tri_inv( + const array& a, + bool upper /* = true */, + StreamOrDevice s /* = {} */) { + return inv_impl(a, /*tri=*/true, upper, s); } array cholesky( @@ -292,4 +306,37 @@ array cholesky( {a}); } +array cholesky_inv( + const array& L, + bool upper /* = false */, + StreamOrDevice s /* = {} */) { + if (L.dtype() != float32) { + std::ostringstream msg; + msg << "[linalg::cholesky] Arrays must type float32. Received array " + << "with type " << L.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + + if (L.ndim() < 2) { + std::ostringstream msg; + msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array " + "with " + << L.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (L.shape(-1) != L.shape(-2)) { + throw std::invalid_argument( + "[linalg::cholesky] Cholesky inverse is only defined for square " + "matrices."); + } + + array L_inv = tri_inv(L, upper, s); + if (upper) { + return matmul(L_inv, swapaxes(L_inv, -1, -2, s), s); + } else { + return matmul(swapaxes(L_inv, -1, -2, s), L_inv, s); + } +} + } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index 16a2bf25b..6f3eb33fc 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -66,6 +66,10 @@ std::vector svd(const array& a, StreamOrDevice s = {}); array inv(const array& a, StreamOrDevice s = {}); +array tri_inv(const array& a, bool upper = false, StreamOrDevice s = {}); + array cholesky(const array& a, bool upper = false, StreamOrDevice s = {}); +array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {}); + } // namespace mlx::core::linalg diff --git a/mlx/primitives.h b/mlx/primitives.h index 342afdc7b..065666a34 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2127,7 +2127,8 @@ class SVD : public Primitive { /* Matrix inversion primitive. */ class Inverse : public UnaryPrimitive { public: - explicit Inverse(Stream stream) : UnaryPrimitive(stream) {} + explicit Inverse(Stream stream, bool tri, bool upper) + : UnaryPrimitive(stream), tri_(tri), upper_(upper) {} void eval_cpu(const std::vector& inputs, array& output) override; void eval_gpu(const std::vector& inputs, array& output) override; @@ -2137,6 +2138,8 @@ class Inverse : public UnaryPrimitive { private: void eval(const std::vector& inputs, array& output); + bool tri_; + bool upper_; }; class Cholesky : public UnaryPrimitive { diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 3ba94e9f8..dd79e44d6 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -261,6 +261,31 @@ void init_linalg(nb::module_& parent_module) { array: ``ainv`` such that ``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])`` )pbdoc"); m.def( + "tri_inv", + &tri_inv, + "a"_a, + "upper"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def tri_inv(a: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the inverse of a triangular square matrix. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the inverse is computed for each matrix + in the last two dimensions of ``a``. + + Args: + a (array): Input array. + upper (array): Whether the array is upper or lower triangular. Defaults to ``False``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: ``ainv`` such that ``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])`` + )pbdoc"); + m.def( "cholesky", &cholesky, "a"_a, @@ -286,8 +311,46 @@ void init_linalg(nb::module_& parent_module) { in which case the default stream of the default device is used. Returns: - array: If ``upper = False``, it returns a lower trinagular ``L`` matrix such + array: If ``upper = False``, it returns a lower triangular ``L`` matrix such that ``dot(L, L.T) = a``. If ``upper = True``, it returns an upper triangular ``U`` matrix such that ``dot(U.T, U) = a``. )pbdoc"); + m.def( + "cholesky_inv", + &cholesky_inv, + "a"_a, + "upper"_a = false, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def cholesky_inv(L: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition L. + + Let A be a real symmetric positive semi-definite matrix and L its Cholesky definition such that: + + .. math:: + + \begin{aligned} + \mathbf{A} = \mathbf{L}\mathbf{L}^T + \end{aligned} + + This function computes :math:`\mathbf{A}^{-1}`. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the Cholesky inverse is computed for each matrix + in the last two dimensions of ``L``. + + If the input matrix is not a triangular matrix behaviour is undefined. + + Args: + L (array): Input array. + upper (bool, optional): If ``True``, return the upper triangular Cholesky factor. + If ``False``, return the lower triangular Cholesky factor. Default: ``False``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: :math:`A^{-1}` where :math:`\mathbf{A} = \mathbf{L}\mathbf{L}^T`. + )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 944df89b8..0a6fe9a53 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -150,6 +150,20 @@ class TestLinalg(mlx_tests.MLXTestCase): mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5) ) + def test_tri_inverse(self): + for upper in (False, True): + A = mx.array([[1, 0, 0], [6, -5, 0], [-9, 8, 7]], dtype=mx.float32) + B = mx.array([[7, 0, 0], [3, -2, 0], [1, 8, 3]], dtype=mx.float32) + if upper: + A = A.T + B = B.T + AB = mx.stack([A, B]) + invs = mx.linalg.tri_inv(AB, upper=upper, stream=mx.cpu) + for M, M_inv in zip(AB, invs): + self.assertTrue( + mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5) + ) + def test_cholesky(self): sqrtA = mx.array( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float32 @@ -167,6 +181,33 @@ class TestLinalg(mlx_tests.MLXTestCase): for M, L in zip(AB, Ls): self.assertTrue(mx.allclose(L @ L.T, M, rtol=1e-5, atol=1e-7)) + def test_cholesky_inv(self): + mx.random.seed(7) + + sqrtA = mx.array( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float32 + ) + A = sqrtA.T @ sqrtA / 81 + + N = 3 + A = mx.random.uniform(shape=(N, N)) + A = A @ A.T + + for upper in (False, True): + L = mx.linalg.cholesky(A, upper=upper, stream=mx.cpu) + A_inv = mx.linalg.cholesky_inv(L, upper=upper, stream=mx.cpu) + self.assertTrue(mx.allclose(A @ A_inv, mx.eye(N), atol=1e-4)) + + # Multiple matrices + B = A + 1 / 9 + AB = mx.stack([A, B]) + Ls = mx.linalg.cholesky(AB, stream=mx.cpu) + for upper in (False, True): + Ls = mx.linalg.cholesky(AB, upper=upper, stream=mx.cpu) + AB_inv = mx.linalg.cholesky_inv(Ls, upper=upper, stream=mx.cpu) + for M, M_inv in zip(AB, AB_inv): + self.assertTrue(mx.allclose(M @ M_inv, mx.eye(N), atol=1e-4)) + if __name__ == "__main__": unittest.main()