mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	| @@ -19,7 +19,7 @@ Buffer malloc(size_t size) { | |||||||
| } | } | ||||||
|  |  | ||||||
| void free(Buffer buffer) { | void free(Buffer buffer) { | ||||||
|   return allocator().free(buffer); |   allocator().free(buffer); | ||||||
| } | } | ||||||
|  |  | ||||||
| Buffer CommonAllocator::malloc(size_t size, bool) { | Buffer CommonAllocator::malloc(size_t size, bool) { | ||||||
|   | |||||||
| @@ -214,6 +214,8 @@ array::~array() { | |||||||
|     if (do_detach) { |     if (do_detach) { | ||||||
|       for (auto& s : siblings()) { |       for (auto& s : siblings()) { | ||||||
|         for (auto& ss : s.siblings()) { |         for (auto& ss : s.siblings()) { | ||||||
|  |           // Set to null here to avoid descending into array destructor | ||||||
|  |           // for siblings | ||||||
|           ss.array_desc_ = nullptr; |           ss.array_desc_ = nullptr; | ||||||
|         } |         } | ||||||
|         s.array_desc_->siblings.clear(); |         s.array_desc_->siblings.clear(); | ||||||
| @@ -292,6 +294,14 @@ array::ArrayDesc::~ArrayDesc() { | |||||||
|     auto top = std::move(for_deletion.back()); |     auto top = std::move(for_deletion.back()); | ||||||
|     for_deletion.pop_back(); |     for_deletion.pop_back(); | ||||||
|     append_deletable_inputs(*top); |     append_deletable_inputs(*top); | ||||||
|  |  | ||||||
|  |     // Clear out possible siblings to break circular references | ||||||
|  |     for (auto& s : top->siblings) { | ||||||
|  |       // Set to null here to avoid descending into top-level | ||||||
|  |       // array destructor for siblings | ||||||
|  |       s.array_desc_ = nullptr; | ||||||
|  |     } | ||||||
|  |     top->siblings.clear(); | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ | |||||||
|  |  | ||||||
| import operator | import operator | ||||||
| import pickle | import pickle | ||||||
|  | import resource | ||||||
| import sys | import sys | ||||||
| import unittest | import unittest | ||||||
| import weakref | import weakref | ||||||
| @@ -1928,6 +1929,25 @@ class TestArray(mlx_tests.MLXTestCase): | |||||||
|             x = mx.sin(x) |             x = mx.sin(x) | ||||||
|         mx.eval(x) |         mx.eval(x) | ||||||
|  |  | ||||||
|  |     def test_siblings_without_eval(self): | ||||||
|  |         def get_mem(): | ||||||
|  |             return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss | ||||||
|  |  | ||||||
|  |         key = mx.array([1, 2]) | ||||||
|  |  | ||||||
|  |         def t(): | ||||||
|  |             a, b = mx.split(key, 2) | ||||||
|  |             a = mx.reshape(a, []) | ||||||
|  |             b = mx.reshape(b, []) | ||||||
|  |             return b | ||||||
|  |  | ||||||
|  |         t() | ||||||
|  |         expected = get_mem() | ||||||
|  |         for _ in range(100): | ||||||
|  |             t() | ||||||
|  |         used = get_mem() | ||||||
|  |         self.assertEqual(expected, used) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun