mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix leak with multi-output primitives (#1274)
* fix leak with multi-output primitives * hopefully an actual fix
This commit is contained in:
parent
df124e018a
commit
1fba87b0df
@ -175,10 +175,11 @@ array::~array() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore arrays that will be detached
|
||||
if (status() != array::Status::unscheduled) {
|
||||
// Ignore arrays that might be detached during eval
|
||||
if (status() == array::Status::scheduled) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Break circular reference for non-detached arrays with siblings
|
||||
if (auto n = siblings().size(); n > 0) {
|
||||
bool do_detach = true;
|
||||
|
@ -75,6 +75,9 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
for (auto& out : outputs) {
|
||||
out.set_status(array::Status::available);
|
||||
}
|
||||
|
||||
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
|
||||
d.end_encoding(s.index);
|
||||
|
@ -141,10 +141,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
tape.pop();
|
||||
|
||||
// Set the status of the array and siblings.
|
||||
auto status = async ? array::Status::scheduled : array::Status::available;
|
||||
arr.set_status(status);
|
||||
arr.set_status(array::Status::scheduled);
|
||||
for (auto& s : arr.siblings()) {
|
||||
s.set_status(status);
|
||||
s.set_status(array::Status::scheduled);
|
||||
}
|
||||
|
||||
auto stream = arr.primitive().stream();
|
||||
@ -170,6 +169,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
for (auto& out : outputs) {
|
||||
out.set_status(array::Status::available);
|
||||
}
|
||||
|
||||
if (signal) {
|
||||
arr.event().signal();
|
||||
}
|
||||
|
@ -103,6 +103,23 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
z = mx.add(y, x, stream=mx.cpu)
|
||||
self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0)))
|
||||
|
||||
def test_multi_output_eval_during_transform(self):
|
||||
x = mx.random.uniform(shape=(1024,))
|
||||
y = mx.ones((1024,))
|
||||
mx.eval(x, y)
|
||||
|
||||
def fn(x):
|
||||
a, b = mx.divmod(x, x)
|
||||
mx.eval(a)
|
||||
return a
|
||||
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
if mx.metal.is_available():
|
||||
peak_mem = mx.metal.get_peak_memory()
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user