mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
* add chaining support for some of the functionalities of "nn.Module" * reformat * change the return types * remove return types * add return type with forward referencing * add tests for chaining * add name to contributors * Update python/mlx/nn/layers/base.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/nn/layers/base.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * update docstring * update docstrings --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
f30b659291
commit
d611251502
@ -15,6 +15,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`.
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
</a>
|
||||
|
@ -151,7 +151,7 @@ class Module(dict):
|
||||
self,
|
||||
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
||||
strict: bool = True,
|
||||
):
|
||||
) -> "Module":
|
||||
"""
|
||||
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
|
||||
|
||||
@ -164,6 +164,9 @@ class Module(dict):
|
||||
only the weights actually contained in the model are loaded and
|
||||
shapes are not checked. Default: ``True``.
|
||||
|
||||
Returns:
|
||||
The module instance after updating the weights.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
@ -223,6 +226,7 @@ class Module(dict):
|
||||
)
|
||||
|
||||
self.update(tree_unflatten(weights))
|
||||
return self
|
||||
|
||||
def save_weights(self, file: str):
|
||||
"""
|
||||
@ -319,7 +323,7 @@ class Module(dict):
|
||||
|
||||
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
||||
|
||||
def update(self, parameters: dict):
|
||||
def update(self, parameters: dict) -> "Module":
|
||||
"""Replace the parameters of this Module with the provided ones in the
|
||||
dict of dicts and lists.
|
||||
|
||||
@ -334,6 +338,8 @@ class Module(dict):
|
||||
Args:
|
||||
parameters (dict): A complete or partial dictionary of the modules
|
||||
parameters.
|
||||
Returns:
|
||||
The module instance after updating the parameters.
|
||||
"""
|
||||
|
||||
def apply(dst, parameters):
|
||||
@ -360,12 +366,13 @@ class Module(dict):
|
||||
apply(current_value, new_value)
|
||||
|
||||
apply(self, parameters)
|
||||
return self
|
||||
|
||||
def apply(
|
||||
self,
|
||||
map_fn: Callable[[mx.array], mx.array],
|
||||
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||
):
|
||||
) -> "Module":
|
||||
"""Map all the parameters using the provided ``map_fn`` and immediately
|
||||
update the module with the mapped parameters.
|
||||
|
||||
@ -376,11 +383,15 @@ class Module(dict):
|
||||
map_fn (Callable): Maps an array to another array
|
||||
filter_fn (Callable, optional): Filter to select which arrays to
|
||||
map (default: :meth:`Module.valid_parameter_filter`).
|
||||
|
||||
Returns:
|
||||
The module instance after updating the parameters.
|
||||
"""
|
||||
filter_fn = filter_fn or Module.valid_parameter_filter
|
||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||
return self
|
||||
|
||||
def update_modules(self, modules: dict):
|
||||
def update_modules(self, modules: dict) -> "Module":
|
||||
"""Replace the child modules of this :class:`Module` instance with the
|
||||
provided ones in the dict of dicts and lists.
|
||||
|
||||
@ -395,6 +406,8 @@ class Module(dict):
|
||||
Args:
|
||||
modules (dict): A complete or partial dictionary of the modules
|
||||
submodules.
|
||||
Returns:
|
||||
The module instance after updating the submodules.
|
||||
"""
|
||||
|
||||
def apply(dst, modules):
|
||||
@ -417,13 +430,19 @@ class Module(dict):
|
||||
apply(current_value, new_value)
|
||||
|
||||
apply(self, modules)
|
||||
return self
|
||||
|
||||
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
|
||||
def apply_to_modules(
|
||||
self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]
|
||||
) -> "Module":
|
||||
"""Apply a function to all the modules in this instance (including this
|
||||
instance).
|
||||
|
||||
Args:
|
||||
apply_fn (Callable): The function to apply to the modules.
|
||||
|
||||
Returns:
|
||||
The module instance after updating submodules.
|
||||
"""
|
||||
module_stack = [("", self)]
|
||||
while module_stack:
|
||||
@ -433,6 +452,7 @@ class Module(dict):
|
||||
module_stack.extend(
|
||||
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
|
||||
)
|
||||
return self
|
||||
|
||||
def modules(self):
|
||||
"""Return a list with all the modules in this instance.
|
||||
@ -469,7 +489,7 @@ class Module(dict):
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
):
|
||||
) -> "Module":
|
||||
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
||||
computing gradients for it.
|
||||
|
||||
@ -493,6 +513,9 @@ class Module(dict):
|
||||
``module.freeze(keys="bias")``.
|
||||
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
The module instance after freezing the parameters.
|
||||
"""
|
||||
|
||||
def _freeze_impl(_, m):
|
||||
@ -513,6 +536,7 @@ class Module(dict):
|
||||
self.apply_to_modules(_freeze_impl)
|
||||
else:
|
||||
_freeze_impl("", self)
|
||||
return self
|
||||
|
||||
def unfreeze(
|
||||
self,
|
||||
@ -520,7 +544,7 @@ class Module(dict):
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
):
|
||||
) -> "Module":
|
||||
"""Unfreeze the Module's parameters or some of them.
|
||||
|
||||
This function is idempotent ie unfreezing a model that is not frozen is
|
||||
@ -545,6 +569,9 @@ class Module(dict):
|
||||
``module.unfreeze(keys="bias")``.
|
||||
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
The module instance after unfreezing the parameters.
|
||||
"""
|
||||
|
||||
def _unfreeze_impl(_, m):
|
||||
@ -559,8 +586,9 @@ class Module(dict):
|
||||
self.apply_to_modules(_unfreeze_impl)
|
||||
else:
|
||||
_unfreeze_impl("", self)
|
||||
return self
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
def train(self, mode: bool = True) -> "Module":
|
||||
"""Set the model in or out of training mode.
|
||||
|
||||
Training mode only applies to certain layers. For example
|
||||
@ -570,19 +598,22 @@ class Module(dict):
|
||||
Args:
|
||||
mode (bool): Indicate if the model should be in training or
|
||||
evaluation mode. Default: ``True``.
|
||||
Returns:
|
||||
The module instance after updating the training mode.
|
||||
"""
|
||||
|
||||
def _set_train(_, m):
|
||||
m._training = mode
|
||||
|
||||
self.apply_to_modules(_set_train)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
def eval(self) -> "Module":
|
||||
"""Set the model to evaluation mode.
|
||||
|
||||
See :func:`train`.
|
||||
"""
|
||||
self.train(False)
|
||||
return self.train(False)
|
||||
|
||||
def set_dtype(
|
||||
self,
|
||||
|
@ -162,6 +162,16 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
m.state["hello"] = "world"
|
||||
self.assertEqual(m.state["hello"], "world")
|
||||
|
||||
def test_chaining(self):
|
||||
m = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1))
|
||||
pre_freeze_num_params = len(m.parameters())
|
||||
m.freeze().unfreeze()
|
||||
self.assertEqual(len(m.parameters()), pre_freeze_num_params)
|
||||
params_dict = m.parameters()
|
||||
|
||||
self.assertFalse(m.update(params_dict).eval()._training)
|
||||
self.assertTrue(m.train()._training)
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
Loading…
Reference in New Issue
Block a user