mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 11:28:12 +08:00
Cycle leak break (#1856)
* detect and break leaks in custom function * detect and break leaks in custom function
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -737,6 +738,38 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
expected[4:-5:-2] = tan_b
|
||||
self.assertTrue(mx.allclose(grad, expected))
|
||||
|
||||
def test_leaks(self):
|
||||
for transform in [
|
||||
mx.grad,
|
||||
mx.value_and_grad,
|
||||
mx.custom_function,
|
||||
mx.checkpoint,
|
||||
]:
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
def outer():
|
||||
d = {}
|
||||
|
||||
def f(x):
|
||||
return d["x"]
|
||||
|
||||
d["f"] = transform(f)
|
||||
d["x"] = mx.array([0] * 1000)
|
||||
|
||||
for _ in range(5):
|
||||
outer()
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user