default strict mode for module update and update_modules (#2239)

This commit is contained in:
Awni Hannun 2025-06-05 15:27:02 -07:00 committed by GitHub
parent 52dc8c8cd5
commit c763fe1be0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 12 deletions

View File

@ -193,7 +193,7 @@ class Module(dict):
) )
if len(weights) != 0: if len(weights) != 0:
self.update(tree_unflatten(weights)) self.update(tree_unflatten(weights), strict=False)
return self return self
def save_weights(self, file: str): def save_weights(self, file: str):
@ -291,7 +291,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict) -> Module: def update(self, parameters: dict, strict: bool = True) -> Module:
"""Replace the parameters of this Module with the provided ones in the """Replace the parameters of this Module with the provided ones in the
dict of dicts and lists. dict of dicts and lists.
@ -306,6 +306,8 @@ class Module(dict):
Args: Args:
parameters (dict): A complete or partial dictionary of the modules parameters (dict): A complete or partial dictionary of the modules
parameters. parameters.
strict (bool): If ``True`` checks that ``parameters`` is a
subset of the module's parameters. Default: ``True``.
Returns: Returns:
The module instance after updating the parameters. The module instance after updating the parameters.
""" """
@ -317,21 +319,29 @@ class Module(dict):
current_value = dst[k] current_value = dst[k]
new_value = parameters[k] new_value = parameters[k]
if isinstance(current_value, mx.array): if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[k] = new_value dst[k] = new_value
elif isinstance(current_value, Module): else:
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(f'Module does not have parameter named "{k}".')
elif isinstance(parameters, list): elif isinstance(parameters, list):
for i in range(len(parameters)): for i in range(len(parameters)):
current_value = dst[i] current_value = dst[i]
new_value = parameters[i] new_value = parameters[i]
if isinstance(current_value, mx.array): if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[i] = new_value dst[i] = new_value
elif isinstance(current_value, Module): else:
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(f"Received invalid type: {type(parameters).__name__}.")
apply(self, parameters) apply(self, parameters)
return self return self
@ -359,7 +369,7 @@ class Module(dict):
self.update(self.filter_and_map(filter_fn, map_fn)) self.update(self.filter_and_map(filter_fn, map_fn))
return self return self
def update_modules(self, modules: dict) -> Module: def update_modules(self, modules: dict, strict: bool = True) -> Module:
"""Replace the child modules of this :class:`Module` instance with the """Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists. provided ones in the dict of dicts and lists.
@ -368,12 +378,14 @@ class Module(dict):
programmatically swapping layers. programmatically swapping layers.
The passed in parameters dictionary need not be a full dictionary The passed in parameters dictionary need not be a full dictionary
similar to :meth:`parameters`. Only the provided locations will be similar to :meth:`modules`. Only the provided locations will be
updated. updated.
Args: Args:
modules (dict): A complete or partial dictionary of the modules modules (dict): A complete or partial dictionary of the module's
submodules. submodules.
strict (bool): If ``True`` checks that ``modules`` is a
subset of the child modules of this instance. Default: ``True``.
Returns: Returns:
The module instance after updating the submodules. The module instance after updating the submodules.
""" """
@ -388,6 +400,14 @@ class Module(dict):
dst[k] = new_value dst[k] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list): elif isinstance(modules, list):
for i in range(len(dst)): for i in range(len(dst)):
current_value = dst[i] current_value = dst[i]
@ -396,6 +416,12 @@ class Module(dict):
dst[i] = new_value dst[i] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
apply(self, modules) apply(self, modules)
return self return self

View File

@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase):
x = mx.zeros((3,)) x = mx.zeros((3,))
mx.grad(loss_fn)(model) mx.grad(loss_fn)(model)
def test_update(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent parameters
with self.assertRaises(ValueError):
updates = {"layers": [{"value": 0}]}
m.update(updates)
with self.assertRaises(ValueError):
updates = {"layers": ["hello"]}
m.update(updates)
# Wronge type
with self.assertRaises(ValueError):
updates = {"layers": [{"weight": "hi"}]}
m.update(updates)
def test_update_modules(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent modules should not be allowed by default
with self.assertRaises(ValueError):
m = m.update_modules({"values": [0, 1]})
# Update wrong types
with self.assertRaises(ValueError):
m = m.update_modules({"layers": [0, 1]})
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.test = mx.array(1.0)
self.list = [mx.array(1.0), mx.array(2.0)]
m = MyModule()
with self.assertRaises(ValueError):
m = m.update_modules({"test": "hi"})
with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]})
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):