diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 1d4258f62..9697942bf 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -71,6 +71,7 @@ DEFAULT(Slice) DEFAULT_MULTI(Split) DEFAULT(Sort) DEFAULT(StopGradient) +DEFAULT_MULTI(SVD) DEFAULT(Transpose) void Abs::eval_cpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 8e8b10016..90fc25d84 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -53,6 +53,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 53b7a65f7..e3eb7e0dc 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -100,6 +100,7 @@ DEFAULT(Square) DEFAULT(Sqrt) DEFAULT(StopGradient) DEFAULT(Subtract) +DEFAULT_MULTI(SVD) DEFAULT(Tan) DEFAULT(Tanh) DEFAULT(Transpose) diff --git a/mlx/backend/common/lapack_helper.h b/mlx/backend/common/lapack_helper.h new file mode 100644 index 000000000..bf0f76437 --- /dev/null +++ b/mlx/backend/common/lapack_helper.h @@ -0,0 +1,23 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#ifdef ACCELERATE_NEW_LAPACK +#include +#else +#include +#endif + +#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME) + +// This is to work around a change in the function signatures of lapack >= 3.9.1 +// where functions taking char* also include a strlen argument, see a similar +// change in OpenCV: +// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57 +#define MLX_LAPACK_FUNC(f) LAPACK_##f + +#else + +#define MLX_LAPACK_FUNC(f) f##_ + +#endif diff --git a/mlx/backend/common/svd.cpp b/mlx/backend/common/svd.cpp new file mode 100644 index 000000000..38fc67a8c --- /dev/null +++ b/mlx/backend/common/svd.cpp @@ -0,0 +1,148 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack_helper.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void svd_impl(const array& a, array& u, array& s, array& vt) { + // Lapack uses the column-major convention. To avoid having to transpose + // the input and then transpose the outputs, we swap the indices/sizes of the + // matrices and take advantage of the following identity (see + // https://math.stackexchange.com/a/30077) + // A = UΣVᵀ + // Aᵀ = VΣUᵀ + // As a result some of the indices/sizes are swapped as noted above. + + // Rows and cols of the original matrix in row-major order. + const int M = a.shape(-2); + const int N = a.shape(-1); + const int K = std::min(M, N); + + // A of shape M x N. The leading dimension is N since lapack receives Aᵀ. + const int lda = N; + // U of shape M x M. (N x N in lapack). + const int ldu = N; + // Vᵀ of shape N x N. (M x M in lapack). + const int ldvt = M; + + size_t num_matrices = a.size() / (M * N); + + // lapack clobbers the input, so we have to make a copy. + array in(a.shape(), float32, nullptr, {}); + copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + + // Allocate outputs. + u.set_data(allocator::malloc_or_wait(u.nbytes())); + s.set_data(allocator::malloc_or_wait(s.nbytes())); + vt.set_data(allocator::malloc_or_wait(vt.nbytes())); + + static constexpr auto job_u = "V"; + static constexpr auto job_vt = "V"; + static constexpr auto range = "A"; + + // Will contain the number of singular values after the call has returned. + int ns = 0; + float workspace_dimension = 0; + + // Will contain the indices of eigenvectors that failed to converge (not used + // here but required by lapack). + std::vector iwork; + iwork.resize(12 * K); + + static const int lwork_query = -1; + + static const int ignored_int = 0; + static const float ignored_float = 0; + + int info; + + // Compute workspace size. + MLX_LAPACK_FUNC(sgesvdx) + ( + /* jobu = */ job_u, + /* jobvt = */ job_vt, + /* range = */ range, + // M and N are swapped since lapack expects column-major. + /* m = */ &N, + /* n = */ &M, + /* a = */ nullptr, + /* lda = */ &lda, + /* vl = */ &ignored_float, + /* vu = */ &ignored_float, + /* il = */ &ignored_int, + /* iu = */ &ignored_int, + /* ns = */ &ns, + /* s = */ nullptr, + /* u = */ nullptr, + /* ldu = */ &ldu, + /* vt = */ nullptr, + /* ldvt = */ &ldvt, + /* work = */ &workspace_dimension, + /* lwork = */ &lwork_query, + /* iwork = */ iwork.data(), + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info; + throw std::runtime_error(ss.str()); + } + + const int lwork = workspace_dimension; + auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; + + // Loop over matrices. + for (int i = 0; i < num_matrices; i++) { + MLX_LAPACK_FUNC(sgesvdx) + ( + /* jobu = */ job_u, + /* jobvt = */ job_vt, + /* range = */ range, + // M and N are swapped since lapack expects column-major. + /* m = */ &N, + /* n = */ &M, + /* a = */ in.data() + M * N * i, + /* lda = */ &lda, + /* vl = */ &ignored_float, + /* vu = */ &ignored_float, + /* il = */ &ignored_int, + /* iu = */ &ignored_int, + /* ns = */ &ns, + /* s = */ s.data() + K * i, + // According to the identity above, lapack will write Vᵀᵀ as U. + /* u = */ vt.data() + N * N * i, + /* ldu = */ &ldu, + // According to the identity above, lapack will write Uᵀ as Vᵀ. + /* vt = */ u.data() + M * M * i, + /* ldvt = */ &ldvt, + /* work = */ static_cast(scratch.buffer.raw_ptr()), + /* lwork = */ &lwork, + /* iwork = */ iwork.data(), + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "svd_impl: sgesvdx_ failed with code " << info; + throw std::runtime_error(ss.str()); + } + + if (ns != K) { + std::stringstream ss; + ss << "svd_impl: expected " << K << " singular values, but " << ns + << " were computed."; + throw std::runtime_error(ss.str()); + } + } +} + +void SVD::eval(const std::vector& inputs, std::vector& outputs) { + if (!(inputs[0].dtype() == float32)) { + throw std::runtime_error("[SVD::eval] only supports float32."); + } + svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 0f2716a1b..067011260 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -890,4 +890,10 @@ void QRF::eval_gpu( throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI."); } +void SVD::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI."); +} + } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index bfe569041..ec93317f7 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -93,6 +93,7 @@ NO_GPU(Square) NO_GPU(Sqrt) NO_GPU(StopGradient) NO_GPU(Subtract) +NO_GPU_MULTI(SVD) NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Transpose) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 90304d96e..dca186143 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -200,4 +200,42 @@ std::pair qr(const array& a, StreamOrDevice s /* = {} */) { return std::make_pair(out[0], out[1]); } +std::vector svd(const array& a, StreamOrDevice s /* = {} */) { + if (a.dtype() != float32) { + std::ostringstream msg; + msg << "[linalg::svd] Input array must have type float32. Received array " + << "with type " << a.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + if (a.ndim() < 2) { + std::ostringstream msg; + msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array " + "with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + const auto m = a.shape(-2); + const auto n = a.shape(-1); + const auto rank = a.ndim(); + + std::vector u_shape = a.shape(); + u_shape[rank - 2] = m; + u_shape[rank - 1] = m; + + std::vector s_shape = a.shape(); + s_shape.pop_back(); + s_shape[rank - 2] = std::min(m, n); + + std::vector vt_shape = a.shape(); + vt_shape[rank - 2] = n; + vt_shape[rank - 1] = n; + + return array::make_arrays( + {u_shape, s_shape, vt_shape}, + {a.dtype(), a.dtype(), a.dtype()}, + std::make_unique(to_stream(s)), + {a}); +} + } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index c78d99476..521fdf4b5 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -62,4 +62,6 @@ norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) { std::pair qr(const array& a, StreamOrDevice s = {}); +std::vector svd(const array& a, StreamOrDevice s = {}); + } // namespace mlx::core::linalg diff --git a/mlx/primitives.h b/mlx/primitives.h index f0485b0d4..d428fc3ab 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1840,4 +1840,20 @@ class QRF : public Primitive { void eval(const std::vector& inputs, std::vector& outputs); }; +/* SVD primitive. */ +class SVD : public Primitive { + public: + explicit SVD(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(SVD) + + private: + void eval(const std::vector& inputs, std::vector& outputs); +}; + } // namespace mlx::core diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 1345f4a32..b0ca1cadf 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -16,6 +16,13 @@ using namespace py::literals; using namespace mlx::core; using namespace mlx::core::linalg; +namespace { +py::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) { + const auto result = svd(a, s); + return py::make_tuple(result.at(0), result.at(1), result.at(2)); +} +} // namespace + void init_linalg(py::module_& parent_module) { py::options options; options.disable_function_signatures(); @@ -186,7 +193,7 @@ void init_linalg(py::module_& parent_module) { R"pbdoc( qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array) - The QR factorizatoin of the input matrix. + The QR factorization of the input matrix. This function supports arrays with at least 2 dimensions. The matrices which are factorized are assumed to be in the last two dimensions of @@ -210,4 +217,28 @@ void init_linalg(py::module_& parent_module) { array([[-2.23607, -3.57771], [0, 0.447214]], dtype=float32) )pbdoc"); + m.def( + "svd", + &svd_helper, + "a"_a, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array) + + The Singular Value Decomposition (SVD) of the input matrix. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the function iterates over all indices of the first + a.ndim - 2 dimensions and for each combination SVD is applied to the last two indices. + + Args: + a (array): Input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + tuple(array, array, array): The ``U``, ``S``, and ``Vt`` matrices, such that + ``A = U @ diag(S) @ Vt`` + )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index dffa97148..c5f31505c 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -120,6 +120,22 @@ class TestLinalg(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(out, mx.eye(2), rtol=1e-5, atol=1e-7)) self.assertTrue(mx.allclose(mx.tril(r, -1), mx.zeros_like(r))) + def test_svd_decomposition(self): + A = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float32) + U, S, Vt = mx.linalg.svd(A, stream=mx.cpu) + self.assertTrue( + mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A, rtol=1e-5, atol=1e-7) + ) + + # Multiple matrices + B = A + 10.0 + AB = mx.stack([A, B]) + Us, Ss, Vts = mx.linalg.svd(AB, stream=mx.cpu) + for M, U, S, Vt in zip([A, B], Us, Ss, Vts): + self.assertTrue( + mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7) + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index fd3a25a35..5ab11ff8a 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -5,6 +5,7 @@ #include #include "mlx/mlx.h" +#include "mlx/ops.h" using namespace mlx::core; using namespace mlx::core::linalg; @@ -267,3 +268,35 @@ TEST_CASE("test QR factorization") { CHECK_EQ(Q.dtype(), float32); CHECK_EQ(R.dtype(), float32); } + +TEST_CASE("test SVD factorization") { + // 0D and 1D throw + CHECK_THROWS(linalg::svd(array(0.0))); + CHECK_THROWS(linalg::svd(array({0.0, 1.0}))); + + // Unsupported types throw + CHECK_THROWS(linalg::svd(array({0, 1}, {1, 2}))); + + const auto prng_key = random::key(42); + const auto A = mlx::core::random::normal({5, 4}, prng_key); + const auto outs = linalg::svd(A, Device::cpu); + CHECK_EQ(outs.size(), 3); + + const auto& U = outs[0]; + const auto& S = outs[1]; + const auto& Vt = outs[2]; + + CHECK_EQ(U.shape(), std::vector{5, 5}); + CHECK_EQ(S.shape(), std::vector{4}); + CHECK_EQ(Vt.shape(), std::vector{4, 4}); + + const auto U_slice = slice(U, {0, 0}, {U.shape(0), S.shape(0)}); + + const auto A_again = matmul(matmul(U_slice, diag(S)), Vt); + + CHECK( + allclose(A_again, A, /* rtol = */ 1e-4, /* atol = */ 1e-4).item()); + CHECK_EQ(U.dtype(), float32); + CHECK_EQ(S.dtype(), float32); + CHECK_EQ(Vt.dtype(), float32); +}