mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Async eval (#972)
This commit is contained in:
parent
fffe072028
commit
99abb9eff4
@ -93,7 +93,9 @@ void array::detach() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void array::eval() {
|
void array::eval() {
|
||||||
mlx::core::eval({*this});
|
if (!is_evaled()) {
|
||||||
|
mlx::core::eval({*this});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool array::is_tracer() const {
|
bool array::is_tracer() const {
|
||||||
|
@ -38,7 +38,13 @@ class Synchronizer : public Primitive {
|
|||||||
// are currently under a function transformation.
|
// are currently under a function transformation.
|
||||||
int detail::InTracing::tracing_counter{0};
|
int detail::InTracing::tracing_counter{0};
|
||||||
|
|
||||||
void eval(std::vector<array> outputs) {
|
std::shared_future<void> async_eval(std::vector<array> outputs) {
|
||||||
|
static std::shared_future<void> global_synchronizer;
|
||||||
|
// Catch up with previous async eval if needed
|
||||||
|
if (global_synchronizer.valid()) {
|
||||||
|
global_synchronizer.wait();
|
||||||
|
}
|
||||||
|
|
||||||
std::function<void(const array&)> recurse;
|
std::function<void(const array&)> recurse;
|
||||||
std::queue<array> tape;
|
std::queue<array> tape;
|
||||||
std::unordered_set<std::uintptr_t> cache;
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
@ -152,8 +158,12 @@ void eval(std::vector<array> outputs) {
|
|||||||
scheduler::enqueue(stream, std::move(task));
|
scheduler::enqueue(stream, std::move(task));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
global_synchronizer = std::move(deps[synchronizer.id()]);
|
||||||
|
return global_synchronizer;
|
||||||
|
}
|
||||||
|
|
||||||
deps[synchronizer.id()].wait();
|
void eval(std::vector<array> outputs) {
|
||||||
|
async_eval(std::move(outputs)).wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||||
|
@ -2,10 +2,13 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <future>
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
std::shared_future<void> async_eval(std::vector<array> outputs);
|
||||||
|
|
||||||
void eval(std::vector<array> outputs);
|
void eval(std::vector<array> outputs);
|
||||||
|
|
||||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||||
|
@ -595,6 +595,14 @@ class PyCheckpointedFun {
|
|||||||
};
|
};
|
||||||
|
|
||||||
void init_transforms(nb::module_& m) {
|
void init_transforms(nb::module_& m) {
|
||||||
|
nb::class_<std::shared_future<void>>(
|
||||||
|
m,
|
||||||
|
"Synchronizer",
|
||||||
|
R"pbdoc(
|
||||||
|
A synchronization object returned by :func:`async_eval`.
|
||||||
|
)pbdoc")
|
||||||
|
.def("wait", [](const std::shared_future<void>& f) { f.wait(); });
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"eval",
|
"eval",
|
||||||
[](const nb::args& args) {
|
[](const nb::args& args) {
|
||||||
@ -615,6 +623,38 @@ void init_transforms(nb::module_& m) {
|
|||||||
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
|
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
|
||||||
arrays are ignored.
|
arrays are ignored.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"async_eval",
|
||||||
|
[](const nb::args& args) {
|
||||||
|
std::vector<array> arrays = tree_flatten(args, false);
|
||||||
|
{
|
||||||
|
nb::gil_scoped_release nogil;
|
||||||
|
return async_eval(arrays);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
nb::arg(),
|
||||||
|
nb::sig("def async_eval(*args) -> Synchronizer"),
|
||||||
|
R"pbdoc(
|
||||||
|
Asynchronously evaluate an :class:`array` or tree of :class:`array`.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
You must call ``wait`` on the returned synchronization object before
|
||||||
|
using any arrays that are asynchronously evaluated.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
This is an experimental API and may change in future versions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args (arrays or trees of arrays): Each argument can be a single array
|
||||||
|
or a tree of arrays. If a tree is given the nodes can be a Python
|
||||||
|
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
|
||||||
|
arrays are ignored.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Synchronizer: A synchronization object.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"jvp",
|
"jvp",
|
||||||
[](const nb::callable& fun,
|
[](const nb::callable& fun,
|
||||||
|
@ -32,6 +32,18 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
mx.eval(state)
|
mx.eval(state)
|
||||||
self.assertEqual(x.item(), 3)
|
self.assertEqual(x.item(), 3)
|
||||||
|
|
||||||
|
def test_async_eval(self):
|
||||||
|
x = mx.array(1) + mx.array(1) + mx.array(1)
|
||||||
|
sync = mx.async_eval(x)
|
||||||
|
sync.wait()
|
||||||
|
self.assertEqual(x.item(), 3)
|
||||||
|
|
||||||
|
# It should be safe to call eval on the array which has been async
|
||||||
|
# eval'ed
|
||||||
|
x = mx.array(1) + mx.array(1) + mx.array(1)
|
||||||
|
sync = mx.async_eval(x)
|
||||||
|
self.assertEqual(x.item(), 3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user