From 428f589364bb405f784b552d8f0d7a7542f383ad Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 13 Feb 2025 14:21:44 -0800 Subject: [PATCH] Revert "More buffer donation in some cases (#1858)" (#1863) This reverts commit d274ae77f2083cef3af3ebd01465c26f9c958cb1. --- mlx/backend/metal/metal.cpp | 25 ++++--------------------- python/tests/test_eval.py | 19 ------------------- 2 files changed, 4 insertions(+), 40 deletions(-) diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 20ad0fe1d..a8bc716c1 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -1,5 +1,6 @@ // Copyright © 2023-2024 Apple Inc. #include + #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/event.h" #include "mlx/backend/metal/utils.h" @@ -53,34 +54,16 @@ std::function make_task(array arr, bool signal) { abort_with_exception(error); } } - - // Use a set to de-dup input buffers - std::unordered_set> buffers; + std::vector> buffers; for (auto& in : arr.inputs()) { - if (in.data_shared_ptr() != nullptr) { - buffers.insert(in.data_shared_ptr()); - } + buffers.push_back(in.data_shared_ptr()); } for (auto& s : arr.siblings()) { - buffers.insert(s.data_shared_ptr()); + buffers.push_back(s.data_shared_ptr()); } if (!arr.is_tracer()) { arr.detach(); } - // Erase any input buffers that have a use count > 1 - // to keep them elligible for donation - for (auto it = buffers.begin(); it != buffers.end();) { - // Always hold buffers which could belong to outputs - // TODO: this shouldn't be necessary, but metal validation - // complains if buffers get released before the command buffer is - // finished, even if the ready signal has fired - if (!signal && it->use_count() > 1) { - it = buffers.erase(it); - } else { - ++it; - } - } - for (auto& out : outputs) { out.set_status(array::Status::evaluated); } diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 8e9151d11..37e31f80b 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -174,25 +174,6 @@ class TestEval(mlx_tests.MLXTestCase): post = mx.metal.get_peak_memory() self.assertEqual(pre, post) - def test_donation_multiple_inputs(self): - def fun(its, x, y): - for _ in range(its): - a = x + y # y should donate - b = x + a # x should donate - x, y = a, b - return x, y - - x = mx.zeros((128, 128)) - y = mx.zeros((128, 128)) - mx.metal.reset_peak_memory() - a, b = fun(2, x, y) - mx.eval(a, b) - mem2 = mx.metal.get_peak_memory() - a, b = fun(10, x, y) - mx.eval(a, b) - mem10 = mx.metal.get_peak_memory() - self.assertEqual(mem2, mem10) - if __name__ == "__main__": unittest.main()