From d6112515024f930131134d0c0e91d50493349a0e Mon Sep 17 00:00:00 2001 From: AmirHossein_Razlighi <79264971+amirhossein-razlighi@users.noreply.github.com> Date: Thu, 28 Mar 2024 06:28:29 +0330 Subject: [PATCH] Support Chaining for some of functionalities of `nn.Module` (#885) (#897) * add chaining support for some of the functionalities of "nn.Module" * reformat * change the return types * remove return types * add return type with forward referencing * add tests for chaining * add name to contributors * Update python/mlx/nn/layers/base.py Co-authored-by: Awni Hannun * Update python/mlx/nn/layers/base.py Co-authored-by: Awni Hannun * update docstring * update docstrings --------- Co-authored-by: Awni Hannun --- ACKNOWLEDGMENTS.md | 1 + python/mlx/nn/layers/base.py | 51 +++++++++++++++++++++++++++++------- python/tests/test_nn.py | 10 +++++++ 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index e841d0d0c..0e5d1142d 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -15,6 +15,7 @@ MLX was developed with contributions from the following individuals: - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays. - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` +- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 9bd336690..996410526 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -151,7 +151,7 @@ class Module(dict): self, file_or_weights: Union[str, List[Tuple[str, mx.array]]], strict: bool = True, - ): + ) -> "Module": """ Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list. @@ -164,6 +164,9 @@ class Module(dict): only the weights actually contained in the model are loaded and shapes are not checked. Default: ``True``. + Returns: + The module instance after updating the weights. + Example: .. code-block:: python @@ -223,6 +226,7 @@ class Module(dict): ) self.update(tree_unflatten(weights)) + return self def save_weights(self, file: str): """ @@ -319,7 +323,7 @@ class Module(dict): return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) - def update(self, parameters: dict): + def update(self, parameters: dict) -> "Module": """Replace the parameters of this Module with the provided ones in the dict of dicts and lists. @@ -334,6 +338,8 @@ class Module(dict): Args: parameters (dict): A complete or partial dictionary of the modules parameters. + Returns: + The module instance after updating the parameters. """ def apply(dst, parameters): @@ -360,12 +366,13 @@ class Module(dict): apply(current_value, new_value) apply(self, parameters) + return self def apply( self, map_fn: Callable[[mx.array], mx.array], filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None, - ): + ) -> "Module": """Map all the parameters using the provided ``map_fn`` and immediately update the module with the mapped parameters. @@ -376,11 +383,15 @@ class Module(dict): map_fn (Callable): Maps an array to another array filter_fn (Callable, optional): Filter to select which arrays to map (default: :meth:`Module.valid_parameter_filter`). + + Returns: + The module instance after updating the parameters. """ filter_fn = filter_fn or Module.valid_parameter_filter self.update(self.filter_and_map(filter_fn, map_fn)) + return self - def update_modules(self, modules: dict): + def update_modules(self, modules: dict) -> "Module": """Replace the child modules of this :class:`Module` instance with the provided ones in the dict of dicts and lists. @@ -395,6 +406,8 @@ class Module(dict): Args: modules (dict): A complete or partial dictionary of the modules submodules. + Returns: + The module instance after updating the submodules. """ def apply(dst, modules): @@ -417,13 +430,19 @@ class Module(dict): apply(current_value, new_value) apply(self, modules) + return self - def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]): + def apply_to_modules( + self, apply_fn: Callable[[str, "mlx.nn.Module"], Any] + ) -> "Module": """Apply a function to all the modules in this instance (including this instance). Args: apply_fn (Callable): The function to apply to the modules. + + Returns: + The module instance after updating submodules. """ module_stack = [("", self)] while module_stack: @@ -433,6 +452,7 @@ class Module(dict): module_stack.extend( tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module) ) + return self def modules(self): """Return a list with all the modules in this instance. @@ -469,7 +489,7 @@ class Module(dict): recurse: bool = True, keys: Optional[Union[str, List[str]]] = None, strict: bool = False, - ): + ) -> "Module": """Freeze the Module's parameters or some of them. Freezing a parameter means not computing gradients for it. @@ -493,6 +513,9 @@ class Module(dict): ``module.freeze(keys="bias")``. strict (bool, optional): If set to ``True`` validate that the passed keys exist. Default: ``False``. + + Returns: + The module instance after freezing the parameters. """ def _freeze_impl(_, m): @@ -513,6 +536,7 @@ class Module(dict): self.apply_to_modules(_freeze_impl) else: _freeze_impl("", self) + return self def unfreeze( self, @@ -520,7 +544,7 @@ class Module(dict): recurse: bool = True, keys: Optional[Union[str, List[str]]] = None, strict: bool = False, - ): + ) -> "Module": """Unfreeze the Module's parameters or some of them. This function is idempotent ie unfreezing a model that is not frozen is @@ -545,6 +569,9 @@ class Module(dict): ``module.unfreeze(keys="bias")``. strict (bool, optional): If set to ``True`` validate that the passed keys exist. Default: ``False``. + + Returns: + The module instance after unfreezing the parameters. """ def _unfreeze_impl(_, m): @@ -559,8 +586,9 @@ class Module(dict): self.apply_to_modules(_unfreeze_impl) else: _unfreeze_impl("", self) + return self - def train(self, mode: bool = True): + def train(self, mode: bool = True) -> "Module": """Set the model in or out of training mode. Training mode only applies to certain layers. For example @@ -570,19 +598,22 @@ class Module(dict): Args: mode (bool): Indicate if the model should be in training or evaluation mode. Default: ``True``. + Returns: + The module instance after updating the training mode. """ def _set_train(_, m): m._training = mode self.apply_to_modules(_set_train) + return self - def eval(self): + def eval(self) -> "Module": """Set the model to evaluation mode. See :func:`train`. """ - self.train(False) + return self.train(False) def set_dtype( self, diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 9225a97d9..f5e8f6d8d 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -162,6 +162,16 @@ class TestBase(mlx_tests.MLXTestCase): m.state["hello"] = "world" self.assertEqual(m.state["hello"], "world") + def test_chaining(self): + m = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1)) + pre_freeze_num_params = len(m.parameters()) + m.freeze().unfreeze() + self.assertEqual(len(m.parameters()), pre_freeze_num_params) + params_dict = m.parameters() + + self.assertFalse(m.update(params_dict).eval()._training) + self.assertTrue(m.train()._training) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self):