fix circular reference (#2443)

This commit is contained in:
Awni Hannun 2025-07-30 09:37:44 -07:00 committed by GitHub
parent 3bf81ed1bd
commit b405591249
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 35 deletions

View File

@ -399,41 +399,7 @@ class Module(dict):
Returns:
The module instance after updating the submodules.
"""
def apply(dst, modules):
if isinstance(modules, dict):
for k in modules:
if k in dst:
current_value = dst[k]
new_value = modules[k]
if self.is_module(current_value) and self.is_module(new_value):
dst[k] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif strict and new_value != {}:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list):
for i in range(len(modules)):
current_value = dst[i]
new_value = modules[i]
if self.is_module(current_value) and self.is_module(new_value):
dst[i] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif strict and new_value != {}:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
apply(self, modules)
_update_modules(self, modules, strict)
return self
def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
@ -639,6 +605,36 @@ class Module(dict):
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
def _update_modules(dst, modules, strict):
if isinstance(modules, dict):
for k in modules:
if k in dst:
current_value = dst[k]
new_value = modules[k]
if Module.is_module(current_value) and Module.is_module(new_value):
dst[k] = new_value
elif isinstance(current_value, (dict, list)):
_update_modules(current_value, new_value, strict)
elif strict and new_value != {}:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(f'Module does not have sub-module named "{k}".')
elif isinstance(modules, list):
for i in range(len(modules)):
current_value = dst[i]
new_value = modules[i]
if Module.is_module(current_value) and Module.is_module(new_value):
dst[i] = new_value
elif isinstance(current_value, (dict, list)):
_update_modules(current_value, new_value, strict)
elif strict and new_value != {}:
raise ValueError(f"Received invalid type: {type(new_value).__name__}.")
elif strict:
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)

View File

@ -279,6 +279,23 @@ class TestBase(mlx_tests.MLXTestCase):
del m.weight
self.assertFalse(hasattr(m, "weight"))
def test_circular_leaks(self):
y = mx.random.uniform(1)
mx.eval(y)
def make_and_update():
model = nn.Linear(1024, 512)
mx.eval(model.parameters())
leaves = {}
model.update_modules(leaves)
mx.synchronize()
pre = mx.get_active_memory()
make_and_update()
mx.synchronize()
post = mx.get_active_memory()
self.assertEqual(pre, post)
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):