Fix sibling leak (#1590)

* add test

* fix + test

* fix fix
This commit is contained in:
Awni Hannun 2024-11-18 19:17:01 -08:00 committed by GitHub
parent 9d7fa6b8e6
commit bf481e8e5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 1 deletions

View File

@ -19,7 +19,7 @@ Buffer malloc(size_t size) {
}
void free(Buffer buffer) {
return allocator().free(buffer);
allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size, bool) {

View File

@ -214,6 +214,8 @@ array::~array() {
if (do_detach) {
for (auto& s : siblings()) {
for (auto& ss : s.siblings()) {
// Set to null here to avoid descending into array destructor
// for siblings
ss.array_desc_ = nullptr;
}
s.array_desc_->siblings.clear();
@ -292,6 +294,14 @@ array::ArrayDesc::~ArrayDesc() {
auto top = std::move(for_deletion.back());
for_deletion.pop_back();
append_deletable_inputs(*top);
// Clear out possible siblings to break circular references
for (auto& s : top->siblings) {
// Set to null here to avoid descending into top-level
// array destructor for siblings
s.array_desc_ = nullptr;
}
top->siblings.clear();
}
}

View File

@ -2,6 +2,7 @@
import operator
import pickle
import resource
import sys
import unittest
import weakref
@ -1928,6 +1929,25 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.sin(x)
mx.eval(x)
def test_siblings_without_eval(self):
def get_mem():
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
key = mx.array([1, 2])
def t():
a, b = mx.split(key, 2)
a = mx.reshape(a, [])
b = mx.reshape(b, [])
return b
t()
expected = get_mem()
for _ in range(100):
t()
used = get_mem()
self.assertEqual(expected, used)
if __name__ == "__main__":
unittest.main()