mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 13:51:13 +08:00
parent
9d7fa6b8e6
commit
bf481e8e5d
@ -19,7 +19,7 @@ Buffer malloc(size_t size) {
|
||||
}
|
||||
|
||||
void free(Buffer buffer) {
|
||||
return allocator().free(buffer);
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
|
@ -214,6 +214,8 @@ array::~array() {
|
||||
if (do_detach) {
|
||||
for (auto& s : siblings()) {
|
||||
for (auto& ss : s.siblings()) {
|
||||
// Set to null here to avoid descending into array destructor
|
||||
// for siblings
|
||||
ss.array_desc_ = nullptr;
|
||||
}
|
||||
s.array_desc_->siblings.clear();
|
||||
@ -292,6 +294,14 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
auto top = std::move(for_deletion.back());
|
||||
for_deletion.pop_back();
|
||||
append_deletable_inputs(*top);
|
||||
|
||||
// Clear out possible siblings to break circular references
|
||||
for (auto& s : top->siblings) {
|
||||
// Set to null here to avoid descending into top-level
|
||||
// array destructor for siblings
|
||||
s.array_desc_ = nullptr;
|
||||
}
|
||||
top->siblings.clear();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import operator
|
||||
import pickle
|
||||
import resource
|
||||
import sys
|
||||
import unittest
|
||||
import weakref
|
||||
@ -1928,6 +1929,25 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.sin(x)
|
||||
mx.eval(x)
|
||||
|
||||
def test_siblings_without_eval(self):
|
||||
def get_mem():
|
||||
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
|
||||
key = mx.array([1, 2])
|
||||
|
||||
def t():
|
||||
a, b = mx.split(key, 2)
|
||||
a = mx.reshape(a, [])
|
||||
b = mx.reshape(b, [])
|
||||
return b
|
||||
|
||||
t()
|
||||
expected = get_mem()
|
||||
for _ in range(100):
|
||||
t()
|
||||
used = get_mem()
|
||||
self.assertEqual(expected, used)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user