diff --git a/mlx/array.cpp b/mlx/array.cpp index 29a86aaa0..d5c141e70 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -175,10 +175,11 @@ array::~array() { return; } - // Ignore arrays that will be detached - if (status() != array::Status::unscheduled) { + // Ignore arrays that might be detached during eval + if (status() == array::Status::scheduled) { return; } + // Break circular reference for non-detached arrays with siblings if (auto n = siblings().size(); n > 0) { bool do_detach = true; diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 3afe47159..0bdb60ede 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -75,6 +75,9 @@ std::function make_task(array arr, bool signal) { if (!arr.is_tracer()) { arr.detach(); } + for (auto& out : outputs) { + out.set_status(array::Status::available); + } if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) { d.end_encoding(s.index); diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 0624358e2..c8345e2cc 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -141,10 +141,9 @@ array eval_impl(std::vector outputs, bool async) { tape.pop(); // Set the status of the array and siblings. - auto status = async ? array::Status::scheduled : array::Status::available; - arr.set_status(status); + arr.set_status(array::Status::scheduled); for (auto& s : arr.siblings()) { - s.set_status(status); + s.set_status(array::Status::scheduled); } auto stream = arr.primitive().stream(); @@ -170,6 +169,10 @@ array eval_impl(std::vector outputs, bool async) { if (!arr.is_tracer()) { arr.detach(); } + for (auto& out : outputs) { + out.set_status(array::Status::available); + } + if (signal) { arr.event().signal(); } diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 7b972ac1c..9f98c9600 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -103,6 +103,23 @@ class TestEval(mlx_tests.MLXTestCase): z = mx.add(y, x, stream=mx.cpu) self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0))) + def test_multi_output_eval_during_transform(self): + x = mx.random.uniform(shape=(1024,)) + y = mx.ones((1024,)) + mx.eval(x, y) + + def fn(x): + a, b = mx.divmod(x, x) + mx.eval(a) + return a + + out = mx.vjp(fn, (x,), (y,)) + out = mx.vjp(fn, (x,), (y,)) + if mx.metal.is_available(): + peak_mem = mx.metal.get_peak_memory() + out = mx.vjp(fn, (x,), (y,)) + self.assertEqual(peak_mem, mx.metal.get_peak_memory()) + if __name__ == "__main__": unittest.main()