fix circular reference (#2443)

This commit is contained in:
Awni Hannun
2025-07-30 09:37:44 -07:00
committed by GitHub
parent 3bf81ed1bd
commit b405591249
2 changed files with 48 additions and 35 deletions

View File

@@ -279,6 +279,23 @@ class TestBase(mlx_tests.MLXTestCase):
del m.weight
self.assertFalse(hasattr(m, "weight"))
def test_circular_leaks(self):
y = mx.random.uniform(1)
mx.eval(y)
def make_and_update():
model = nn.Linear(1024, 512)
mx.eval(model.parameters())
leaves = {}
model.update_modules(leaves)
mx.synchronize()
pre = mx.get_active_memory()
make_and_update()
mx.synchronize()
post = mx.get_active_memory()
self.assertEqual(pre, post)
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):