mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Async eval (#972)
This commit is contained in:
@@ -595,6 +595,14 @@ class PyCheckpointedFun {
|
||||
};
|
||||
|
||||
void init_transforms(nb::module_& m) {
|
||||
nb::class_<std::shared_future<void>>(
|
||||
m,
|
||||
"Synchronizer",
|
||||
R"pbdoc(
|
||||
A synchronization object returned by :func:`async_eval`.
|
||||
)pbdoc")
|
||||
.def("wait", [](const std::shared_future<void>& f) { f.wait(); });
|
||||
|
||||
m.def(
|
||||
"eval",
|
||||
[](const nb::args& args) {
|
||||
@@ -615,6 +623,38 @@ void init_transforms(nb::module_& m) {
|
||||
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
|
||||
arrays are ignored.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"async_eval",
|
||||
[](const nb::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args, false);
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
return async_eval(arrays);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
nb::sig("def async_eval(*args) -> Synchronizer"),
|
||||
R"pbdoc(
|
||||
Asynchronously evaluate an :class:`array` or tree of :class:`array`.
|
||||
|
||||
.. warning::
|
||||
|
||||
You must call ``wait`` on the returned synchronization object before
|
||||
using any arrays that are asynchronously evaluated.
|
||||
|
||||
.. note::
|
||||
|
||||
This is an experimental API and may change in future versions.
|
||||
|
||||
Args:
|
||||
*args (arrays or trees of arrays): Each argument can be a single array
|
||||
or a tree of arrays. If a tree is given the nodes can be a Python
|
||||
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
|
||||
arrays are ignored.
|
||||
|
||||
Returns:
|
||||
Synchronizer: A synchronization object.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"jvp",
|
||||
[](const nb::callable& fun,
|
||||
|
@@ -32,6 +32,18 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
mx.eval(state)
|
||||
self.assertEqual(x.item(), 3)
|
||||
|
||||
def test_async_eval(self):
|
||||
x = mx.array(1) + mx.array(1) + mx.array(1)
|
||||
sync = mx.async_eval(x)
|
||||
sync.wait()
|
||||
self.assertEqual(x.item(), 3)
|
||||
|
||||
# It should be safe to call eval on the array which has been async
|
||||
# eval'ed
|
||||
x = mx.array(1) + mx.array(1) + mx.array(1)
|
||||
sync = mx.async_eval(x)
|
||||
self.assertEqual(x.item(), 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user