default strict mode for module update and update_modules (#2239)

This commit is contained in:
Awni Hannun 2025-06-05 15:27:02 -07:00 committed by GitHub
parent 52dc8c8cd5
commit c763fe1be0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 12 deletions

View File

@ -193,7 +193,7 @@ class Module(dict):
)
if len(weights) != 0:
self.update(tree_unflatten(weights))
self.update(tree_unflatten(weights), strict=False)
return self
def save_weights(self, file: str):
@ -291,7 +291,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict) -> Module:
def update(self, parameters: dict, strict: bool = True) -> Module:
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.
@ -305,7 +305,9 @@ class Module(dict):
Args:
parameters (dict): A complete or partial dictionary of the modules
parameters.
parameters.
strict (bool): If ``True`` checks that ``parameters`` is a
subset of the module's parameters. Default: ``True``.
Returns:
The module instance after updating the parameters.
"""
@ -317,21 +319,29 @@ class Module(dict):
current_value = dst[k]
new_value = parameters[k]
if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[k] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
else:
apply(current_value, new_value)
elif strict:
raise ValueError(f'Module does not have parameter named "{k}".')
elif isinstance(parameters, list):
for i in range(len(parameters)):
current_value = dst[i]
new_value = parameters[i]
if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[i] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
else:
apply(current_value, new_value)
elif strict:
raise ValueError(f"Received invalid type: {type(parameters).__name__}.")
apply(self, parameters)
return self
@ -359,7 +369,7 @@ class Module(dict):
self.update(self.filter_and_map(filter_fn, map_fn))
return self
def update_modules(self, modules: dict) -> Module:
def update_modules(self, modules: dict, strict: bool = True) -> Module:
"""Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists.
@ -368,12 +378,14 @@ class Module(dict):
programmatically swapping layers.
The passed in parameters dictionary need not be a full dictionary
similar to :meth:`parameters`. Only the provided locations will be
similar to :meth:`modules`. Only the provided locations will be
updated.
Args:
modules (dict): A complete or partial dictionary of the modules
modules (dict): A complete or partial dictionary of the module's
submodules.
strict (bool): If ``True`` checks that ``modules`` is a
subset of the child modules of this instance. Default: ``True``.
Returns:
The module instance after updating the submodules.
"""
@ -388,6 +400,14 @@ class Module(dict):
dst[k] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list):
for i in range(len(dst)):
current_value = dst[i]
@ -396,6 +416,12 @@ class Module(dict):
dst[i] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
apply(self, modules)
return self

View File

@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase):
x = mx.zeros((3,))
mx.grad(loss_fn)(model)
def test_update(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent parameters
with self.assertRaises(ValueError):
updates = {"layers": [{"value": 0}]}
m.update(updates)
with self.assertRaises(ValueError):
updates = {"layers": ["hello"]}
m.update(updates)
# Wronge type
with self.assertRaises(ValueError):
updates = {"layers": [{"weight": "hi"}]}
m.update(updates)
def test_update_modules(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent modules should not be allowed by default
with self.assertRaises(ValueError):
m = m.update_modules({"values": [0, 1]})
# Update wrong types
with self.assertRaises(ValueError):
m = m.update_modules({"layers": [0, 1]})
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.test = mx.array(1.0)
self.list = [mx.array(1.0), mx.array(2.0)]
m = MyModule()
with self.assertRaises(ValueError):
m = m.update_modules({"test": "hi"})
with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]})
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):