Shared events for synchronization + async eval (#998)

* more async eval

* fix rebase

* try correct async eval

* fix async

* more tests for async eval

* use shared events for synchronization

* comment + cleanup

* with autorelease pool

* fix no metal build

* fix compile

* fix patch

* don't eval if asyn evale'd

* don't use is_evaled

* comments

* more multi stream tests

* try and cleanup use of is_evaled

* use a status flag
This commit is contained in:
Awni Hannun
2024-04-17 06:16:02 -07:00
committed by GitHub
parent b18468bf81
commit 8a0677d56d
28 changed files with 424 additions and 125 deletions

View File

@@ -946,10 +946,7 @@ void init_array(nb::module_& m) {
.def(
"__repr__",
[](array& a) {
if (!a.is_evaled()) {
nb::gil_scoped_release nogil;
a.eval();
}
nb::gil_scoped_release nogil;
std::ostringstream os;
os << a;
return os.str();

View File

@@ -86,7 +86,7 @@ extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
std::memset(view, 0, sizeof(Py_buffer));
auto a = nb::cast<array>(nb::handle(obj));
if (!a.is_evaled()) {
{
nb::gil_scoped_release nogil;
a.eval();
}

View File

@@ -104,8 +104,7 @@ template <typename Lib, typename T>
nb::ndarray<Lib> mlx_to_nd_array(
array a,
std::optional<nb::dlpack::dtype> t = {}) {
// Eval if not already evaled
if (!a.is_evaled()) {
{
nb::gil_scoped_release nogil;
a.eval();
}

View File

@@ -595,14 +595,6 @@ class PyCheckpointedFun {
};
void init_transforms(nb::module_& m) {
nb::class_<std::shared_future<void>>(
m,
"Synchronizer",
R"pbdoc(
A synchronization object returned by :func:`async_eval`.
)pbdoc")
.def("wait", [](const std::shared_future<void>& f) { f.wait(); });
m.def(
"eval",
[](const nb::args& args) {
@@ -629,19 +621,14 @@ void init_transforms(nb::module_& m) {
std::vector<array> arrays = tree_flatten(args, false);
{
nb::gil_scoped_release nogil;
return async_eval(arrays);
async_eval(arrays);
}
},
nb::arg(),
nb::sig("def async_eval(*args) -> Synchronizer"),
nb::sig("def async_eval(*args)"),
R"pbdoc(
Asynchronously evaluate an :class:`array` or tree of :class:`array`.
.. warning::
You must call ``wait`` on the returned synchronization object before
using any arrays that are asynchronously evaluated.
.. note::
This is an experimental API and may change in future versions.
@@ -652,8 +639,17 @@ void init_transforms(nb::module_& m) {
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
arrays are ignored.
Returns:
Synchronizer: A synchronization object.
Example:
>>> x = mx.array(1.0)
>>> y = mx.exp(x)
>>> mx.async_eval(y)
>>> print(y)
>>>
>>> y = mx.exp(x)
>>> mx.async_eval(y)
>>> z = y + 3
>>> mx.async_eval(z)
>>> print(z)
)pbdoc");
m.def(
"jvp",