mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
This reverts commit d274ae77f2
.
This commit is contained in:
parent
5cd97f7ffe
commit
428f589364
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/event.h"
|
#include "mlx/backend/metal/event.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
@ -53,34 +54,16 @@ std::function<void()> make_task(array arr, bool signal) {
|
|||||||
abort_with_exception(error);
|
abort_with_exception(error);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||||
// Use a set to de-dup input buffers
|
|
||||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
|
||||||
for (auto& in : arr.inputs()) {
|
for (auto& in : arr.inputs()) {
|
||||||
if (in.data_shared_ptr() != nullptr) {
|
buffers.push_back(in.data_shared_ptr());
|
||||||
buffers.insert(in.data_shared_ptr());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for (auto& s : arr.siblings()) {
|
for (auto& s : arr.siblings()) {
|
||||||
buffers.insert(s.data_shared_ptr());
|
buffers.push_back(s.data_shared_ptr());
|
||||||
}
|
}
|
||||||
if (!arr.is_tracer()) {
|
if (!arr.is_tracer()) {
|
||||||
arr.detach();
|
arr.detach();
|
||||||
}
|
}
|
||||||
// Erase any input buffers that have a use count > 1
|
|
||||||
// to keep them elligible for donation
|
|
||||||
for (auto it = buffers.begin(); it != buffers.end();) {
|
|
||||||
// Always hold buffers which could belong to outputs
|
|
||||||
// TODO: this shouldn't be necessary, but metal validation
|
|
||||||
// complains if buffers get released before the command buffer is
|
|
||||||
// finished, even if the ready signal has fired
|
|
||||||
if (!signal && it->use_count() > 1) {
|
|
||||||
it = buffers.erase(it);
|
|
||||||
} else {
|
|
||||||
++it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& out : outputs) {
|
for (auto& out : outputs) {
|
||||||
out.set_status(array::Status::evaluated);
|
out.set_status(array::Status::evaluated);
|
||||||
}
|
}
|
||||||
|
@ -174,25 +174,6 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
post = mx.metal.get_peak_memory()
|
post = mx.metal.get_peak_memory()
|
||||||
self.assertEqual(pre, post)
|
self.assertEqual(pre, post)
|
||||||
|
|
||||||
def test_donation_multiple_inputs(self):
|
|
||||||
def fun(its, x, y):
|
|
||||||
for _ in range(its):
|
|
||||||
a = x + y # y should donate
|
|
||||||
b = x + a # x should donate
|
|
||||||
x, y = a, b
|
|
||||||
return x, y
|
|
||||||
|
|
||||||
x = mx.zeros((128, 128))
|
|
||||||
y = mx.zeros((128, 128))
|
|
||||||
mx.metal.reset_peak_memory()
|
|
||||||
a, b = fun(2, x, y)
|
|
||||||
mx.eval(a, b)
|
|
||||||
mem2 = mx.metal.get_peak_memory()
|
|
||||||
a, b = fun(10, x, y)
|
|
||||||
mx.eval(a, b)
|
|
||||||
mem10 = mx.metal.get_peak_memory()
|
|
||||||
self.assertEqual(mem2, mem10)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user