mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
default strict mode for module update and update_modules (#2239)
This commit is contained in:
parent
52dc8c8cd5
commit
c763fe1be0
@ -193,7 +193,7 @@ class Module(dict):
|
||||
)
|
||||
|
||||
if len(weights) != 0:
|
||||
self.update(tree_unflatten(weights))
|
||||
self.update(tree_unflatten(weights), strict=False)
|
||||
return self
|
||||
|
||||
def save_weights(self, file: str):
|
||||
@ -291,7 +291,7 @@ class Module(dict):
|
||||
|
||||
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
||||
|
||||
def update(self, parameters: dict) -> Module:
|
||||
def update(self, parameters: dict, strict: bool = True) -> Module:
|
||||
"""Replace the parameters of this Module with the provided ones in the
|
||||
dict of dicts and lists.
|
||||
|
||||
@ -305,7 +305,9 @@ class Module(dict):
|
||||
|
||||
Args:
|
||||
parameters (dict): A complete or partial dictionary of the modules
|
||||
parameters.
|
||||
parameters.
|
||||
strict (bool): If ``True`` checks that ``parameters`` is a
|
||||
subset of the module's parameters. Default: ``True``.
|
||||
Returns:
|
||||
The module instance after updating the parameters.
|
||||
"""
|
||||
@ -317,21 +319,29 @@ class Module(dict):
|
||||
current_value = dst[k]
|
||||
new_value = parameters[k]
|
||||
if isinstance(current_value, mx.array):
|
||||
if strict and not isinstance(new_value, mx.array):
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
dst[k] = new_value
|
||||
elif isinstance(current_value, Module):
|
||||
current_value.update(new_value)
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
else:
|
||||
apply(current_value, new_value)
|
||||
elif strict:
|
||||
raise ValueError(f'Module does not have parameter named "{k}".')
|
||||
elif isinstance(parameters, list):
|
||||
for i in range(len(parameters)):
|
||||
current_value = dst[i]
|
||||
new_value = parameters[i]
|
||||
if isinstance(current_value, mx.array):
|
||||
if strict and not isinstance(new_value, mx.array):
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
dst[i] = new_value
|
||||
elif isinstance(current_value, Module):
|
||||
current_value.update(new_value)
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
else:
|
||||
apply(current_value, new_value)
|
||||
elif strict:
|
||||
raise ValueError(f"Received invalid type: {type(parameters).__name__}.")
|
||||
|
||||
apply(self, parameters)
|
||||
return self
|
||||
@ -359,7 +369,7 @@ class Module(dict):
|
||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||
return self
|
||||
|
||||
def update_modules(self, modules: dict) -> Module:
|
||||
def update_modules(self, modules: dict, strict: bool = True) -> Module:
|
||||
"""Replace the child modules of this :class:`Module` instance with the
|
||||
provided ones in the dict of dicts and lists.
|
||||
|
||||
@ -368,12 +378,14 @@ class Module(dict):
|
||||
programmatically swapping layers.
|
||||
|
||||
The passed in parameters dictionary need not be a full dictionary
|
||||
similar to :meth:`parameters`. Only the provided locations will be
|
||||
similar to :meth:`modules`. Only the provided locations will be
|
||||
updated.
|
||||
|
||||
Args:
|
||||
modules (dict): A complete or partial dictionary of the modules
|
||||
modules (dict): A complete or partial dictionary of the module's
|
||||
submodules.
|
||||
strict (bool): If ``True`` checks that ``modules`` is a
|
||||
subset of the child modules of this instance. Default: ``True``.
|
||||
Returns:
|
||||
The module instance after updating the submodules.
|
||||
"""
|
||||
@ -388,6 +400,14 @@ class Module(dict):
|
||||
dst[k] = new_value
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif strict:
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
elif strict:
|
||||
raise ValueError(
|
||||
f'Module does not have sub-module named "{k}".'
|
||||
)
|
||||
elif isinstance(modules, list):
|
||||
for i in range(len(dst)):
|
||||
current_value = dst[i]
|
||||
@ -396,6 +416,12 @@ class Module(dict):
|
||||
dst[i] = new_value
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif strict:
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
elif strict:
|
||||
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
|
||||
|
||||
apply(self, modules)
|
||||
return self
|
||||
|
@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
x = mx.zeros((3,))
|
||||
mx.grad(loss_fn)(model)
|
||||
|
||||
def test_update(self):
|
||||
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
||||
|
||||
# Updating non-existent parameters
|
||||
with self.assertRaises(ValueError):
|
||||
updates = {"layers": [{"value": 0}]}
|
||||
m.update(updates)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
updates = {"layers": ["hello"]}
|
||||
m.update(updates)
|
||||
|
||||
# Wronge type
|
||||
with self.assertRaises(ValueError):
|
||||
updates = {"layers": [{"weight": "hi"}]}
|
||||
m.update(updates)
|
||||
|
||||
def test_update_modules(self):
|
||||
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
||||
|
||||
# Updating non-existent modules should not be allowed by default
|
||||
with self.assertRaises(ValueError):
|
||||
m = m.update_modules({"values": [0, 1]})
|
||||
|
||||
# Update wrong types
|
||||
with self.assertRaises(ValueError):
|
||||
m = m.update_modules({"layers": [0, 1]})
|
||||
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.test = mx.array(1.0)
|
||||
self.list = [mx.array(1.0), mx.array(2.0)]
|
||||
|
||||
m = MyModule()
|
||||
with self.assertRaises(ValueError):
|
||||
m = m.update_modules({"test": "hi"})
|
||||
with self.assertRaises(ValueError):
|
||||
m = m.update_modules({"list": ["hi"]})
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
Loading…
Reference in New Issue
Block a user