mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	| @@ -19,7 +19,7 @@ Buffer malloc(size_t size) { | ||||
| } | ||||
|  | ||||
| void free(Buffer buffer) { | ||||
|   return allocator().free(buffer); | ||||
|   allocator().free(buffer); | ||||
| } | ||||
|  | ||||
| Buffer CommonAllocator::malloc(size_t size, bool) { | ||||
|   | ||||
| @@ -214,6 +214,8 @@ array::~array() { | ||||
|     if (do_detach) { | ||||
|       for (auto& s : siblings()) { | ||||
|         for (auto& ss : s.siblings()) { | ||||
|           // Set to null here to avoid descending into array destructor | ||||
|           // for siblings | ||||
|           ss.array_desc_ = nullptr; | ||||
|         } | ||||
|         s.array_desc_->siblings.clear(); | ||||
| @@ -292,6 +294,14 @@ array::ArrayDesc::~ArrayDesc() { | ||||
|     auto top = std::move(for_deletion.back()); | ||||
|     for_deletion.pop_back(); | ||||
|     append_deletable_inputs(*top); | ||||
|  | ||||
|     // Clear out possible siblings to break circular references | ||||
|     for (auto& s : top->siblings) { | ||||
|       // Set to null here to avoid descending into top-level | ||||
|       // array destructor for siblings | ||||
|       s.array_desc_ = nullptr; | ||||
|     } | ||||
|     top->siblings.clear(); | ||||
|   } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| import operator | ||||
| import pickle | ||||
| import resource | ||||
| import sys | ||||
| import unittest | ||||
| import weakref | ||||
| @@ -1928,6 +1929,25 @@ class TestArray(mlx_tests.MLXTestCase): | ||||
|             x = mx.sin(x) | ||||
|         mx.eval(x) | ||||
|  | ||||
|     def test_siblings_without_eval(self): | ||||
|         def get_mem(): | ||||
|             return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss | ||||
|  | ||||
|         key = mx.array([1, 2]) | ||||
|  | ||||
|         def t(): | ||||
|             a, b = mx.split(key, 2) | ||||
|             a = mx.reshape(a, []) | ||||
|             b = mx.reshape(b, []) | ||||
|             return b | ||||
|  | ||||
|         t() | ||||
|         expected = get_mem() | ||||
|         for _ in range(100): | ||||
|             t() | ||||
|         used = get_mem() | ||||
|         self.assertEqual(expected, used) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun