mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-03 17:32:40 +08:00
fix circular reference (#2443)
This commit is contained in:
parent
3bf81ed1bd
commit
b405591249
@ -399,41 +399,7 @@ class Module(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
The module instance after updating the submodules.
|
The module instance after updating the submodules.
|
||||||
"""
|
"""
|
||||||
|
_update_modules(self, modules, strict)
|
||||||
def apply(dst, modules):
|
|
||||||
if isinstance(modules, dict):
|
|
||||||
for k in modules:
|
|
||||||
if k in dst:
|
|
||||||
current_value = dst[k]
|
|
||||||
new_value = modules[k]
|
|
||||||
if self.is_module(current_value) and self.is_module(new_value):
|
|
||||||
dst[k] = new_value
|
|
||||||
elif isinstance(current_value, (dict, list)):
|
|
||||||
apply(current_value, new_value)
|
|
||||||
elif strict and new_value != {}:
|
|
||||||
raise ValueError(
|
|
||||||
f"Received invalid type: {type(new_value).__name__}."
|
|
||||||
)
|
|
||||||
elif strict:
|
|
||||||
raise ValueError(
|
|
||||||
f'Module does not have sub-module named "{k}".'
|
|
||||||
)
|
|
||||||
elif isinstance(modules, list):
|
|
||||||
for i in range(len(modules)):
|
|
||||||
current_value = dst[i]
|
|
||||||
new_value = modules[i]
|
|
||||||
if self.is_module(current_value) and self.is_module(new_value):
|
|
||||||
dst[i] = new_value
|
|
||||||
elif isinstance(current_value, (dict, list)):
|
|
||||||
apply(current_value, new_value)
|
|
||||||
elif strict and new_value != {}:
|
|
||||||
raise ValueError(
|
|
||||||
f"Received invalid type: {type(new_value).__name__}."
|
|
||||||
)
|
|
||||||
elif strict:
|
|
||||||
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
|
|
||||||
|
|
||||||
apply(self, modules)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
|
def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
|
||||||
@ -639,6 +605,36 @@ class Module(dict):
|
|||||||
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
|
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
|
||||||
|
|
||||||
|
|
||||||
|
def _update_modules(dst, modules, strict):
|
||||||
|
if isinstance(modules, dict):
|
||||||
|
for k in modules:
|
||||||
|
if k in dst:
|
||||||
|
current_value = dst[k]
|
||||||
|
new_value = modules[k]
|
||||||
|
if Module.is_module(current_value) and Module.is_module(new_value):
|
||||||
|
dst[k] = new_value
|
||||||
|
elif isinstance(current_value, (dict, list)):
|
||||||
|
_update_modules(current_value, new_value, strict)
|
||||||
|
elif strict and new_value != {}:
|
||||||
|
raise ValueError(
|
||||||
|
f"Received invalid type: {type(new_value).__name__}."
|
||||||
|
)
|
||||||
|
elif strict:
|
||||||
|
raise ValueError(f'Module does not have sub-module named "{k}".')
|
||||||
|
elif isinstance(modules, list):
|
||||||
|
for i in range(len(modules)):
|
||||||
|
current_value = dst[i]
|
||||||
|
new_value = modules[i]
|
||||||
|
if Module.is_module(current_value) and Module.is_module(new_value):
|
||||||
|
dst[i] = new_value
|
||||||
|
elif isinstance(current_value, (dict, list)):
|
||||||
|
_update_modules(current_value, new_value, strict)
|
||||||
|
elif strict and new_value != {}:
|
||||||
|
raise ValueError(f"Received invalid type: {type(new_value).__name__}.")
|
||||||
|
elif strict:
|
||||||
|
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
|
||||||
|
|
||||||
|
|
||||||
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
||||||
if is_leaf_fn(model, value_key, value):
|
if is_leaf_fn(model, value_key, value):
|
||||||
return map_fn(value)
|
return map_fn(value)
|
||||||
|
@ -279,6 +279,23 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
del m.weight
|
del m.weight
|
||||||
self.assertFalse(hasattr(m, "weight"))
|
self.assertFalse(hasattr(m, "weight"))
|
||||||
|
|
||||||
|
def test_circular_leaks(self):
|
||||||
|
y = mx.random.uniform(1)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
def make_and_update():
|
||||||
|
model = nn.Linear(1024, 512)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
leaves = {}
|
||||||
|
model.update_modules(leaves)
|
||||||
|
|
||||||
|
mx.synchronize()
|
||||||
|
pre = mx.get_active_memory()
|
||||||
|
make_and_update()
|
||||||
|
mx.synchronize()
|
||||||
|
post = mx.get_active_memory()
|
||||||
|
self.assertEqual(pre, post)
|
||||||
|
|
||||||
|
|
||||||
class TestLayers(mlx_tests.MLXTestCase):
|
class TestLayers(mlx_tests.MLXTestCase):
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user