mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
* add chaining support for some of the functionalities of "nn.Module" * reformat * change the return types * remove return types * add return type with forward referencing * add tests for chaining * add name to contributors * Update python/mlx/nn/layers/base.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/nn/layers/base.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * update docstring * update docstrings --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
f30b659291
commit
d611251502
@ -15,6 +15,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||||
|
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`.
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
</a>
|
</a>
|
||||||
|
@ -151,7 +151,7 @@ class Module(dict):
|
|||||||
self,
|
self,
|
||||||
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
):
|
) -> "Module":
|
||||||
"""
|
"""
|
||||||
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
|
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
|
||||||
|
|
||||||
@ -164,6 +164,9 @@ class Module(dict):
|
|||||||
only the weights actually contained in the model are loaded and
|
only the weights actually contained in the model are loaded and
|
||||||
shapes are not checked. Default: ``True``.
|
shapes are not checked. Default: ``True``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The module instance after updating the weights.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -223,6 +226,7 @@ class Module(dict):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.update(tree_unflatten(weights))
|
self.update(tree_unflatten(weights))
|
||||||
|
return self
|
||||||
|
|
||||||
def save_weights(self, file: str):
|
def save_weights(self, file: str):
|
||||||
"""
|
"""
|
||||||
@ -319,7 +323,7 @@ class Module(dict):
|
|||||||
|
|
||||||
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
||||||
|
|
||||||
def update(self, parameters: dict):
|
def update(self, parameters: dict) -> "Module":
|
||||||
"""Replace the parameters of this Module with the provided ones in the
|
"""Replace the parameters of this Module with the provided ones in the
|
||||||
dict of dicts and lists.
|
dict of dicts and lists.
|
||||||
|
|
||||||
@ -334,6 +338,8 @@ class Module(dict):
|
|||||||
Args:
|
Args:
|
||||||
parameters (dict): A complete or partial dictionary of the modules
|
parameters (dict): A complete or partial dictionary of the modules
|
||||||
parameters.
|
parameters.
|
||||||
|
Returns:
|
||||||
|
The module instance after updating the parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def apply(dst, parameters):
|
def apply(dst, parameters):
|
||||||
@ -360,12 +366,13 @@ class Module(dict):
|
|||||||
apply(current_value, new_value)
|
apply(current_value, new_value)
|
||||||
|
|
||||||
apply(self, parameters)
|
apply(self, parameters)
|
||||||
|
return self
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
map_fn: Callable[[mx.array], mx.array],
|
map_fn: Callable[[mx.array], mx.array],
|
||||||
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||||
):
|
) -> "Module":
|
||||||
"""Map all the parameters using the provided ``map_fn`` and immediately
|
"""Map all the parameters using the provided ``map_fn`` and immediately
|
||||||
update the module with the mapped parameters.
|
update the module with the mapped parameters.
|
||||||
|
|
||||||
@ -376,11 +383,15 @@ class Module(dict):
|
|||||||
map_fn (Callable): Maps an array to another array
|
map_fn (Callable): Maps an array to another array
|
||||||
filter_fn (Callable, optional): Filter to select which arrays to
|
filter_fn (Callable, optional): Filter to select which arrays to
|
||||||
map (default: :meth:`Module.valid_parameter_filter`).
|
map (default: :meth:`Module.valid_parameter_filter`).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The module instance after updating the parameters.
|
||||||
"""
|
"""
|
||||||
filter_fn = filter_fn or Module.valid_parameter_filter
|
filter_fn = filter_fn or Module.valid_parameter_filter
|
||||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||||
|
return self
|
||||||
|
|
||||||
def update_modules(self, modules: dict):
|
def update_modules(self, modules: dict) -> "Module":
|
||||||
"""Replace the child modules of this :class:`Module` instance with the
|
"""Replace the child modules of this :class:`Module` instance with the
|
||||||
provided ones in the dict of dicts and lists.
|
provided ones in the dict of dicts and lists.
|
||||||
|
|
||||||
@ -395,6 +406,8 @@ class Module(dict):
|
|||||||
Args:
|
Args:
|
||||||
modules (dict): A complete or partial dictionary of the modules
|
modules (dict): A complete or partial dictionary of the modules
|
||||||
submodules.
|
submodules.
|
||||||
|
Returns:
|
||||||
|
The module instance after updating the submodules.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def apply(dst, modules):
|
def apply(dst, modules):
|
||||||
@ -417,13 +430,19 @@ class Module(dict):
|
|||||||
apply(current_value, new_value)
|
apply(current_value, new_value)
|
||||||
|
|
||||||
apply(self, modules)
|
apply(self, modules)
|
||||||
|
return self
|
||||||
|
|
||||||
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
|
def apply_to_modules(
|
||||||
|
self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]
|
||||||
|
) -> "Module":
|
||||||
"""Apply a function to all the modules in this instance (including this
|
"""Apply a function to all the modules in this instance (including this
|
||||||
instance).
|
instance).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
apply_fn (Callable): The function to apply to the modules.
|
apply_fn (Callable): The function to apply to the modules.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The module instance after updating submodules.
|
||||||
"""
|
"""
|
||||||
module_stack = [("", self)]
|
module_stack = [("", self)]
|
||||||
while module_stack:
|
while module_stack:
|
||||||
@ -433,6 +452,7 @@ class Module(dict):
|
|||||||
module_stack.extend(
|
module_stack.extend(
|
||||||
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
|
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
|
||||||
)
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
def modules(self):
|
def modules(self):
|
||||||
"""Return a list with all the modules in this instance.
|
"""Return a list with all the modules in this instance.
|
||||||
@ -469,7 +489,7 @@ class Module(dict):
|
|||||||
recurse: bool = True,
|
recurse: bool = True,
|
||||||
keys: Optional[Union[str, List[str]]] = None,
|
keys: Optional[Union[str, List[str]]] = None,
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
):
|
) -> "Module":
|
||||||
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
||||||
computing gradients for it.
|
computing gradients for it.
|
||||||
|
|
||||||
@ -493,6 +513,9 @@ class Module(dict):
|
|||||||
``module.freeze(keys="bias")``.
|
``module.freeze(keys="bias")``.
|
||||||
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
||||||
Default: ``False``.
|
Default: ``False``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The module instance after freezing the parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _freeze_impl(_, m):
|
def _freeze_impl(_, m):
|
||||||
@ -513,6 +536,7 @@ class Module(dict):
|
|||||||
self.apply_to_modules(_freeze_impl)
|
self.apply_to_modules(_freeze_impl)
|
||||||
else:
|
else:
|
||||||
_freeze_impl("", self)
|
_freeze_impl("", self)
|
||||||
|
return self
|
||||||
|
|
||||||
def unfreeze(
|
def unfreeze(
|
||||||
self,
|
self,
|
||||||
@ -520,7 +544,7 @@ class Module(dict):
|
|||||||
recurse: bool = True,
|
recurse: bool = True,
|
||||||
keys: Optional[Union[str, List[str]]] = None,
|
keys: Optional[Union[str, List[str]]] = None,
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
):
|
) -> "Module":
|
||||||
"""Unfreeze the Module's parameters or some of them.
|
"""Unfreeze the Module's parameters or some of them.
|
||||||
|
|
||||||
This function is idempotent ie unfreezing a model that is not frozen is
|
This function is idempotent ie unfreezing a model that is not frozen is
|
||||||
@ -545,6 +569,9 @@ class Module(dict):
|
|||||||
``module.unfreeze(keys="bias")``.
|
``module.unfreeze(keys="bias")``.
|
||||||
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
||||||
Default: ``False``.
|
Default: ``False``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The module instance after unfreezing the parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _unfreeze_impl(_, m):
|
def _unfreeze_impl(_, m):
|
||||||
@ -559,8 +586,9 @@ class Module(dict):
|
|||||||
self.apply_to_modules(_unfreeze_impl)
|
self.apply_to_modules(_unfreeze_impl)
|
||||||
else:
|
else:
|
||||||
_unfreeze_impl("", self)
|
_unfreeze_impl("", self)
|
||||||
|
return self
|
||||||
|
|
||||||
def train(self, mode: bool = True):
|
def train(self, mode: bool = True) -> "Module":
|
||||||
"""Set the model in or out of training mode.
|
"""Set the model in or out of training mode.
|
||||||
|
|
||||||
Training mode only applies to certain layers. For example
|
Training mode only applies to certain layers. For example
|
||||||
@ -570,19 +598,22 @@ class Module(dict):
|
|||||||
Args:
|
Args:
|
||||||
mode (bool): Indicate if the model should be in training or
|
mode (bool): Indicate if the model should be in training or
|
||||||
evaluation mode. Default: ``True``.
|
evaluation mode. Default: ``True``.
|
||||||
|
Returns:
|
||||||
|
The module instance after updating the training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _set_train(_, m):
|
def _set_train(_, m):
|
||||||
m._training = mode
|
m._training = mode
|
||||||
|
|
||||||
self.apply_to_modules(_set_train)
|
self.apply_to_modules(_set_train)
|
||||||
|
return self
|
||||||
|
|
||||||
def eval(self):
|
def eval(self) -> "Module":
|
||||||
"""Set the model to evaluation mode.
|
"""Set the model to evaluation mode.
|
||||||
|
|
||||||
See :func:`train`.
|
See :func:`train`.
|
||||||
"""
|
"""
|
||||||
self.train(False)
|
return self.train(False)
|
||||||
|
|
||||||
def set_dtype(
|
def set_dtype(
|
||||||
self,
|
self,
|
||||||
|
@ -162,6 +162,16 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
m.state["hello"] = "world"
|
m.state["hello"] = "world"
|
||||||
self.assertEqual(m.state["hello"], "world")
|
self.assertEqual(m.state["hello"], "world")
|
||||||
|
|
||||||
|
def test_chaining(self):
|
||||||
|
m = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1))
|
||||||
|
pre_freeze_num_params = len(m.parameters())
|
||||||
|
m.freeze().unfreeze()
|
||||||
|
self.assertEqual(len(m.parameters()), pre_freeze_num_params)
|
||||||
|
params_dict = m.parameters()
|
||||||
|
|
||||||
|
self.assertFalse(m.update(params_dict).eval()._training)
|
||||||
|
self.assertTrue(m.train()._training)
|
||||||
|
|
||||||
|
|
||||||
class TestLayers(mlx_tests.MLXTestCase):
|
class TestLayers(mlx_tests.MLXTestCase):
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user