mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
fix circular reference (#2443)
This commit is contained in:
@@ -279,6 +279,23 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
del m.weight
|
||||
self.assertFalse(hasattr(m, "weight"))
|
||||
|
||||
def test_circular_leaks(self):
|
||||
y = mx.random.uniform(1)
|
||||
mx.eval(y)
|
||||
|
||||
def make_and_update():
|
||||
model = nn.Linear(1024, 512)
|
||||
mx.eval(model.parameters())
|
||||
leaves = {}
|
||||
model.update_modules(leaves)
|
||||
|
||||
mx.synchronize()
|
||||
pre = mx.get_active_memory()
|
||||
make_and_update()
|
||||
mx.synchronize()
|
||||
post = mx.get_active_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
Reference in New Issue
Block a user