Revert "More buffer donation in some cases (#1858)" (#1863)

This reverts commit d274ae77f2.
This commit is contained in:
Awni Hannun 2025-02-13 14:21:44 -08:00 committed by GitHub
parent 5cd97f7ffe
commit 428f589364
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 40 deletions

View File

@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <memory> #include <memory>
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h" #include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
@ -53,34 +54,16 @@ std::function<void()> make_task(array arr, bool signal) {
abort_with_exception(error); abort_with_exception(error);
} }
} }
std::vector<std::shared_ptr<array::Data>> buffers;
// Use a set to de-dup input buffers
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) { for (auto& in : arr.inputs()) {
if (in.data_shared_ptr() != nullptr) { buffers.push_back(in.data_shared_ptr());
buffers.insert(in.data_shared_ptr());
}
} }
for (auto& s : arr.siblings()) { for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr()); buffers.push_back(s.data_shared_ptr());
} }
if (!arr.is_tracer()) { if (!arr.is_tracer()) {
arr.detach(); arr.detach();
} }
// Erase any input buffers that have a use count > 1
// to keep them elligible for donation
for (auto it = buffers.begin(); it != buffers.end();) {
// Always hold buffers which could belong to outputs
// TODO: this shouldn't be necessary, but metal validation
// complains if buffers get released before the command buffer is
// finished, even if the ready signal has fired
if (!signal && it->use_count() > 1) {
it = buffers.erase(it);
} else {
++it;
}
}
for (auto& out : outputs) { for (auto& out : outputs) {
out.set_status(array::Status::evaluated); out.set_status(array::Status::evaluated);
} }

View File

@ -174,25 +174,6 @@ class TestEval(mlx_tests.MLXTestCase):
post = mx.metal.get_peak_memory() post = mx.metal.get_peak_memory()
self.assertEqual(pre, post) self.assertEqual(pre, post)
def test_donation_multiple_inputs(self):
def fun(its, x, y):
for _ in range(its):
a = x + y # y should donate
b = x + a # x should donate
x, y = a, b
return x, y
x = mx.zeros((128, 128))
y = mx.zeros((128, 128))
mx.metal.reset_peak_memory()
a, b = fun(2, x, y)
mx.eval(a, b)
mem2 = mx.metal.get_peak_memory()
a, b = fun(10, x, y)
mx.eval(a, b)
mem10 = mx.metal.get_peak_memory()
self.assertEqual(mem2, mem10)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()