diff --git a/mlx/allocator.cpp b/mlx/allocator.cpp index a111b5c22..8d7273a78 100644 --- a/mlx/allocator.cpp +++ b/mlx/allocator.cpp @@ -19,7 +19,7 @@ Buffer malloc(size_t size) { } void free(Buffer buffer) { - return allocator().free(buffer); + allocator().free(buffer); } Buffer CommonAllocator::malloc(size_t size, bool) { diff --git a/mlx/array.cpp b/mlx/array.cpp index 2c70b9c8e..8bf007688 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -214,6 +214,8 @@ array::~array() { if (do_detach) { for (auto& s : siblings()) { for (auto& ss : s.siblings()) { + // Set to null here to avoid descending into array destructor + // for siblings ss.array_desc_ = nullptr; } s.array_desc_->siblings.clear(); @@ -292,6 +294,14 @@ array::ArrayDesc::~ArrayDesc() { auto top = std::move(for_deletion.back()); for_deletion.pop_back(); append_deletable_inputs(*top); + + // Clear out possible siblings to break circular references + for (auto& s : top->siblings) { + // Set to null here to avoid descending into top-level + // array destructor for siblings + s.array_desc_ = nullptr; + } + top->siblings.clear(); } } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index db0e26aa9..b0c915de1 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2,6 +2,7 @@ import operator import pickle +import resource import sys import unittest import weakref @@ -1928,6 +1929,25 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.sin(x) mx.eval(x) + def test_siblings_without_eval(self): + def get_mem(): + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + + key = mx.array([1, 2]) + + def t(): + a, b = mx.split(key, 2) + a = mx.reshape(a, []) + b = mx.reshape(b, []) + return b + + t() + expected = get_mem() + for _ in range(100): + t() + used = get_mem() + self.assertEqual(expected, used) + if __name__ == "__main__": unittest.main()