Async eval (#972)

This commit is contained in:
Awni Hannun 2024-04-09 18:34:00 -07:00 committed by GitHub
parent fffe072028
commit 99abb9eff4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 70 additions and 3 deletions

View File

@ -93,7 +93,9 @@ void array::detach() {
}
void array::eval() {
if (!is_evaled()) {
mlx::core::eval({*this});
}
}
bool array::is_tracer() const {

View File

@ -38,7 +38,13 @@ class Synchronizer : public Primitive {
// are currently under a function transformation.
int detail::InTracing::tracing_counter{0};
void eval(std::vector<array> outputs) {
std::shared_future<void> async_eval(std::vector<array> outputs) {
static std::shared_future<void> global_synchronizer;
// Catch up with previous async eval if needed
if (global_synchronizer.valid()) {
global_synchronizer.wait();
}
std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
@ -152,8 +158,12 @@ void eval(std::vector<array> outputs) {
scheduler::enqueue(stream, std::move(task));
}
}
global_synchronizer = std::move(deps[synchronizer.id()]);
return global_synchronizer;
}
deps[synchronizer.id()].wait();
void eval(std::vector<array> outputs) {
async_eval(std::move(outputs)).wait();
}
std::pair<std::vector<array>, std::vector<array>> vjp(

View File

@ -2,10 +2,13 @@
#pragma once
#include <future>
#include "mlx/array.h"
namespace mlx::core {
std::shared_future<void> async_eval(std::vector<array> outputs);
void eval(std::vector<array> outputs);
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>

View File

@ -595,6 +595,14 @@ class PyCheckpointedFun {
};
void init_transforms(nb::module_& m) {
nb::class_<std::shared_future<void>>(
m,
"Synchronizer",
R"pbdoc(
A synchronization object returned by :func:`async_eval`.
)pbdoc")
.def("wait", [](const std::shared_future<void>& f) { f.wait(); });
m.def(
"eval",
[](const nb::args& args) {
@ -615,6 +623,38 @@ void init_transforms(nb::module_& m) {
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
arrays are ignored.
)pbdoc");
m.def(
"async_eval",
[](const nb::args& args) {
std::vector<array> arrays = tree_flatten(args, false);
{
nb::gil_scoped_release nogil;
return async_eval(arrays);
}
},
nb::arg(),
nb::sig("def async_eval(*args) -> Synchronizer"),
R"pbdoc(
Asynchronously evaluate an :class:`array` or tree of :class:`array`.
.. warning::
You must call ``wait`` on the returned synchronization object before
using any arrays that are asynchronously evaluated.
.. note::
This is an experimental API and may change in future versions.
Args:
*args (arrays or trees of arrays): Each argument can be a single array
or a tree of arrays. If a tree is given the nodes can be a Python
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
arrays are ignored.
Returns:
Synchronizer: A synchronization object.
)pbdoc");
m.def(
"jvp",
[](const nb::callable& fun,

View File

@ -32,6 +32,18 @@ class TestEval(mlx_tests.MLXTestCase):
mx.eval(state)
self.assertEqual(x.item(), 3)
def test_async_eval(self):
x = mx.array(1) + mx.array(1) + mx.array(1)
sync = mx.async_eval(x)
sync.wait()
self.assertEqual(x.item(), 3)
# It should be safe to call eval on the array which has been async
# eval'ed
x = mx.array(1) + mx.array(1) + mx.array(1)
sync = mx.async_eval(x)
self.assertEqual(x.item(), 3)
if __name__ == "__main__":
unittest.main()