mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 21:04:41 +08:00
Shared events for synchronization + async eval (#998)
* more async eval * fix rebase * try correct async eval * fix async * more tests for async eval * use shared events for synchronization * comment + cleanup * with autorelease pool * fix no metal build * fix compile * fix patch * don't eval if asyn evale'd * don't use is_evaled * comments * more multi stream tests * try and cleanup use of is_evaled * use a status flag
This commit is contained in:
@@ -946,10 +946,7 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__repr__",
|
||||
[](array& a) {
|
||||
if (!a.is_evaled()) {
|
||||
nb::gil_scoped_release nogil;
|
||||
a.eval();
|
||||
}
|
||||
nb::gil_scoped_release nogil;
|
||||
std::ostringstream os;
|
||||
os << a;
|
||||
return os.str();
|
||||
|
@@ -86,7 +86,7 @@ extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
|
||||
std::memset(view, 0, sizeof(Py_buffer));
|
||||
auto a = nb::cast<array>(nb::handle(obj));
|
||||
|
||||
if (!a.is_evaled()) {
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
a.eval();
|
||||
}
|
||||
|
@@ -104,8 +104,7 @@ template <typename Lib, typename T>
|
||||
nb::ndarray<Lib> mlx_to_nd_array(
|
||||
array a,
|
||||
std::optional<nb::dlpack::dtype> t = {}) {
|
||||
// Eval if not already evaled
|
||||
if (!a.is_evaled()) {
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
a.eval();
|
||||
}
|
||||
|
@@ -595,14 +595,6 @@ class PyCheckpointedFun {
|
||||
};
|
||||
|
||||
void init_transforms(nb::module_& m) {
|
||||
nb::class_<std::shared_future<void>>(
|
||||
m,
|
||||
"Synchronizer",
|
||||
R"pbdoc(
|
||||
A synchronization object returned by :func:`async_eval`.
|
||||
)pbdoc")
|
||||
.def("wait", [](const std::shared_future<void>& f) { f.wait(); });
|
||||
|
||||
m.def(
|
||||
"eval",
|
||||
[](const nb::args& args) {
|
||||
@@ -629,19 +621,14 @@ void init_transforms(nb::module_& m) {
|
||||
std::vector<array> arrays = tree_flatten(args, false);
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
return async_eval(arrays);
|
||||
async_eval(arrays);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
nb::sig("def async_eval(*args) -> Synchronizer"),
|
||||
nb::sig("def async_eval(*args)"),
|
||||
R"pbdoc(
|
||||
Asynchronously evaluate an :class:`array` or tree of :class:`array`.
|
||||
|
||||
.. warning::
|
||||
|
||||
You must call ``wait`` on the returned synchronization object before
|
||||
using any arrays that are asynchronously evaluated.
|
||||
|
||||
.. note::
|
||||
|
||||
This is an experimental API and may change in future versions.
|
||||
@@ -652,8 +639,17 @@ void init_transforms(nb::module_& m) {
|
||||
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
|
||||
arrays are ignored.
|
||||
|
||||
Returns:
|
||||
Synchronizer: A synchronization object.
|
||||
Example:
|
||||
>>> x = mx.array(1.0)
|
||||
>>> y = mx.exp(x)
|
||||
>>> mx.async_eval(y)
|
||||
>>> print(y)
|
||||
>>>
|
||||
>>> y = mx.exp(x)
|
||||
>>> mx.async_eval(y)
|
||||
>>> z = y + 3
|
||||
>>> mx.async_eval(z)
|
||||
>>> print(z)
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"jvp",
|
||||
|
@@ -34,16 +34,75 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_async_eval(self):
|
||||
x = mx.array(1) + mx.array(1) + mx.array(1)
|
||||
sync = mx.async_eval(x)
|
||||
sync.wait()
|
||||
mx.async_eval(x)
|
||||
self.assertEqual(x.item(), 3)
|
||||
|
||||
# It should be safe to call eval on the array which has been async
|
||||
# eval'ed
|
||||
x = mx.array(1) + mx.array(1) + mx.array(1)
|
||||
sync = mx.async_eval(x)
|
||||
self.assertEqual(x.item(), 3)
|
||||
|
||||
x = mx.array([1, 2, 3])
|
||||
y = 2 * x
|
||||
mx.async_eval(y)
|
||||
z = 2 * y
|
||||
mx.async_eval(z)
|
||||
self.assertTrue(mx.array_equal(y, mx.array([2, 4, 6])))
|
||||
self.assertTrue(mx.array_equal(z, mx.array([4, 8, 12])))
|
||||
|
||||
def test_async_eval_twice(self):
|
||||
x = mx.array(1) + mx.array(1) + mx.array(1)
|
||||
mx.async_eval(x)
|
||||
y = x + 1
|
||||
mx.async_eval(y)
|
||||
self.assertEqual(x.item(), 3)
|
||||
|
||||
def test_async_eval_in_trace(self):
|
||||
def fun(x):
|
||||
y = x + 1.0
|
||||
mx.async_eval(y)
|
||||
return mx.exp(y)
|
||||
|
||||
# Raises
|
||||
with self.assertRaises(ValueError):
|
||||
mx.grad(fun)(mx.array(1.0))
|
||||
|
||||
# Also raises
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(fun)(mx.ones((2, 2)))
|
||||
|
||||
def test_async_eval_into_eval(self):
|
||||
x = mx.array(1)
|
||||
y = x + 1
|
||||
mx.async_eval(y)
|
||||
a = y - 10
|
||||
b = mx.abs(a)
|
||||
self.assertEqual(b.item(), 8)
|
||||
|
||||
def test_async_eval_into_eval_diff_stream(self):
|
||||
s = mx.new_stream(mx.cpu)
|
||||
x = mx.array(0)
|
||||
y = x - 5
|
||||
mx.async_eval(y)
|
||||
z = mx.abs(y, stream=s)
|
||||
self.assertEqual(z.item(), 5)
|
||||
|
||||
def test_eval_slow_fast_multi_stream(self):
|
||||
x = mx.ones((8000,))
|
||||
y = mx.abs(mx.array(-1.0))
|
||||
for _ in range(20):
|
||||
x = x + mx.array(1.0)
|
||||
z = mx.add(x, y, stream=mx.cpu)
|
||||
self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))
|
||||
|
||||
# Switch eval order
|
||||
x = mx.ones((8000,))
|
||||
y = mx.abs(mx.array(-1.0))
|
||||
for _ in range(20):
|
||||
x = x + mx.array(1.0)
|
||||
z = mx.add(y, x, stream=mx.cpu)
|
||||
self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -24,12 +24,12 @@ class TestMetal(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
|
||||
|
||||
# Query active and peak memory
|
||||
a = mx.zeros((4096,))
|
||||
a = mx.zeros((4096,), stream=mx.cpu)
|
||||
mx.eval(a)
|
||||
active_mem = mx.metal.get_active_memory()
|
||||
self.assertTrue(active_mem >= 4096 * 4)
|
||||
|
||||
b = mx.zeros((4096,))
|
||||
b = mx.zeros((4096,), stream=mx.cpu)
|
||||
mx.eval(b)
|
||||
del b
|
||||
|
||||
|
Reference in New Issue
Block a user