From 077c1ee64a683d6b9289a1f733fe65bb9a8dea28 Mon Sep 17 00:00:00 2001 From: taher <8665427+nullhook@users.noreply.github.com> Date: Fri, 26 Jan 2024 09:27:31 -0800 Subject: [PATCH] QR factorization (#310) * add qr factorization --------- Co-authored-by: Awni Hannun --- .circleci/config.yml | 2 +- CMakeLists.txt | 25 +++- docs/src/python/linalg.rst | 1 + mlx/CMakeLists.txt | 2 +- mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/CMakeLists.txt | 1 + mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/common/qrf.cpp | 153 ++++++++++++++++++++++ mlx/backend/metal/primitives.cpp | 6 + mlx/backend/no_metal/primitives.cpp | 2 +- mlx/linalg.cpp | 30 ++++- mlx/linalg.h | 2 + mlx/ops.cpp | 4 +- mlx/ops.h | 4 +- mlx/primitives.h | 16 +++ python/src/linalg.cpp | 33 +++++ python/src/ops.cpp | 4 +- python/tests/test_linalg.py | 31 +++++ tests/CMakeLists.txt | 2 +- tests/linalg_tests.cpp | 21 ++- 20 files changed, 322 insertions(+), 19 deletions(-) create mode 100644 mlx/backend/common/qrf.cpp diff --git a/.circleci/config.yml b/.circleci/config.yml index c97f5dd73..2d7c9f771 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -29,7 +29,7 @@ jobs: pip install pybind11-stubgen pip install numpy sudo apt-get update - sudo apt-get install libblas-dev + sudo apt-get install libblas-dev liblapack-dev liblapacke-dev - run: name: Install Python package command: | diff --git a/CMakeLists.txt b/CMakeLists.txt index 7223b6594..048ca83fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,13 +31,13 @@ if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE}) message(FATAL_ERROR - "Building for x86_64 on macOS is not supported." + "Building for x86_64 on macOS is not supported." " If you are on an Apple silicon system, check the build" " documentation for possible fixes: " "https://ml-explore.github.io/mlx/build/html/install.html#build-from-source") elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64") - message(WARNING - "Building for x86_64 on macOS is not supported." + message(WARNING + "Building for x86_64 on macOS is not supported." " If you are on an Apple silicon system, " " make sure you are building for arm64.") elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64") @@ -75,7 +75,7 @@ elseif (MLX_BUILD_METAL) COMMAND_ERROR_IS_FATAL ANY) message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") - + if (${MACOS_VERSION} GREATER_EQUAL 14.2) set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip) elseif (${MACOS_VERSION} GREATER_EQUAL 14.0) @@ -123,16 +123,27 @@ else() /usr/include /usr/local/include $ENV{BLAS_HOME}/include) - message(STATUS ${BLAS_LIBRARIES}) - message(STATUS ${BLAS_INCLUDE_DIRS}) + message(STATUS "Blas lib" ${BLAS_LIBRARIES}) + message(STATUS "Blas incclude" ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_link_libraries(mlx ${BLAS_LIBRARIES}) + find_package(LAPACK REQUIRED) + if (NOT LAPACK_FOUND) + message(FATAL_ERROR "Must have LAPACK installed") + endif() + find_path(LAPACK_INCLUDE_DIRS lapacke.h + /usr/include + /usr/local/include) + message(STATUS "Lapack lib" ${LAPACK_LIBRARIES}) + message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) + target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) + target_link_libraries(mlx ${LAPACK_LIBRARIES}) endif() add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) target_include_directories( - mlx + mlx PUBLIC $ $ diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index 27746441e..0ac559f5e 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -9,3 +9,4 @@ Linear Algebra :toctree: _autosummary norm + qr diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 882bf93e0..bb546616c 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -19,7 +19,7 @@ target_sources( add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) -if (MLX_BUILD_ACCELERATE) +if (MLX_BUILD_ACCELERATE) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) else() target_sources( diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 8f2da02a2..777bb76f4 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -65,6 +65,7 @@ DEFAULT(Sort) DEFAULT(StopGradient) DEFAULT(Transpose) DEFAULT_MULTI(DivMod) +DEFAULT_MULTI(QRF) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 077a0353e..25563f66b 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -16,4 +16,5 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp ) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 1225fd1af..945451a2a 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -97,6 +97,7 @@ DEFAULT(Tan) DEFAULT(Tanh) DEFAULT(Transpose) DEFAULT_MULTI(DivMod) +DEFAULT_MULTI(QRF) namespace { diff --git a/mlx/backend/common/qrf.cpp b/mlx/backend/common/qrf.cpp new file mode 100644 index 000000000..41561cbad --- /dev/null +++ b/mlx/backend/common/qrf.cpp @@ -0,0 +1,153 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/backend/common/copy.h" +#include "mlx/primitives.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#else +#include +#endif + +namespace mlx::core { + +template +struct lpack; + +template <> +struct lpack { + static void xgeqrf( + const int* m, + const int* n, + float* a, + const int* lda, + float* tau, + float* work, + const int* lwork, + int* info) { + sgeqrf_(m, n, a, lda, tau, work, lwork, info); + } + static void xorgqr( + const int* m, + const int* n, + const int* k, + float* a, + const int* lda, + const float* tau, + float* work, + const int* lwork, + int* info) { + sorgqr_(m, n, k, a, lda, tau, work, lwork, info); + } +}; + +template +void qrf_impl(const array& a, array& q, array& r) { + const int M = a.shape(-2); + const int N = a.shape(-1); + const int lda = std::max(M, N); + size_t num_matrices = a.size() / (M * N); + int num_reflectors = std::min(M, N); + auto tau = + allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors); + + // Copy A to inplace input and make it col-contiguous + array in(a.shape(), float32, nullptr, {}); + auto flags = in.flags(); + + // Copy the input to be column contiguous + flags.col_contiguous = num_matrices == 1; + flags.row_contiguous = false; + std::vector strides = in.strides(); + strides[in.ndim() - 2] = 1; + strides[in.ndim() - 1] = M; + in.set_data( + allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags); + copy_inplace(a, in, CopyType::GeneralGeneral); + + T optimal_work; + int lwork = -1; + int info; + + // Compute workspace size + lpack::xgeqrf( + &M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info); + + // Update workspace size + lwork = optimal_work; + auto work = allocator::malloc_or_wait(sizeof(T) * lwork); + + // Loop over matrices + for (int i = 0; i < num_matrices; ++i) { + // Solve + lpack::xgeqrf( + &M, + &N, + in.data() + M * N * i, + &lda, + static_cast(tau.raw_ptr()) + num_reflectors * i, + static_cast(work.raw_ptr()), + &lwork, + &info); + } + allocator::free(work); + + r.set_data(allocator::malloc_or_wait(r.nbytes())); + copy_inplace(in, r, CopyType::General); + + for (int i = 0; i < num_matrices; ++i) { + // Zero lower triangle + for (int j = 0; j < r.shape(-2); ++j) { + for (int k = 0; k < j; ++k) { + r.data()[i * N * M + j * N + k] = 0; + } + } + } + + // Get work size + lwork = -1; + lpack::xorgqr( + &M, + &N, + &num_reflectors, + nullptr, + &lda, + nullptr, + &optimal_work, + &lwork, + &info); + lwork = optimal_work; + work = allocator::malloc_or_wait(sizeof(T) * lwork); + + // Loop over matrices + for (int i = 0; i < num_matrices; ++i) { + // Compute Q + lpack::xorgqr( + &M, + &N, + &num_reflectors, + in.data() + M * N * i, + &lda, + static_cast(tau.raw_ptr()) + num_reflectors * i, + static_cast(work.raw_ptr()), + &lwork, + &info); + } + + q.set_data(allocator::malloc_or_wait(q.nbytes())); + copy_inplace(in, q, CopyType::General); + + // Cleanup + allocator::free(work); + allocator::free(tau); +} + +void QRF::eval(const std::vector& inputs, std::vector& outputs) { + if (!(inputs[0].dtype() == float32)) { + throw std::runtime_error("[QRF::eval] only supports float32."); + } + qrf_impl(inputs[0], outputs[0], outputs[1]); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index d9e0619cd..e41368f6e 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -769,4 +769,10 @@ void Transpose::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } +void QRF::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI."); +} + } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 899f7caff..034fba760 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -90,5 +90,5 @@ NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Transpose) NO_GPU_MULTI(DivMod) - +NO_GPU_MULTI(QRF) } // namespace mlx::core diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 7e7264e3f..90304d96e 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -4,8 +4,9 @@ #include #include -#include "mlx/dtype.h" #include "mlx/linalg.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" namespace mlx::core::linalg { @@ -172,4 +173,31 @@ array norm( return matrix_norm(a, ord, ax, keepdims, s); } +std::pair qr(const array& a, StreamOrDevice s /* = {} */) { + if (a.dtype() != float32) { + std::ostringstream msg; + msg << "[linalg::qr] Arrays must type float32. Received array " + << "with type " << a.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + if (a.ndim() < 2) { + std::ostringstream msg; + msg << "[linalg::qr] Arrays must have >= 2 dimensions. Received array " + "with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (a.shape(-1) != a.shape(-2)) { + throw std::invalid_argument( + "[linalg::qr] Support for non-square matrices NYI."); + } + + auto out = array::make_arrays( + {a.shape(), a.shape()}, + {a.dtype(), a.dtype()}, + std::make_unique(to_stream(s)), + {astype(a, a.dtype(), s)}); + return std::make_pair(out[0], out[1]); +} + } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index 80e484eb5..c78d99476 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -60,4 +60,6 @@ norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) { return norm(a, std::vector{axis}, keepdims, s); } +std::pair qr(const array& a, StreamOrDevice s = {}); + } // namespace mlx::core::linalg diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 65d8d39a2..387f8e814 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -252,7 +252,7 @@ array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) { return astype(greater_equal(l, r, s), type, s); } -array tril(array x, int k, StreamOrDevice s /* = {} */) { +array tril(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) { if (x.ndim() < 2) { throw std::invalid_argument("[tril] array must be at least 2-D"); } @@ -260,7 +260,7 @@ array tril(array x, int k, StreamOrDevice s /* = {} */) { return where(mask, x, zeros_like(x, s), s); } -array triu(array x, int k, StreamOrDevice s /* = {} */) { +array triu(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) { if (x.ndim() < 2) { throw std::invalid_argument("[triu] array must be at least 2-D"); } diff --git a/mlx/ops.h b/mlx/ops.h index 19889165c..0d17f2d2c 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -123,8 +123,8 @@ inline array tri(int n, Dtype type, StreamOrDevice s = {}) { return tri(n, n, 0, type, s); } -array tril(array x, int k, StreamOrDevice s = {}); -array triu(array x, int k, StreamOrDevice s = {}); +array tril(array x, int k = 0, StreamOrDevice s = {}); +array triu(array x, int k = 0, StreamOrDevice s = {}); /** array manipulation */ diff --git a/mlx/primitives.h b/mlx/primitives.h index c0a176417..2f9f6d6b3 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1602,4 +1602,20 @@ class Transpose : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +/* QR Factorization primitive. */ +class QRF : public Primitive { + public: + explicit QRF(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(QRF) + + private: + void eval(const std::vector& inputs, std::vector& outputs); +}; + } // namespace mlx::core diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index ea5474a70..1345f4a32 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -177,4 +177,37 @@ void init_linalg(py::module_& parent_module) { >>> la.norm(m[0, :, :]), LA.norm(m[1, :, :]) (array(3.74166, dtype=float32), array(11.225, dtype=float32)) )pbdoc"); + m.def( + "qr", + &qr, + "a"_a, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array) + + The QR factorizatoin of the input matrix. + + This function supports arrays with at least 2 dimensions. The matrices + which are factorized are assumed to be in the last two dimensions of + the input. + + Args: + a (array): Input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + tuple(array, array): The ``Q`` and ``R`` matrices. + + Example: + >>> A = mx.array([[2., 3.], [1., 2.]]) + >>> Q, R = mx.linalg.qr(A, stream=mx.cpu) + >>> Q + array([[-0.894427, -0.447214], + [-0.447214, 0.894427]], dtype=float32) + >>> R + array([[-2.23607, -3.57771], + [0, 0.447214]], dtype=float32) + )pbdoc"); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1156391f7..432b98d76 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -55,7 +55,7 @@ void init_ops(py::module_& m) { Args: a (array): Input array. shape (tuple(int)): New shape. - stream (Stream, optional): Stream or device. Defaults to ```None``` + stream (Stream, optional): Stream or device. Defaults to ``None`` in which case the default stream of the default device is used. Returns: @@ -112,7 +112,7 @@ void init_ops(py::module_& m) { Args: a (array): Input array. axis (int or tuple(int), optional): Axes to remove. Defaults - to ```None``` in which case all size one axes are removed. + to ``None`` in which case all size one axes are removed. Returns: array: The output array with size one axes removed. diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index ac86c1e11..dffa97148 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -89,6 +89,37 @@ class TestLinalg(mlx_tests.MLXTestCase): out_mx = mx.linalg.norm(x_mx, ord="fro") self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + def test_qr_factorization(self): + with self.assertRaises(ValueError): + mx.linalg.qr(mx.array(0.0)) + + with self.assertRaises(ValueError): + mx.linalg.qr(mx.array([0.0, 1.0])) + + with self.assertRaises(ValueError): + mx.linalg.qr(mx.array([[0, 1], [1, 0]])) + + A = mx.array([[2.0, 3.0], [1.0, 2.0]]) + Q, R = mx.linalg.qr(A, stream=mx.cpu) + out = Q @ R + self.assertTrue(mx.allclose(out, A)) + out = Q @ Q + self.assertTrue(mx.allclose(out, mx.eye(2), rtol=1e-5, atol=1e-7)) + self.assertTrue(mx.allclose(mx.tril(R, -1), mx.zeros_like(R))) + self.assertEqual(Q.dtype, mx.float32) + self.assertEqual(R.dtype, mx.float32) + + # Multiple matrices + B = mx.array([[-1.0, 2.0], [-4.0, 1.0]]) + A = mx.stack([A, B]) + Q, R = mx.linalg.qr(A, stream=mx.cpu) + for a, q, r in zip(A, Q, R): + out = q @ r + self.assertTrue(mx.allclose(out, a)) + out = q @ q + self.assertTrue(mx.allclose(out, mx.eye(2), rtol=1e-5, atol=1e-7)) + self.assertTrue(mx.allclose(mx.tril(r, -1), mx.zeros_like(r))) + if __name__ == "__main__": unittest.main() diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index dbc499205..838650c2a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -14,7 +14,7 @@ if (MLX_BUILD_METAL) ) endif() -target_sources(tests PRIVATE +target_sources(tests PRIVATE allocator_tests.cpp array_tests.cpp arg_reduce_tests.cpp diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 1bf02c243..fd3a25a35 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include "doctest/doctest.h" @@ -248,3 +248,22 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") { array({14.28285686, 39.7617907})) .item()); } + +TEST_CASE("test QR factorization") { + // 0D and 1D throw + CHECK_THROWS(linalg::qr(array(0.0))); + CHECK_THROWS(linalg::qr(array({0.0, 1.0}))); + + // Unsupported types throw + CHECK_THROWS(linalg::qr(array({0, 1}, {1, 2}))); + + array A = array({{2., 3., 1., 2.}, {2, 2}}); + auto [Q, R] = linalg::qr(A, Device::cpu); + auto out = matmul(Q, R); + CHECK(allclose(out, A).item()); + out = matmul(Q, Q); + CHECK(allclose(out, eye(2), 1e-5, 1e-7).item()); + CHECK(allclose(tril(R, -1), zeros_like(R)).item()); + CHECK_EQ(Q.dtype(), float32); + CHECK_EQ(R.dtype(), float32); +}