Fix leak with multi-output primitives (#1274)

* fix leak with multi-output primitives

* hopefully an actual fix
This commit is contained in:
Awni Hannun 2024-07-23 06:34:18 -07:00 committed by GitHub
parent df124e018a
commit 1fba87b0df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 29 additions and 5 deletions

View File

@ -175,10 +175,11 @@ array::~array() {
return;
}
// Ignore arrays that will be detached
if (status() != array::Status::unscheduled) {
// Ignore arrays that might be detached during eval
if (status() == array::Status::scheduled) {
return;
}
// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;

View File

@ -75,6 +75,9 @@ std::function<void()> make_task(array arr, bool signal) {
if (!arr.is_tracer()) {
arr.detach();
}
for (auto& out : outputs) {
out.set_status(array::Status::available);
}
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
d.end_encoding(s.index);

View File

@ -141,10 +141,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
tape.pop();
// Set the status of the array and siblings.
auto status = async ? array::Status::scheduled : array::Status::available;
arr.set_status(status);
arr.set_status(array::Status::scheduled);
for (auto& s : arr.siblings()) {
s.set_status(status);
s.set_status(array::Status::scheduled);
}
auto stream = arr.primitive().stream();
@ -170,6 +169,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
if (!arr.is_tracer()) {
arr.detach();
}
for (auto& out : outputs) {
out.set_status(array::Status::available);
}
if (signal) {
arr.event().signal();
}

View File

@ -103,6 +103,23 @@ class TestEval(mlx_tests.MLXTestCase):
z = mx.add(y, x, stream=mx.cpu)
self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))
def test_multi_output_eval_during_transform(self):
x = mx.random.uniform(shape=(1024,))
y = mx.ones((1024,))
mx.eval(x, y)
def fn(x):
a, b = mx.divmod(x, x)
mx.eval(a)
return a
out = mx.vjp(fn, (x,), (y,))
out = mx.vjp(fn, (x,), (y,))
if mx.metal.is_available():
peak_mem = mx.metal.get_peak_memory()
out = mx.vjp(fn, (x,), (y,))
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
if __name__ == "__main__":
unittest.main()