From c763fe1be0f1158e0f53f6f6a28f56d69b1c7fe8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 5 Jun 2025 15:27:02 -0700 Subject: [PATCH] default strict mode for module update and update_modules (#2239) --- python/mlx/nn/layers/base.py | 50 +++++++++++++++++++++++++++--------- python/tests/test_nn.py | 40 +++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index b35c58478..783ef446d 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -193,7 +193,7 @@ class Module(dict): ) if len(weights) != 0: - self.update(tree_unflatten(weights)) + self.update(tree_unflatten(weights), strict=False) return self def save_weights(self, file: str): @@ -291,7 +291,7 @@ class Module(dict): return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) - def update(self, parameters: dict) -> Module: + def update(self, parameters: dict, strict: bool = True) -> Module: """Replace the parameters of this Module with the provided ones in the dict of dicts and lists. @@ -305,7 +305,9 @@ class Module(dict): Args: parameters (dict): A complete or partial dictionary of the modules - parameters. + parameters. + strict (bool): If ``True`` checks that ``parameters`` is a + subset of the module's parameters. Default: ``True``. Returns: The module instance after updating the parameters. """ @@ -317,21 +319,29 @@ class Module(dict): current_value = dst[k] new_value = parameters[k] if isinstance(current_value, mx.array): + if strict and not isinstance(new_value, mx.array): + raise ValueError( + f"Received invalid type: {type(new_value).__name__}." + ) dst[k] = new_value - elif isinstance(current_value, Module): - current_value.update(new_value) - elif isinstance(current_value, (dict, list)): + else: apply(current_value, new_value) + elif strict: + raise ValueError(f'Module does not have parameter named "{k}".') elif isinstance(parameters, list): for i in range(len(parameters)): current_value = dst[i] new_value = parameters[i] if isinstance(current_value, mx.array): + if strict and not isinstance(new_value, mx.array): + raise ValueError( + f"Received invalid type: {type(new_value).__name__}." + ) dst[i] = new_value - elif isinstance(current_value, Module): - current_value.update(new_value) - elif isinstance(current_value, (dict, list)): + else: apply(current_value, new_value) + elif strict: + raise ValueError(f"Received invalid type: {type(parameters).__name__}.") apply(self, parameters) return self @@ -359,7 +369,7 @@ class Module(dict): self.update(self.filter_and_map(filter_fn, map_fn)) return self - def update_modules(self, modules: dict) -> Module: + def update_modules(self, modules: dict, strict: bool = True) -> Module: """Replace the child modules of this :class:`Module` instance with the provided ones in the dict of dicts and lists. @@ -368,12 +378,14 @@ class Module(dict): programmatically swapping layers. The passed in parameters dictionary need not be a full dictionary - similar to :meth:`parameters`. Only the provided locations will be + similar to :meth:`modules`. Only the provided locations will be updated. Args: - modules (dict): A complete or partial dictionary of the modules + modules (dict): A complete or partial dictionary of the module's submodules. + strict (bool): If ``True`` checks that ``modules`` is a + subset of the child modules of this instance. Default: ``True``. Returns: The module instance after updating the submodules. """ @@ -388,6 +400,14 @@ class Module(dict): dst[k] = new_value elif isinstance(current_value, (dict, list)): apply(current_value, new_value) + elif strict: + raise ValueError( + f"Received invalid type: {type(new_value).__name__}." + ) + elif strict: + raise ValueError( + f'Module does not have sub-module named "{k}".' + ) elif isinstance(modules, list): for i in range(len(dst)): current_value = dst[i] @@ -396,6 +416,12 @@ class Module(dict): dst[i] = new_value elif isinstance(current_value, (dict, list)): apply(current_value, new_value) + elif strict: + raise ValueError( + f"Received invalid type: {type(new_value).__name__}." + ) + elif strict: + raise ValueError(f"Received invalid type: {type(modules).__name__}.") apply(self, modules) return self diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 826d53d96..13e31ad96 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase): x = mx.zeros((3,)) mx.grad(loss_fn)(model) + def test_update(self): + m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) + + # Updating non-existent parameters + with self.assertRaises(ValueError): + updates = {"layers": [{"value": 0}]} + m.update(updates) + + with self.assertRaises(ValueError): + updates = {"layers": ["hello"]} + m.update(updates) + + # Wronge type + with self.assertRaises(ValueError): + updates = {"layers": [{"weight": "hi"}]} + m.update(updates) + + def test_update_modules(self): + m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) + + # Updating non-existent modules should not be allowed by default + with self.assertRaises(ValueError): + m = m.update_modules({"values": [0, 1]}) + + # Update wrong types + with self.assertRaises(ValueError): + m = m.update_modules({"layers": [0, 1]}) + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.test = mx.array(1.0) + self.list = [mx.array(1.0), mx.array(2.0)] + + m = MyModule() + with self.assertRaises(ValueError): + m = m.update_modules({"test": "hi"}) + with self.assertRaises(ValueError): + m = m.update_modules({"list": ["hi"]}) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self):