default strict mode for module update and update_modules (#2239)

This commit is contained in:
Awni Hannun
2025-06-05 15:27:02 -07:00
committed by GitHub
parent 52dc8c8cd5
commit c763fe1be0
2 changed files with 78 additions and 12 deletions

View File

@@ -193,7 +193,7 @@ class Module(dict):
)
if len(weights) != 0:
self.update(tree_unflatten(weights))
self.update(tree_unflatten(weights), strict=False)
return self
def save_weights(self, file: str):
@@ -291,7 +291,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict) -> Module:
def update(self, parameters: dict, strict: bool = True) -> Module:
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.
@@ -305,7 +305,9 @@ class Module(dict):
Args:
parameters (dict): A complete or partial dictionary of the modules
parameters.
parameters.
strict (bool): If ``True`` checks that ``parameters`` is a
subset of the module's parameters. Default: ``True``.
Returns:
The module instance after updating the parameters.
"""
@@ -317,21 +319,29 @@ class Module(dict):
current_value = dst[k]
new_value = parameters[k]
if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[k] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
else:
apply(current_value, new_value)
elif strict:
raise ValueError(f'Module does not have parameter named "{k}".')
elif isinstance(parameters, list):
for i in range(len(parameters)):
current_value = dst[i]
new_value = parameters[i]
if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[i] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
else:
apply(current_value, new_value)
elif strict:
raise ValueError(f"Received invalid type: {type(parameters).__name__}.")
apply(self, parameters)
return self
@@ -359,7 +369,7 @@ class Module(dict):
self.update(self.filter_and_map(filter_fn, map_fn))
return self
def update_modules(self, modules: dict) -> Module:
def update_modules(self, modules: dict, strict: bool = True) -> Module:
"""Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists.
@@ -368,12 +378,14 @@ class Module(dict):
programmatically swapping layers.
The passed in parameters dictionary need not be a full dictionary
similar to :meth:`parameters`. Only the provided locations will be
similar to :meth:`modules`. Only the provided locations will be
updated.
Args:
modules (dict): A complete or partial dictionary of the modules
modules (dict): A complete or partial dictionary of the module's
submodules.
strict (bool): If ``True`` checks that ``modules`` is a
subset of the child modules of this instance. Default: ``True``.
Returns:
The module instance after updating the submodules.
"""
@@ -388,6 +400,14 @@ class Module(dict):
dst[k] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list):
for i in range(len(dst)):
current_value = dst[i]
@@ -396,6 +416,12 @@ class Module(dict):
dst[i] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
apply(self, modules)
return self