Support Chaining for some of functionalities of nn.Module (#885) (#897)

* add chaining support for some of the functionalities of "nn.Module"

* reformat

* change the return types

* remove return types

* add return type with forward referencing

* add tests for chaining

* add name to contributors

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update docstring

* update docstrings

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
AmirHossein_Razlighi
2024-03-28 06:28:29 +03:30
committed by GitHub
parent f30b659291
commit d611251502
3 changed files with 52 additions and 10 deletions

View File

@@ -151,7 +151,7 @@ class Module(dict):
self,
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
strict: bool = True,
):
) -> "Module":
"""
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
@@ -164,6 +164,9 @@ class Module(dict):
only the weights actually contained in the model are loaded and
shapes are not checked. Default: ``True``.
Returns:
The module instance after updating the weights.
Example:
.. code-block:: python
@@ -223,6 +226,7 @@ class Module(dict):
)
self.update(tree_unflatten(weights))
return self
def save_weights(self, file: str):
"""
@@ -319,7 +323,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict):
def update(self, parameters: dict) -> "Module":
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.
@@ -334,6 +338,8 @@ class Module(dict):
Args:
parameters (dict): A complete or partial dictionary of the modules
parameters.
Returns:
The module instance after updating the parameters.
"""
def apply(dst, parameters):
@@ -360,12 +366,13 @@ class Module(dict):
apply(current_value, new_value)
apply(self, parameters)
return self
def apply(
self,
map_fn: Callable[[mx.array], mx.array],
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
):
) -> "Module":
"""Map all the parameters using the provided ``map_fn`` and immediately
update the module with the mapped parameters.
@@ -376,11 +383,15 @@ class Module(dict):
map_fn (Callable): Maps an array to another array
filter_fn (Callable, optional): Filter to select which arrays to
map (default: :meth:`Module.valid_parameter_filter`).
Returns:
The module instance after updating the parameters.
"""
filter_fn = filter_fn or Module.valid_parameter_filter
self.update(self.filter_and_map(filter_fn, map_fn))
return self
def update_modules(self, modules: dict):
def update_modules(self, modules: dict) -> "Module":
"""Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists.
@@ -395,6 +406,8 @@ class Module(dict):
Args:
modules (dict): A complete or partial dictionary of the modules
submodules.
Returns:
The module instance after updating the submodules.
"""
def apply(dst, modules):
@@ -417,13 +430,19 @@ class Module(dict):
apply(current_value, new_value)
apply(self, modules)
return self
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
def apply_to_modules(
self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]
) -> "Module":
"""Apply a function to all the modules in this instance (including this
instance).
Args:
apply_fn (Callable): The function to apply to the modules.
Returns:
The module instance after updating submodules.
"""
module_stack = [("", self)]
while module_stack:
@@ -433,6 +452,7 @@ class Module(dict):
module_stack.extend(
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
)
return self
def modules(self):
"""Return a list with all the modules in this instance.
@@ -469,7 +489,7 @@ class Module(dict):
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
):
) -> "Module":
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it.
@@ -493,6 +513,9 @@ class Module(dict):
``module.freeze(keys="bias")``.
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
Default: ``False``.
Returns:
The module instance after freezing the parameters.
"""
def _freeze_impl(_, m):
@@ -513,6 +536,7 @@ class Module(dict):
self.apply_to_modules(_freeze_impl)
else:
_freeze_impl("", self)
return self
def unfreeze(
self,
@@ -520,7 +544,7 @@ class Module(dict):
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
):
) -> "Module":
"""Unfreeze the Module's parameters or some of them.
This function is idempotent ie unfreezing a model that is not frozen is
@@ -545,6 +569,9 @@ class Module(dict):
``module.unfreeze(keys="bias")``.
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
Default: ``False``.
Returns:
The module instance after unfreezing the parameters.
"""
def _unfreeze_impl(_, m):
@@ -559,8 +586,9 @@ class Module(dict):
self.apply_to_modules(_unfreeze_impl)
else:
_unfreeze_impl("", self)
return self
def train(self, mode: bool = True):
def train(self, mode: bool = True) -> "Module":
"""Set the model in or out of training mode.
Training mode only applies to certain layers. For example
@@ -570,19 +598,22 @@ class Module(dict):
Args:
mode (bool): Indicate if the model should be in training or
evaluation mode. Default: ``True``.
Returns:
The module instance after updating the training mode.
"""
def _set_train(_, m):
m._training = mode
self.apply_to_modules(_set_train)
return self
def eval(self):
def eval(self) -> "Module":
"""Set the model to evaluation mode.
See :func:`train`.
"""
self.train(False)
return self.train(False)
def set_dtype(
self,