mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Cycle leak break (#1856)
* detect and break leaks in custom function * detect and break leaks in custom function
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -737,6 +738,38 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
expected[4:-5:-2] = tan_b
|
||||
self.assertTrue(mx.allclose(grad, expected))
|
||||
|
||||
def test_leaks(self):
|
||||
for transform in [
|
||||
mx.grad,
|
||||
mx.value_and_grad,
|
||||
mx.custom_function,
|
||||
mx.checkpoint,
|
||||
]:
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
def outer():
|
||||
d = {}
|
||||
|
||||
def f(x):
|
||||
return d["x"]
|
||||
|
||||
d["f"] = transform(f)
|
||||
d["x"] = mx.array([0] * 1000)
|
||||
|
||||
for _ in range(5):
|
||||
outer()
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -1,5 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import gc
|
||||
import io
|
||||
import unittest
|
||||
from functools import partial
|
||||
@@ -926,6 +927,32 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(out[0].shape, (3, 1, 4, 2))
|
||||
self.assertEqual(out[1].shape, (2, 2, 5))
|
||||
|
||||
def test_leaks(self):
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
def outer():
|
||||
d = {}
|
||||
|
||||
def f(x):
|
||||
return d["x"]
|
||||
|
||||
d["f"] = mx.compile(f)
|
||||
d["x"] = mx.array([0] * 1000)
|
||||
|
||||
for _ in range(5):
|
||||
outer()
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -1,5 +1,6 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -239,6 +240,33 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
constants_size = constant.nbytes + 8192
|
||||
self.assertTrue(os.path.getsize(path) < constants_size)
|
||||
|
||||
def test_leaks(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
def outer():
|
||||
d = {}
|
||||
|
||||
def f(x):
|
||||
return d["x"]
|
||||
|
||||
d["f"] = mx.exporter(path, f)
|
||||
d["x"] = mx.array([0] * 1000)
|
||||
|
||||
for _ in range(5):
|
||||
outer()
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -1,5 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -608,6 +609,32 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(fx.shape, (5, 6, 7))
|
||||
self.assertEqual(fy.shape, (4, 5, 6, 7))
|
||||
|
||||
def test_leaks(self):
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
def outer():
|
||||
d = {}
|
||||
|
||||
def f(x):
|
||||
return d["x"]
|
||||
|
||||
d["f"] = mx.vmap(f)
|
||||
d["x"] = mx.array([0] * 1000)
|
||||
|
||||
for _ in range(5):
|
||||
outer()
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user