mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	default strict mode for module update and update_modules (#2239)
This commit is contained in:
		| @@ -193,7 +193,7 @@ class Module(dict): | ||||
|                     ) | ||||
|  | ||||
|         if len(weights) != 0: | ||||
|             self.update(tree_unflatten(weights)) | ||||
|             self.update(tree_unflatten(weights), strict=False) | ||||
|         return self | ||||
|  | ||||
|     def save_weights(self, file: str): | ||||
| @@ -291,7 +291,7 @@ class Module(dict): | ||||
|  | ||||
|         return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) | ||||
|  | ||||
|     def update(self, parameters: dict) -> Module: | ||||
|     def update(self, parameters: dict, strict: bool = True) -> Module: | ||||
|         """Replace the parameters of this Module with the provided ones in the | ||||
|         dict of dicts and lists. | ||||
|  | ||||
| @@ -305,7 +305,9 @@ class Module(dict): | ||||
|  | ||||
|         Args: | ||||
|             parameters (dict): A complete or partial dictionary of the modules | ||||
|                                parameters. | ||||
|                 parameters. | ||||
|             strict (bool): If ``True`` checks that ``parameters`` is a | ||||
|                 subset of the module's parameters. Default: ``True``. | ||||
|         Returns: | ||||
|             The module instance after updating the parameters. | ||||
|         """ | ||||
| @@ -317,21 +319,29 @@ class Module(dict): | ||||
|                         current_value = dst[k] | ||||
|                         new_value = parameters[k] | ||||
|                         if isinstance(current_value, mx.array): | ||||
|                             if strict and not isinstance(new_value, mx.array): | ||||
|                                 raise ValueError( | ||||
|                                     f"Received invalid type: {type(new_value).__name__}." | ||||
|                                 ) | ||||
|                             dst[k] = new_value | ||||
|                         elif isinstance(current_value, Module): | ||||
|                             current_value.update(new_value) | ||||
|                         elif isinstance(current_value, (dict, list)): | ||||
|                         else: | ||||
|                             apply(current_value, new_value) | ||||
|                     elif strict: | ||||
|                         raise ValueError(f'Module does not have parameter named "{k}".') | ||||
|             elif isinstance(parameters, list): | ||||
|                 for i in range(len(parameters)): | ||||
|                     current_value = dst[i] | ||||
|                     new_value = parameters[i] | ||||
|                     if isinstance(current_value, mx.array): | ||||
|                         if strict and not isinstance(new_value, mx.array): | ||||
|                             raise ValueError( | ||||
|                                 f"Received invalid type: {type(new_value).__name__}." | ||||
|                             ) | ||||
|                         dst[i] = new_value | ||||
|                     elif isinstance(current_value, Module): | ||||
|                         current_value.update(new_value) | ||||
|                     elif isinstance(current_value, (dict, list)): | ||||
|                     else: | ||||
|                         apply(current_value, new_value) | ||||
|             elif strict: | ||||
|                 raise ValueError(f"Received invalid type: {type(parameters).__name__}.") | ||||
|  | ||||
|         apply(self, parameters) | ||||
|         return self | ||||
| @@ -359,7 +369,7 @@ class Module(dict): | ||||
|         self.update(self.filter_and_map(filter_fn, map_fn)) | ||||
|         return self | ||||
|  | ||||
|     def update_modules(self, modules: dict) -> Module: | ||||
|     def update_modules(self, modules: dict, strict: bool = True) -> Module: | ||||
|         """Replace the child modules of this :class:`Module` instance with the | ||||
|         provided ones in the dict of dicts and lists. | ||||
|  | ||||
| @@ -368,12 +378,14 @@ class Module(dict): | ||||
|         programmatically swapping layers. | ||||
|  | ||||
|         The passed in parameters dictionary need not be a full dictionary | ||||
|         similar to :meth:`parameters`. Only the provided locations will be | ||||
|         similar to :meth:`modules`. Only the provided locations will be | ||||
|         updated. | ||||
|  | ||||
|         Args: | ||||
|             modules (dict): A complete or partial dictionary of the modules | ||||
|             modules (dict): A complete or partial dictionary of the module's | ||||
|                 submodules. | ||||
|             strict (bool): If ``True`` checks that ``modules`` is a | ||||
|                 subset of the child modules of this instance. Default: ``True``. | ||||
|         Returns: | ||||
|             The module instance after updating the submodules. | ||||
|         """ | ||||
| @@ -388,6 +400,14 @@ class Module(dict): | ||||
|                             dst[k] = new_value | ||||
|                         elif isinstance(current_value, (dict, list)): | ||||
|                             apply(current_value, new_value) | ||||
|                         elif strict: | ||||
|                             raise ValueError( | ||||
|                                 f"Received invalid type: {type(new_value).__name__}." | ||||
|                             ) | ||||
|                     elif strict: | ||||
|                         raise ValueError( | ||||
|                             f'Module does not have sub-module named "{k}".' | ||||
|                         ) | ||||
|             elif isinstance(modules, list): | ||||
|                 for i in range(len(dst)): | ||||
|                     current_value = dst[i] | ||||
| @@ -396,6 +416,12 @@ class Module(dict): | ||||
|                         dst[i] = new_value | ||||
|                     elif isinstance(current_value, (dict, list)): | ||||
|                         apply(current_value, new_value) | ||||
|                     elif strict: | ||||
|                         raise ValueError( | ||||
|                             f"Received invalid type: {type(new_value).__name__}." | ||||
|                         ) | ||||
|             elif strict: | ||||
|                 raise ValueError(f"Received invalid type: {type(modules).__name__}.") | ||||
|  | ||||
|         apply(self, modules) | ||||
|         return self | ||||
|   | ||||
| @@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase): | ||||
|         x = mx.zeros((3,)) | ||||
|         mx.grad(loss_fn)(model) | ||||
|  | ||||
|     def test_update(self): | ||||
|         m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) | ||||
|  | ||||
|         # Updating non-existent parameters | ||||
|         with self.assertRaises(ValueError): | ||||
|             updates = {"layers": [{"value": 0}]} | ||||
|             m.update(updates) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             updates = {"layers": ["hello"]} | ||||
|             m.update(updates) | ||||
|  | ||||
|         # Wronge type | ||||
|         with self.assertRaises(ValueError): | ||||
|             updates = {"layers": [{"weight": "hi"}]} | ||||
|             m.update(updates) | ||||
|  | ||||
|     def test_update_modules(self): | ||||
|         m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) | ||||
|  | ||||
|         # Updating non-existent modules should not be allowed by default | ||||
|         with self.assertRaises(ValueError): | ||||
|             m = m.update_modules({"values": [0, 1]}) | ||||
|  | ||||
|         # Update wrong types | ||||
|         with self.assertRaises(ValueError): | ||||
|             m = m.update_modules({"layers": [0, 1]}) | ||||
|  | ||||
|         class MyModule(nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.test = mx.array(1.0) | ||||
|                 self.list = [mx.array(1.0), mx.array(2.0)] | ||||
|  | ||||
|         m = MyModule() | ||||
|         with self.assertRaises(ValueError): | ||||
|             m = m.update_modules({"test": "hi"}) | ||||
|         with self.assertRaises(ValueError): | ||||
|             m = m.update_modules({"list": ["hi"]}) | ||||
|  | ||||
|  | ||||
| class TestLayers(mlx_tests.MLXTestCase): | ||||
|     def test_identity(self): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun