mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 09:58:17 +08:00
Cycle leak break (#1856)
* detect and break leaks in custom function * detect and break leaks in custom function
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import gc
|
||||
import io
|
||||
import unittest
|
||||
from functools import partial
|
||||
@@ -926,6 +927,32 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(out[0].shape, (3, 1, 4, 2))
|
||||
self.assertEqual(out[1].shape, (2, 2, 5))
|
||||
|
||||
def test_leaks(self):
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
def outer():
|
||||
d = {}
|
||||
|
||||
def f(x):
|
||||
return d["x"]
|
||||
|
||||
d["f"] = mx.compile(f)
|
||||
d["x"] = mx.array([0] * 1000)
|
||||
|
||||
for _ in range(5):
|
||||
outer()
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user