Cycle leak break (#1856)

* detect and break leaks in custom function

* detect and break leaks in custom function
This commit is contained in:
Awni Hannun
2025-02-11 14:45:02 -08:00
committed by GitHub
parent 142b77751d
commit 2a45056ba8
10 changed files with 396 additions and 49 deletions

View File

@@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
import gc
import io
import unittest
from functools import partial
@@ -926,6 +927,32 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(out[0].shape, (3, 1, 4, 2))
self.assertEqual(out[1].shape, (2, 2, 5))
def test_leaks(self):
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
else:
mem_pre = 0
def outer():
d = {}
def f(x):
return d["x"]
d["f"] = mx.compile(f)
d["x"] = mx.array([0] * 1000)
for _ in range(5):
outer()
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
else:
mem_post = 0
self.assertEqual(mem_pre, mem_post)
if __name__ == "__main__":
unittest.main()