mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 22:11:15 +08:00
parent
9d7fa6b8e6
commit
bf481e8e5d
@ -19,7 +19,7 @@ Buffer malloc(size_t size) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void free(Buffer buffer) {
|
void free(Buffer buffer) {
|
||||||
return allocator().free(buffer);
|
allocator().free(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||||
|
@ -214,6 +214,8 @@ array::~array() {
|
|||||||
if (do_detach) {
|
if (do_detach) {
|
||||||
for (auto& s : siblings()) {
|
for (auto& s : siblings()) {
|
||||||
for (auto& ss : s.siblings()) {
|
for (auto& ss : s.siblings()) {
|
||||||
|
// Set to null here to avoid descending into array destructor
|
||||||
|
// for siblings
|
||||||
ss.array_desc_ = nullptr;
|
ss.array_desc_ = nullptr;
|
||||||
}
|
}
|
||||||
s.array_desc_->siblings.clear();
|
s.array_desc_->siblings.clear();
|
||||||
@ -292,6 +294,14 @@ array::ArrayDesc::~ArrayDesc() {
|
|||||||
auto top = std::move(for_deletion.back());
|
auto top = std::move(for_deletion.back());
|
||||||
for_deletion.pop_back();
|
for_deletion.pop_back();
|
||||||
append_deletable_inputs(*top);
|
append_deletable_inputs(*top);
|
||||||
|
|
||||||
|
// Clear out possible siblings to break circular references
|
||||||
|
for (auto& s : top->siblings) {
|
||||||
|
// Set to null here to avoid descending into top-level
|
||||||
|
// array destructor for siblings
|
||||||
|
s.array_desc_ = nullptr;
|
||||||
|
}
|
||||||
|
top->siblings.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import operator
|
import operator
|
||||||
import pickle
|
import pickle
|
||||||
|
import resource
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
import weakref
|
import weakref
|
||||||
@ -1928,6 +1929,25 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
x = mx.sin(x)
|
x = mx.sin(x)
|
||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
|
|
||||||
|
def test_siblings_without_eval(self):
|
||||||
|
def get_mem():
|
||||||
|
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||||
|
|
||||||
|
key = mx.array([1, 2])
|
||||||
|
|
||||||
|
def t():
|
||||||
|
a, b = mx.split(key, 2)
|
||||||
|
a = mx.reshape(a, [])
|
||||||
|
b = mx.reshape(b, [])
|
||||||
|
return b
|
||||||
|
|
||||||
|
t()
|
||||||
|
expected = get_mem()
|
||||||
|
for _ in range(100):
|
||||||
|
t()
|
||||||
|
used = get_mem()
|
||||||
|
self.assertEqual(expected, used)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user