From b40559124987412199e9c695978074cee063cc9c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 30 Jul 2025 09:37:44 -0700 Subject: [PATCH] fix circular reference (#2443) --- python/mlx/nn/layers/base.py | 66 +++++++++++++++++------------------- python/tests/test_nn.py | 17 ++++++++++ 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index e999438341..c3a517d163 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -399,41 +399,7 @@ class Module(dict): Returns: The module instance after updating the submodules. """ - - def apply(dst, modules): - if isinstance(modules, dict): - for k in modules: - if k in dst: - current_value = dst[k] - new_value = modules[k] - if self.is_module(current_value) and self.is_module(new_value): - dst[k] = new_value - elif isinstance(current_value, (dict, list)): - apply(current_value, new_value) - elif strict and new_value != {}: - raise ValueError( - f"Received invalid type: {type(new_value).__name__}." - ) - elif strict: - raise ValueError( - f'Module does not have sub-module named "{k}".' - ) - elif isinstance(modules, list): - for i in range(len(modules)): - current_value = dst[i] - new_value = modules[i] - if self.is_module(current_value) and self.is_module(new_value): - dst[i] = new_value - elif isinstance(current_value, (dict, list)): - apply(current_value, new_value) - elif strict and new_value != {}: - raise ValueError( - f"Received invalid type: {type(new_value).__name__}." - ) - elif strict: - raise ValueError(f"Received invalid type: {type(modules).__name__}.") - - apply(self, modules) + _update_modules(self, modules, strict) return self def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module: @@ -639,6 +605,36 @@ class Module(dict): self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x) +def _update_modules(dst, modules, strict): + if isinstance(modules, dict): + for k in modules: + if k in dst: + current_value = dst[k] + new_value = modules[k] + if Module.is_module(current_value) and Module.is_module(new_value): + dst[k] = new_value + elif isinstance(current_value, (dict, list)): + _update_modules(current_value, new_value, strict) + elif strict and new_value != {}: + raise ValueError( + f"Received invalid type: {type(new_value).__name__}." + ) + elif strict: + raise ValueError(f'Module does not have sub-module named "{k}".') + elif isinstance(modules, list): + for i in range(len(modules)): + current_value = dst[i] + new_value = modules[i] + if Module.is_module(current_value) and Module.is_module(new_value): + dst[i] = new_value + elif isinstance(current_value, (dict, list)): + _update_modules(current_value, new_value, strict) + elif strict and new_value != {}: + raise ValueError(f"Received invalid type: {type(new_value).__name__}.") + elif strict: + raise ValueError(f"Received invalid type: {type(modules).__name__}.") + + def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn): if is_leaf_fn(model, value_key, value): return map_fn(value) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index ae3fae4da1..a771020875 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -279,6 +279,23 @@ class TestBase(mlx_tests.MLXTestCase): del m.weight self.assertFalse(hasattr(m, "weight")) + def test_circular_leaks(self): + y = mx.random.uniform(1) + mx.eval(y) + + def make_and_update(): + model = nn.Linear(1024, 512) + mx.eval(model.parameters()) + leaves = {} + model.update_modules(leaves) + + mx.synchronize() + pre = mx.get_active_memory() + make_and_update() + mx.synchronize() + post = mx.get_active_memory() + self.assertEqual(pre, post) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self):