Support Chaining for some of functionalities of nn.Module (#885) (#897)

* add chaining support for some of the functionalities of "nn.Module"

* reformat

* change the return types

* remove return types

* add return type with forward referencing

* add tests for chaining

* add name to contributors

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update docstring

* update docstrings

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
AmirHossein_Razlighi 2024-03-28 06:28:29 +03:30 committed by GitHub
parent f30b659291
commit d611251502
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 52 additions and 10 deletions

View File

@ -15,6 +15,7 @@ MLX was developed with contributions from the following individuals:
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays. - Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a> </a>

View File

@ -151,7 +151,7 @@ class Module(dict):
self, self,
file_or_weights: Union[str, List[Tuple[str, mx.array]]], file_or_weights: Union[str, List[Tuple[str, mx.array]]],
strict: bool = True, strict: bool = True,
): ) -> "Module":
""" """
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list. Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
@ -164,6 +164,9 @@ class Module(dict):
only the weights actually contained in the model are loaded and only the weights actually contained in the model are loaded and
shapes are not checked. Default: ``True``. shapes are not checked. Default: ``True``.
Returns:
The module instance after updating the weights.
Example: Example:
.. code-block:: python .. code-block:: python
@ -223,6 +226,7 @@ class Module(dict):
) )
self.update(tree_unflatten(weights)) self.update(tree_unflatten(weights))
return self
def save_weights(self, file: str): def save_weights(self, file: str):
""" """
@ -319,7 +323,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict): def update(self, parameters: dict) -> "Module":
"""Replace the parameters of this Module with the provided ones in the """Replace the parameters of this Module with the provided ones in the
dict of dicts and lists. dict of dicts and lists.
@ -334,6 +338,8 @@ class Module(dict):
Args: Args:
parameters (dict): A complete or partial dictionary of the modules parameters (dict): A complete or partial dictionary of the modules
parameters. parameters.
Returns:
The module instance after updating the parameters.
""" """
def apply(dst, parameters): def apply(dst, parameters):
@ -360,12 +366,13 @@ class Module(dict):
apply(current_value, new_value) apply(current_value, new_value)
apply(self, parameters) apply(self, parameters)
return self
def apply( def apply(
self, self,
map_fn: Callable[[mx.array], mx.array], map_fn: Callable[[mx.array], mx.array],
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None, filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
): ) -> "Module":
"""Map all the parameters using the provided ``map_fn`` and immediately """Map all the parameters using the provided ``map_fn`` and immediately
update the module with the mapped parameters. update the module with the mapped parameters.
@ -376,11 +383,15 @@ class Module(dict):
map_fn (Callable): Maps an array to another array map_fn (Callable): Maps an array to another array
filter_fn (Callable, optional): Filter to select which arrays to filter_fn (Callable, optional): Filter to select which arrays to
map (default: :meth:`Module.valid_parameter_filter`). map (default: :meth:`Module.valid_parameter_filter`).
Returns:
The module instance after updating the parameters.
""" """
filter_fn = filter_fn or Module.valid_parameter_filter filter_fn = filter_fn or Module.valid_parameter_filter
self.update(self.filter_and_map(filter_fn, map_fn)) self.update(self.filter_and_map(filter_fn, map_fn))
return self
def update_modules(self, modules: dict): def update_modules(self, modules: dict) -> "Module":
"""Replace the child modules of this :class:`Module` instance with the """Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists. provided ones in the dict of dicts and lists.
@ -395,6 +406,8 @@ class Module(dict):
Args: Args:
modules (dict): A complete or partial dictionary of the modules modules (dict): A complete or partial dictionary of the modules
submodules. submodules.
Returns:
The module instance after updating the submodules.
""" """
def apply(dst, modules): def apply(dst, modules):
@ -417,13 +430,19 @@ class Module(dict):
apply(current_value, new_value) apply(current_value, new_value)
apply(self, modules) apply(self, modules)
return self
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]): def apply_to_modules(
self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]
) -> "Module":
"""Apply a function to all the modules in this instance (including this """Apply a function to all the modules in this instance (including this
instance). instance).
Args: Args:
apply_fn (Callable): The function to apply to the modules. apply_fn (Callable): The function to apply to the modules.
Returns:
The module instance after updating submodules.
""" """
module_stack = [("", self)] module_stack = [("", self)]
while module_stack: while module_stack:
@ -433,6 +452,7 @@ class Module(dict):
module_stack.extend( module_stack.extend(
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module) tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
) )
return self
def modules(self): def modules(self):
"""Return a list with all the modules in this instance. """Return a list with all the modules in this instance.
@ -469,7 +489,7 @@ class Module(dict):
recurse: bool = True, recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None, keys: Optional[Union[str, List[str]]] = None,
strict: bool = False, strict: bool = False,
): ) -> "Module":
"""Freeze the Module's parameters or some of them. Freezing a parameter means not """Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it. computing gradients for it.
@ -493,6 +513,9 @@ class Module(dict):
``module.freeze(keys="bias")``. ``module.freeze(keys="bias")``.
strict (bool, optional): If set to ``True`` validate that the passed keys exist. strict (bool, optional): If set to ``True`` validate that the passed keys exist.
Default: ``False``. Default: ``False``.
Returns:
The module instance after freezing the parameters.
""" """
def _freeze_impl(_, m): def _freeze_impl(_, m):
@ -513,6 +536,7 @@ class Module(dict):
self.apply_to_modules(_freeze_impl) self.apply_to_modules(_freeze_impl)
else: else:
_freeze_impl("", self) _freeze_impl("", self)
return self
def unfreeze( def unfreeze(
self, self,
@ -520,7 +544,7 @@ class Module(dict):
recurse: bool = True, recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None, keys: Optional[Union[str, List[str]]] = None,
strict: bool = False, strict: bool = False,
): ) -> "Module":
"""Unfreeze the Module's parameters or some of them. """Unfreeze the Module's parameters or some of them.
This function is idempotent ie unfreezing a model that is not frozen is This function is idempotent ie unfreezing a model that is not frozen is
@ -545,6 +569,9 @@ class Module(dict):
``module.unfreeze(keys="bias")``. ``module.unfreeze(keys="bias")``.
strict (bool, optional): If set to ``True`` validate that the passed keys exist. strict (bool, optional): If set to ``True`` validate that the passed keys exist.
Default: ``False``. Default: ``False``.
Returns:
The module instance after unfreezing the parameters.
""" """
def _unfreeze_impl(_, m): def _unfreeze_impl(_, m):
@ -559,8 +586,9 @@ class Module(dict):
self.apply_to_modules(_unfreeze_impl) self.apply_to_modules(_unfreeze_impl)
else: else:
_unfreeze_impl("", self) _unfreeze_impl("", self)
return self
def train(self, mode: bool = True): def train(self, mode: bool = True) -> "Module":
"""Set the model in or out of training mode. """Set the model in or out of training mode.
Training mode only applies to certain layers. For example Training mode only applies to certain layers. For example
@ -570,19 +598,22 @@ class Module(dict):
Args: Args:
mode (bool): Indicate if the model should be in training or mode (bool): Indicate if the model should be in training or
evaluation mode. Default: ``True``. evaluation mode. Default: ``True``.
Returns:
The module instance after updating the training mode.
""" """
def _set_train(_, m): def _set_train(_, m):
m._training = mode m._training = mode
self.apply_to_modules(_set_train) self.apply_to_modules(_set_train)
return self
def eval(self): def eval(self) -> "Module":
"""Set the model to evaluation mode. """Set the model to evaluation mode.
See :func:`train`. See :func:`train`.
""" """
self.train(False) return self.train(False)
def set_dtype( def set_dtype(
self, self,

View File

@ -162,6 +162,16 @@ class TestBase(mlx_tests.MLXTestCase):
m.state["hello"] = "world" m.state["hello"] = "world"
self.assertEqual(m.state["hello"], "world") self.assertEqual(m.state["hello"], "world")
def test_chaining(self):
m = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1))
pre_freeze_num_params = len(m.parameters())
m.freeze().unfreeze()
self.assertEqual(len(m.parameters()), pre_freeze_num_params)
params_dict = m.parameters()
self.assertFalse(m.update(params_dict).eval()._training)
self.assertTrue(m.train()._training)
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):