Cycle leak break (#1856)

* detect and break leaks in custom function

* detect and break leaks in custom function
This commit is contained in:
Awni Hannun
2025-02-11 14:45:02 -08:00
committed by GitHub
parent 142b77751d
commit 2a45056ba8
10 changed files with 396 additions and 49 deletions

View File

@@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc.
import gc
import unittest
import mlx.core as mx
@@ -737,6 +738,38 @@ class TestAutograd(mlx_tests.MLXTestCase):
expected[4:-5:-2] = tan_b
self.assertTrue(mx.allclose(grad, expected))
def test_leaks(self):
for transform in [
mx.grad,
mx.value_and_grad,
mx.custom_function,
mx.checkpoint,
]:
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
else:
mem_pre = 0
def outer():
d = {}
def f(x):
return d["x"]
d["f"] = transform(f)
d["x"] = mx.array([0] * 1000)
for _ in range(5):
outer()
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
else:
mem_post = 0
self.assertEqual(mem_pre, mem_post)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
import gc
import io
import unittest
from functools import partial
@@ -926,6 +927,32 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(out[0].shape, (3, 1, 4, 2))
self.assertEqual(out[1].shape, (2, 2, 5))
def test_leaks(self):
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
else:
mem_pre = 0
def outer():
d = {}
def f(x):
return d["x"]
d["f"] = mx.compile(f)
d["x"] = mx.array([0] * 1000)
for _ in range(5):
outer()
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
else:
mem_post = 0
self.assertEqual(mem_pre, mem_post)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,5 +1,6 @@
# Copyright © 2024 Apple Inc.
import gc
import os
import tempfile
import unittest
@@ -239,6 +240,33 @@ class TestExportImport(mlx_tests.MLXTestCase):
constants_size = constant.nbytes + 8192
self.assertTrue(os.path.getsize(path) < constants_size)
def test_leaks(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
else:
mem_pre = 0
def outer():
d = {}
def f(x):
return d["x"]
d["f"] = mx.exporter(path, f)
d["x"] = mx.array([0] * 1000)
for _ in range(5):
outer()
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
else:
mem_post = 0
self.assertEqual(mem_pre, mem_post)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
import gc
import unittest
import mlx.core as mx
@@ -608,6 +609,32 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertEqual(fx.shape, (5, 6, 7))
self.assertEqual(fy.shape, (4, 5, 6, 7))
def test_leaks(self):
if mx.metal.is_available():
mem_pre = mx.metal.get_active_memory()
else:
mem_pre = 0
def outer():
d = {}
def f(x):
return d["x"]
d["f"] = mx.vmap(f)
d["x"] = mx.array([0] * 1000)
for _ in range(5):
outer()
gc.collect()
if mx.metal.is_available():
mem_post = mx.metal.get_active_memory()
else:
mem_post = 0
self.assertEqual(mem_pre, mem_post)
if __name__ == "__main__":
unittest.main()