diff --git a/mlx/backend/cpu/eval.cpp b/mlx/backend/cpu/eval.cpp index b23c8d561..3b7cd7977 100644 --- a/mlx/backend/cpu/eval.cpp +++ b/mlx/backend/cpu/eval.cpp @@ -20,21 +20,16 @@ void eval(array& arr) { } arr.primitive().eval_cpu(arr.inputs(), outputs); } +} - std::unordered_set> buffers; - for (auto& in : arr.inputs()) { - buffers.insert(in.data_shared_ptr()); - } - for (auto& s : arr.siblings()) { - buffers.insert(s.data_shared_ptr()); - } - // Remove the output if it was donated to by an input - if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { - buffers.erase(it); - } +void finalize( + Stream s, + std::unordered_set> retain_buffers) { auto& encoder = cpu::get_command_encoder(s); - encoder.dispatch([buffers = std::move(buffers), - temps = std::move(encoder.temporaries())]() {}); + encoder.dispatch([s, + buffers = std::move(retain_buffers), + temps = std::move(encoder.temporaries())]() { + }); } } // namespace mlx::core::cpu diff --git a/mlx/backend/cpu/eval.h b/mlx/backend/cpu/eval.h index 20156d61b..4affe0bf1 100644 --- a/mlx/backend/cpu/eval.h +++ b/mlx/backend/cpu/eval.h @@ -2,11 +2,16 @@ #pragma once +#include + #include "mlx/array.h" #include "mlx/stream.h" namespace mlx::core::cpu { void eval(array& arr); +void finalize( + Stream s, + std::unordered_set> retain_buffers); } // namespace mlx::core::cpu diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 930e570e2..b95a6de5e 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index a9a1bc4f6..178f0a840 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -40,46 +40,32 @@ void eval(array& arr) { debug_set_primitive_buffer_label(command_buffer, arr.primitive()); arr.primitive().eval_gpu(arr.inputs(), outputs); } - std::unordered_set> buffers; - for (auto& in : arr.inputs()) { - buffers.insert(in.data_shared_ptr()); - } - for (auto& s : arr.siblings()) { - buffers.insert(s.data_shared_ptr()); - } - // Remove the output if it was donated to by an input - if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { - buffers.erase(it); - } +} - if (d.command_buffer_needs_commit(s.index)) { +void finalize( + Stream s, + std::unordered_set> retain_buffers, + bool force_commit) { + auto pool = new_scoped_memory_pool(); + auto& d = metal::device(s.device); + auto command_buffer = d.get_command_buffer(s.index); + if (d.command_buffer_needs_commit(s.index) || force_commit) { d.end_encoding(s.index); scheduler::notify_new_task(s); command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + [s, buffers = std::move(retain_buffers)](MTL::CommandBuffer* cbuf) { scheduler::notify_task_completion(s); check_error(cbuf); }); d.commit_command_buffer(s.index); - d.get_command_buffer(s.index); } else { command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + [s, buffers = std::move(retain_buffers)](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); } } -void finalize(Stream s) { - auto pool = new_scoped_memory_pool(); - auto& d = metal::device(s.device); - auto cb = d.get_command_buffer(s.index); - d.end_encoding(s.index); - cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); - d.commit_command_buffer(s.index); - d.get_command_buffer(s.index); -} - void synchronize(Stream s) { auto pool = new_scoped_memory_pool(); auto& d = metal::device(s.device); diff --git a/mlx/backend/metal/metal_impl.h b/mlx/backend/metal/metal_impl.h index 9ca8d2f80..7604c248e 100644 --- a/mlx/backend/metal/metal_impl.h +++ b/mlx/backend/metal/metal_impl.h @@ -4,6 +4,7 @@ #include #include +#include #include "mlx/array.h" #include "mlx/stream.h" @@ -15,7 +16,10 @@ void new_stream(Stream stream); std::unique_ptr> new_scoped_memory_pool(); void eval(array& arr); -void finalize(Stream s); +void finalize( + Stream s, + std::unordered_set> retain_buffers, + bool force_commit); void synchronize(Stream s); } // namespace mlx::core::metal diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index ef9af8800..8bc474376 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -21,7 +21,10 @@ void eval(array&) { "[metal::eval] Cannot eval on GPU without metal backend"); } -void finalize(Stream) { +void finalize( + Stream, + std::unordered_set> retain_buffers, + bool) { throw std::runtime_error( "[metal::finalize] Cannot finalize GPU without metal backend"); } diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 54f3b302b..59ef172b3 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -160,6 +160,7 @@ array eval_impl(std::vector outputs, bool async) { } auto it = cache.find(in.id()); + it->second -= 1; if (it->second != 0) { @@ -180,6 +181,8 @@ array eval_impl(std::vector outputs, bool async) { } } + std::unordered_map> + unretained_buffers; while (!tape.empty()) { auto arr = std::move(tape.back()); tape.pop_back(); @@ -225,7 +228,7 @@ array eval_impl(std::vector outputs, bool async) { // Commit any open streams for (auto& [_, e] : events) { if (e.stream().device == Device::gpu) { - metal::finalize(e.stream()); + metal::finalize(e.stream(), {}, true); } } scheduler::wait_for_one(); @@ -246,24 +249,59 @@ array eval_impl(std::vector outputs, bool async) { }; arr.set_status(array::Status::evaluated); - // TODO Maybe always want the fence coherent kernel in the same cbuf - // as the other kernels? + + std::unordered_set> retain_buffers; + maybe_update_fence(arr); for (auto& sib : arr.siblings()) { sib.set_status(array::Status::evaluated); maybe_update_fence(sib); + retain_buffers.insert(sib.data_shared_ptr()); } + + for (auto& in : arr.inputs()) { + retain_buffers.insert(in.data_shared_ptr()); + } + if (!arr.is_tracer()) { arr.detach(); } - } + for (auto it = retain_buffers.begin(); it != retain_buffers.end();) { + if (it->use_count() > 1) { + // At this point the buffer must be in one of two states: + // 1. Held by another array + // 2. Held from a prevous async_eval + unretained_buffers.emplace(std::uintptr_t(it->get()), *it); + it = retain_buffers.erase(it); + } else { + unretained_buffers.erase(std::uintptr_t(it->get())); + ++it; + } + } + if (stream.device == Device::gpu) { + metal::finalize(stream, std::move(retain_buffers), false); + } else { + cpu::finalize(stream, std::move(retain_buffers)); + } + } // Signal the event in its stream for (auto& [_, e] : events) { auto s = e.stream(); e.signal(s); + std::unordered_set> retain; + if (s == stream) { + for (auto& [_, b] : unretained_buffers) { + auto ptr = b.lock(); + if (ptr) { + retain.insert(ptr); + } + } + } if (s.device == Device::gpu) { - metal::finalize(s); + metal::finalize(s, std::move(retain), true); + } else { + cpu::finalize(s, std::move(retain)); } } diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index fcd424343..c9f53932e 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -195,6 +195,36 @@ class TestEval(mlx_tests.MLXTestCase): mx.eval(z) mx.set_memory_limit(old_limit) + def test_donation_multiple_inputs(self): + def fun(its, x, y): + for _ in range(its): + a = x + y # y should donate + b = x + a # x should donate + x, y = a, b + return x, y + + x = mx.zeros((128, 128)) + y = mx.zeros((128, 128)) + mx.reset_peak_memory() + a, b = fun(2, x, y) + mx.eval(a, b) + mx.synchronize() + mem2 = mx.get_peak_memory() + a, b = fun(10, x, y) + mx.eval(a, b) + mx.synchronize() + mem10 = mx.get_peak_memory() + self.assertEqual(mem2, mem10) + + def test_async_with_delete(self): + a = mx.ones((5, 5)) + for _ in range(100): + a = mx.abs(a) + mx.async_eval(a) + del a + mx.clear_cache() + mx.synchronize() + if __name__ == "__main__": unittest.main()