Support Chaining for some of functionalities of nn.Module (#885) (#897)

* add chaining support for some of the functionalities of "nn.Module"

* reformat

* change the return types

* remove return types

* add return type with forward referencing

* add tests for chaining

* add name to contributors

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update docstring

* update docstrings

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
AmirHossein_Razlighi 2024-03-28 06:28:29 +03:30 committed by GitHub
parent f30b659291
commit d611251502
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 52 additions and 10 deletions

View File

@ -15,6 +15,7 @@ MLX was developed with contributions from the following individuals:
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>

View File

@ -151,7 +151,7 @@ class Module(dict):
self,
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
strict: bool = True,
):
) -> "Module":
"""
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
@ -164,6 +164,9 @@ class Module(dict):
only the weights actually contained in the model are loaded and
shapes are not checked. Default: ``True``.
Returns:
The module instance after updating the weights.
Example:
.. code-block:: python
@ -223,6 +226,7 @@ class Module(dict):
)
self.update(tree_unflatten(weights))
return self
def save_weights(self, file: str):
"""
@ -319,7 +323,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict):
def update(self, parameters: dict) -> "Module":
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.
@ -334,6 +338,8 @@ class Module(dict):
Args:
parameters (dict): A complete or partial dictionary of the modules
parameters.
Returns:
The module instance after updating the parameters.
"""
def apply(dst, parameters):
@ -360,12 +366,13 @@ class Module(dict):
apply(current_value, new_value)
apply(self, parameters)
return self
def apply(
self,
map_fn: Callable[[mx.array], mx.array],
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
):
) -> "Module":
"""Map all the parameters using the provided ``map_fn`` and immediately
update the module with the mapped parameters.
@ -376,11 +383,15 @@ class Module(dict):
map_fn (Callable): Maps an array to another array
filter_fn (Callable, optional): Filter to select which arrays to
map (default: :meth:`Module.valid_parameter_filter`).
Returns:
The module instance after updating the parameters.
"""
filter_fn = filter_fn or Module.valid_parameter_filter
self.update(self.filter_and_map(filter_fn, map_fn))
return self
def update_modules(self, modules: dict):
def update_modules(self, modules: dict) -> "Module":
"""Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists.
@ -395,6 +406,8 @@ class Module(dict):
Args:
modules (dict): A complete or partial dictionary of the modules
submodules.
Returns:
The module instance after updating the submodules.
"""
def apply(dst, modules):
@ -417,13 +430,19 @@ class Module(dict):
apply(current_value, new_value)
apply(self, modules)
return self
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
def apply_to_modules(
self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]
) -> "Module":
"""Apply a function to all the modules in this instance (including this
instance).
Args:
apply_fn (Callable): The function to apply to the modules.
Returns:
The module instance after updating submodules.
"""
module_stack = [("", self)]
while module_stack:
@ -433,6 +452,7 @@ class Module(dict):
module_stack.extend(
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
)
return self
def modules(self):
"""Return a list with all the modules in this instance.
@ -469,7 +489,7 @@ class Module(dict):
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
):
) -> "Module":
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it.
@ -493,6 +513,9 @@ class Module(dict):
``module.freeze(keys="bias")``.
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
Default: ``False``.
Returns:
The module instance after freezing the parameters.
"""
def _freeze_impl(_, m):
@ -513,6 +536,7 @@ class Module(dict):
self.apply_to_modules(_freeze_impl)
else:
_freeze_impl("", self)
return self
def unfreeze(
self,
@ -520,7 +544,7 @@ class Module(dict):
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
):
) -> "Module":
"""Unfreeze the Module's parameters or some of them.
This function is idempotent ie unfreezing a model that is not frozen is
@ -545,6 +569,9 @@ class Module(dict):
``module.unfreeze(keys="bias")``.
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
Default: ``False``.
Returns:
The module instance after unfreezing the parameters.
"""
def _unfreeze_impl(_, m):
@ -559,8 +586,9 @@ class Module(dict):
self.apply_to_modules(_unfreeze_impl)
else:
_unfreeze_impl("", self)
return self
def train(self, mode: bool = True):
def train(self, mode: bool = True) -> "Module":
"""Set the model in or out of training mode.
Training mode only applies to certain layers. For example
@ -570,19 +598,22 @@ class Module(dict):
Args:
mode (bool): Indicate if the model should be in training or
evaluation mode. Default: ``True``.
Returns:
The module instance after updating the training mode.
"""
def _set_train(_, m):
m._training = mode
self.apply_to_modules(_set_train)
return self
def eval(self):
def eval(self) -> "Module":
"""Set the model to evaluation mode.
See :func:`train`.
"""
self.train(False)
return self.train(False)
def set_dtype(
self,

View File

@ -162,6 +162,16 @@ class TestBase(mlx_tests.MLXTestCase):
m.state["hello"] = "world"
self.assertEqual(m.state["hello"], "world")
def test_chaining(self):
m = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1))
pre_freeze_num_params = len(m.parameters())
m.freeze().unfreeze()
self.assertEqual(len(m.parameters()), pre_freeze_num_params)
params_dict = m.parameters()
self.assertFalse(m.update(params_dict).eval()._training)
self.assertTrue(m.train()._training)
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):