mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:29:35 +08:00
Cycle leak break (#1856)
* detect and break leaks in custom function * detect and break leaks in custom function
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -239,6 +240,33 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
constants_size = constant.nbytes + 8192
|
||||
self.assertTrue(os.path.getsize(path) < constants_size)
|
||||
|
||||
def test_leaks(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
def outer():
|
||||
d = {}
|
||||
|
||||
def f(x):
|
||||
return d["x"]
|
||||
|
||||
d["f"] = mx.exporter(path, f)
|
||||
d["x"] = mx.array([0] * 1000)
|
||||
|
||||
for _ in range(5):
|
||||
outer()
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user