mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
default strict mode for module update and update_modules (#2239)
This commit is contained in:
parent
52dc8c8cd5
commit
c763fe1be0
@ -193,7 +193,7 @@ class Module(dict):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(weights) != 0:
|
if len(weights) != 0:
|
||||||
self.update(tree_unflatten(weights))
|
self.update(tree_unflatten(weights), strict=False)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def save_weights(self, file: str):
|
def save_weights(self, file: str):
|
||||||
@ -291,7 +291,7 @@ class Module(dict):
|
|||||||
|
|
||||||
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
||||||
|
|
||||||
def update(self, parameters: dict) -> Module:
|
def update(self, parameters: dict, strict: bool = True) -> Module:
|
||||||
"""Replace the parameters of this Module with the provided ones in the
|
"""Replace the parameters of this Module with the provided ones in the
|
||||||
dict of dicts and lists.
|
dict of dicts and lists.
|
||||||
|
|
||||||
@ -305,7 +305,9 @@ class Module(dict):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
parameters (dict): A complete or partial dictionary of the modules
|
parameters (dict): A complete or partial dictionary of the modules
|
||||||
parameters.
|
parameters.
|
||||||
|
strict (bool): If ``True`` checks that ``parameters`` is a
|
||||||
|
subset of the module's parameters. Default: ``True``.
|
||||||
Returns:
|
Returns:
|
||||||
The module instance after updating the parameters.
|
The module instance after updating the parameters.
|
||||||
"""
|
"""
|
||||||
@ -317,21 +319,29 @@ class Module(dict):
|
|||||||
current_value = dst[k]
|
current_value = dst[k]
|
||||||
new_value = parameters[k]
|
new_value = parameters[k]
|
||||||
if isinstance(current_value, mx.array):
|
if isinstance(current_value, mx.array):
|
||||||
|
if strict and not isinstance(new_value, mx.array):
|
||||||
|
raise ValueError(
|
||||||
|
f"Received invalid type: {type(new_value).__name__}."
|
||||||
|
)
|
||||||
dst[k] = new_value
|
dst[k] = new_value
|
||||||
elif isinstance(current_value, Module):
|
else:
|
||||||
current_value.update(new_value)
|
|
||||||
elif isinstance(current_value, (dict, list)):
|
|
||||||
apply(current_value, new_value)
|
apply(current_value, new_value)
|
||||||
|
elif strict:
|
||||||
|
raise ValueError(f'Module does not have parameter named "{k}".')
|
||||||
elif isinstance(parameters, list):
|
elif isinstance(parameters, list):
|
||||||
for i in range(len(parameters)):
|
for i in range(len(parameters)):
|
||||||
current_value = dst[i]
|
current_value = dst[i]
|
||||||
new_value = parameters[i]
|
new_value = parameters[i]
|
||||||
if isinstance(current_value, mx.array):
|
if isinstance(current_value, mx.array):
|
||||||
|
if strict and not isinstance(new_value, mx.array):
|
||||||
|
raise ValueError(
|
||||||
|
f"Received invalid type: {type(new_value).__name__}."
|
||||||
|
)
|
||||||
dst[i] = new_value
|
dst[i] = new_value
|
||||||
elif isinstance(current_value, Module):
|
else:
|
||||||
current_value.update(new_value)
|
|
||||||
elif isinstance(current_value, (dict, list)):
|
|
||||||
apply(current_value, new_value)
|
apply(current_value, new_value)
|
||||||
|
elif strict:
|
||||||
|
raise ValueError(f"Received invalid type: {type(parameters).__name__}.")
|
||||||
|
|
||||||
apply(self, parameters)
|
apply(self, parameters)
|
||||||
return self
|
return self
|
||||||
@ -359,7 +369,7 @@ class Module(dict):
|
|||||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def update_modules(self, modules: dict) -> Module:
|
def update_modules(self, modules: dict, strict: bool = True) -> Module:
|
||||||
"""Replace the child modules of this :class:`Module` instance with the
|
"""Replace the child modules of this :class:`Module` instance with the
|
||||||
provided ones in the dict of dicts and lists.
|
provided ones in the dict of dicts and lists.
|
||||||
|
|
||||||
@ -368,12 +378,14 @@ class Module(dict):
|
|||||||
programmatically swapping layers.
|
programmatically swapping layers.
|
||||||
|
|
||||||
The passed in parameters dictionary need not be a full dictionary
|
The passed in parameters dictionary need not be a full dictionary
|
||||||
similar to :meth:`parameters`. Only the provided locations will be
|
similar to :meth:`modules`. Only the provided locations will be
|
||||||
updated.
|
updated.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
modules (dict): A complete or partial dictionary of the modules
|
modules (dict): A complete or partial dictionary of the module's
|
||||||
submodules.
|
submodules.
|
||||||
|
strict (bool): If ``True`` checks that ``modules`` is a
|
||||||
|
subset of the child modules of this instance. Default: ``True``.
|
||||||
Returns:
|
Returns:
|
||||||
The module instance after updating the submodules.
|
The module instance after updating the submodules.
|
||||||
"""
|
"""
|
||||||
@ -388,6 +400,14 @@ class Module(dict):
|
|||||||
dst[k] = new_value
|
dst[k] = new_value
|
||||||
elif isinstance(current_value, (dict, list)):
|
elif isinstance(current_value, (dict, list)):
|
||||||
apply(current_value, new_value)
|
apply(current_value, new_value)
|
||||||
|
elif strict:
|
||||||
|
raise ValueError(
|
||||||
|
f"Received invalid type: {type(new_value).__name__}."
|
||||||
|
)
|
||||||
|
elif strict:
|
||||||
|
raise ValueError(
|
||||||
|
f'Module does not have sub-module named "{k}".'
|
||||||
|
)
|
||||||
elif isinstance(modules, list):
|
elif isinstance(modules, list):
|
||||||
for i in range(len(dst)):
|
for i in range(len(dst)):
|
||||||
current_value = dst[i]
|
current_value = dst[i]
|
||||||
@ -396,6 +416,12 @@ class Module(dict):
|
|||||||
dst[i] = new_value
|
dst[i] = new_value
|
||||||
elif isinstance(current_value, (dict, list)):
|
elif isinstance(current_value, (dict, list)):
|
||||||
apply(current_value, new_value)
|
apply(current_value, new_value)
|
||||||
|
elif strict:
|
||||||
|
raise ValueError(
|
||||||
|
f"Received invalid type: {type(new_value).__name__}."
|
||||||
|
)
|
||||||
|
elif strict:
|
||||||
|
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
|
||||||
|
|
||||||
apply(self, modules)
|
apply(self, modules)
|
||||||
return self
|
return self
|
||||||
|
@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
x = mx.zeros((3,))
|
x = mx.zeros((3,))
|
||||||
mx.grad(loss_fn)(model)
|
mx.grad(loss_fn)(model)
|
||||||
|
|
||||||
|
def test_update(self):
|
||||||
|
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
||||||
|
|
||||||
|
# Updating non-existent parameters
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
updates = {"layers": [{"value": 0}]}
|
||||||
|
m.update(updates)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
updates = {"layers": ["hello"]}
|
||||||
|
m.update(updates)
|
||||||
|
|
||||||
|
# Wronge type
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
updates = {"layers": [{"weight": "hi"}]}
|
||||||
|
m.update(updates)
|
||||||
|
|
||||||
|
def test_update_modules(self):
|
||||||
|
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
||||||
|
|
||||||
|
# Updating non-existent modules should not be allowed by default
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
m = m.update_modules({"values": [0, 1]})
|
||||||
|
|
||||||
|
# Update wrong types
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
m = m.update_modules({"layers": [0, 1]})
|
||||||
|
|
||||||
|
class MyModule(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.test = mx.array(1.0)
|
||||||
|
self.list = [mx.array(1.0), mx.array(2.0)]
|
||||||
|
|
||||||
|
m = MyModule()
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
m = m.update_modules({"test": "hi"})
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
m = m.update_modules({"list": ["hi"]})
|
||||||
|
|
||||||
|
|
||||||
class TestLayers(mlx_tests.MLXTestCase):
|
class TestLayers(mlx_tests.MLXTestCase):
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user