From d274ae77f2083cef3af3ebd01465c26f9c958cb1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 12 Feb 2025 19:41:37 -0800 Subject: [PATCH] More buffer donation in some cases (#1858) * more donation * fix * add test --- mlx/backend/metal/metal.cpp | 25 +++++++++++++++++++++---- python/tests/test_eval.py | 19 +++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index a8bc716c1..20ad0fe1d 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -1,6 +1,5 @@ // Copyright © 2023-2024 Apple Inc. #include - #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/event.h" #include "mlx/backend/metal/utils.h" @@ -54,16 +53,34 @@ std::function make_task(array arr, bool signal) { abort_with_exception(error); } } - std::vector> buffers; + + // Use a set to de-dup input buffers + std::unordered_set> buffers; for (auto& in : arr.inputs()) { - buffers.push_back(in.data_shared_ptr()); + if (in.data_shared_ptr() != nullptr) { + buffers.insert(in.data_shared_ptr()); + } } for (auto& s : arr.siblings()) { - buffers.push_back(s.data_shared_ptr()); + buffers.insert(s.data_shared_ptr()); } if (!arr.is_tracer()) { arr.detach(); } + // Erase any input buffers that have a use count > 1 + // to keep them elligible for donation + for (auto it = buffers.begin(); it != buffers.end();) { + // Always hold buffers which could belong to outputs + // TODO: this shouldn't be necessary, but metal validation + // complains if buffers get released before the command buffer is + // finished, even if the ready signal has fired + if (!signal && it->use_count() > 1) { + it = buffers.erase(it); + } else { + ++it; + } + } + for (auto& out : outputs) { out.set_status(array::Status::evaluated); } diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 37e31f80b..8e9151d11 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -174,6 +174,25 @@ class TestEval(mlx_tests.MLXTestCase): post = mx.metal.get_peak_memory() self.assertEqual(pre, post) + def test_donation_multiple_inputs(self): + def fun(its, x, y): + for _ in range(its): + a = x + y # y should donate + b = x + a # x should donate + x, y = a, b + return x, y + + x = mx.zeros((128, 128)) + y = mx.zeros((128, 128)) + mx.metal.reset_peak_memory() + a, b = fun(2, x, y) + mx.eval(a, b) + mem2 = mx.metal.get_peak_memory() + a, b = fun(10, x, y) + mx.eval(a, b) + mem10 = mx.metal.get_peak_memory() + self.assertEqual(mem2, mem10) + if __name__ == "__main__": unittest.main()