Fix leak for multi-output primitives which are never detached (#1059)

* fix multi output leak

* ignore arrays that will be detached

* add some comments

* stray print
This commit is contained in:
Awni Hannun
2024-05-01 07:31:45 -07:00
committed by GitHub
parent 19bef39f5c
commit 7f7b9662ea
5 changed files with 59 additions and 15 deletions

View File

@@ -1694,6 +1694,22 @@ class TestArray(mlx_tests.MLXTestCase):
b = pickle.loads(pickle.dumps(a))
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_multi_output_leak(self):
def fun():
a = mx.zeros((2**20))
mx.eval(a)
b, c = mx.divmod(a, a)
del b, c
fun()
mx.synchronize()
peak_1 = mx.metal.get_peak_memory()
fun()
mx.synchronize()
peak_2 = mx.metal.get_peak_memory()
self.assertEqual(peak_1, peak_2)
if __name__ == "__main__":
unittest.main()