From 99abb9eff4779700741c3faa92d7fdcb259e2022 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 9 Apr 2024 18:34:00 -0700 Subject: [PATCH] Async eval (#972) --- mlx/array.cpp | 4 +++- mlx/transforms.cpp | 14 ++++++++++++-- mlx/transforms.h | 3 +++ python/src/transforms.cpp | 40 +++++++++++++++++++++++++++++++++++++++ python/tests/test_eval.py | 12 ++++++++++++ 5 files changed, 70 insertions(+), 3 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index 771e72676..edbaad395 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -93,7 +93,9 @@ void array::detach() { } void array::eval() { - mlx::core::eval({*this}); + if (!is_evaled()) { + mlx::core::eval({*this}); + } } bool array::is_tracer() const { diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 8a07bfd6a..bb39492a8 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -38,7 +38,13 @@ class Synchronizer : public Primitive { // are currently under a function transformation. int detail::InTracing::tracing_counter{0}; -void eval(std::vector outputs) { +std::shared_future async_eval(std::vector outputs) { + static std::shared_future global_synchronizer; + // Catch up with previous async eval if needed + if (global_synchronizer.valid()) { + global_synchronizer.wait(); + } + std::function recurse; std::queue tape; std::unordered_set cache; @@ -152,8 +158,12 @@ void eval(std::vector outputs) { scheduler::enqueue(stream, std::move(task)); } } + global_synchronizer = std::move(deps[synchronizer.id()]); + return global_synchronizer; +} - deps[synchronizer.id()].wait(); +void eval(std::vector outputs) { + async_eval(std::move(outputs)).wait(); } std::pair, std::vector> vjp( diff --git a/mlx/transforms.h b/mlx/transforms.h index a83765100..eb2c26780 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -2,10 +2,13 @@ #pragma once +#include #include "mlx/array.h" namespace mlx::core { +std::shared_future async_eval(std::vector outputs); + void eval(std::vector outputs); template > diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 01b67b6a3..32df46e1e 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -595,6 +595,14 @@ class PyCheckpointedFun { }; void init_transforms(nb::module_& m) { + nb::class_>( + m, + "Synchronizer", + R"pbdoc( + A synchronization object returned by :func:`async_eval`. + )pbdoc") + .def("wait", [](const std::shared_future& f) { f.wait(); }); + m.def( "eval", [](const nb::args& args) { @@ -615,6 +623,38 @@ void init_transforms(nb::module_& m) { :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not arrays are ignored. )pbdoc"); + m.def( + "async_eval", + [](const nb::args& args) { + std::vector arrays = tree_flatten(args, false); + { + nb::gil_scoped_release nogil; + return async_eval(arrays); + } + }, + nb::arg(), + nb::sig("def async_eval(*args) -> Synchronizer"), + R"pbdoc( + Asynchronously evaluate an :class:`array` or tree of :class:`array`. + + .. warning:: + + You must call ``wait`` on the returned synchronization object before + using any arrays that are asynchronously evaluated. + + .. note:: + + This is an experimental API and may change in future versions. + + Args: + *args (arrays or trees of arrays): Each argument can be a single array + or a tree of arrays. If a tree is given the nodes can be a Python + :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not + arrays are ignored. + + Returns: + Synchronizer: A synchronization object. + )pbdoc"); m.def( "jvp", [](const nb::callable& fun, diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index dc986a19a..d4930ae81 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -32,6 +32,18 @@ class TestEval(mlx_tests.MLXTestCase): mx.eval(state) self.assertEqual(x.item(), 3) + def test_async_eval(self): + x = mx.array(1) + mx.array(1) + mx.array(1) + sync = mx.async_eval(x) + sync.wait() + self.assertEqual(x.item(), 3) + + # It should be safe to call eval on the array which has been async + # eval'ed + x = mx.array(1) + mx.array(1) + mx.array(1) + sync = mx.async_eval(x) + self.assertEqual(x.item(), 3) + if __name__ == "__main__": unittest.main()