mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 06:24:35 +08:00
Shared events for synchronization + async eval (#998)
* more async eval * fix rebase * try correct async eval * fix async * more tests for async eval * use shared events for synchronization * comment + cleanup * with autorelease pool * fix no metal build * fix compile * fix patch * don't eval if asyn evale'd * don't use is_evaled * comments * more multi stream tests * try and cleanup use of is_evaled * use a status flag
This commit is contained in:
@@ -595,14 +595,6 @@ class PyCheckpointedFun {
|
||||
};
|
||||
|
||||
void init_transforms(nb::module_& m) {
|
||||
nb::class_<std::shared_future<void>>(
|
||||
m,
|
||||
"Synchronizer",
|
||||
R"pbdoc(
|
||||
A synchronization object returned by :func:`async_eval`.
|
||||
)pbdoc")
|
||||
.def("wait", [](const std::shared_future<void>& f) { f.wait(); });
|
||||
|
||||
m.def(
|
||||
"eval",
|
||||
[](const nb::args& args) {
|
||||
@@ -629,19 +621,14 @@ void init_transforms(nb::module_& m) {
|
||||
std::vector<array> arrays = tree_flatten(args, false);
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
return async_eval(arrays);
|
||||
async_eval(arrays);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
nb::sig("def async_eval(*args) -> Synchronizer"),
|
||||
nb::sig("def async_eval(*args)"),
|
||||
R"pbdoc(
|
||||
Asynchronously evaluate an :class:`array` or tree of :class:`array`.
|
||||
|
||||
.. warning::
|
||||
|
||||
You must call ``wait`` on the returned synchronization object before
|
||||
using any arrays that are asynchronously evaluated.
|
||||
|
||||
.. note::
|
||||
|
||||
This is an experimental API and may change in future versions.
|
||||
@@ -652,8 +639,17 @@ void init_transforms(nb::module_& m) {
|
||||
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
|
||||
arrays are ignored.
|
||||
|
||||
Returns:
|
||||
Synchronizer: A synchronization object.
|
||||
Example:
|
||||
>>> x = mx.array(1.0)
|
||||
>>> y = mx.exp(x)
|
||||
>>> mx.async_eval(y)
|
||||
>>> print(y)
|
||||
>>>
|
||||
>>> y = mx.exp(x)
|
||||
>>> mx.async_eval(y)
|
||||
>>> z = y + 3
|
||||
>>> mx.async_eval(z)
|
||||
>>> print(z)
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"jvp",
|
||||
|
Reference in New Issue
Block a user