mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Module checks the weight on load_weights
(#337)
* update module to check weights on load, also fix docs and reorganize tests * nits + rebase * a few more docs updates for Module * use manual module file * comment
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import textwrap
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
@@ -61,6 +61,7 @@ class Module(dict):
|
||||
|
||||
@property
|
||||
def training(self):
|
||||
"""Boolean indicating if the model is in training mode."""
|
||||
return self._training
|
||||
|
||||
def _extra_repr(self):
|
||||
@@ -87,15 +88,83 @@ class Module(dict):
|
||||
def __setattr__(self, key: str, val: Any):
|
||||
self[key] = val
|
||||
|
||||
def load_weights(self, file: str):
|
||||
def load_weights(
|
||||
self,
|
||||
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
||||
strict: bool = True,
|
||||
):
|
||||
"""
|
||||
Load and update the model's weights from a `.npz` file.
|
||||
Update the model's weights from a ``.npz`` or a list.
|
||||
|
||||
Args:
|
||||
file_or_weights (str or list(tuple(str, mx.array))): The path to
|
||||
the weights ``.npz`` file or a list of pairs of parameter names
|
||||
and arrays.
|
||||
strict (bool, optional): If ``True`` then checks that the provided
|
||||
weights exactly match the parameters of the model. Otherwise,
|
||||
only the weights actually contained in the model are loaded and
|
||||
shapes are not checked. Default: ``True``.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
model = nn.Linear(10, 10)
|
||||
|
||||
# Load from file
|
||||
model.load_weights("weights.npz")
|
||||
|
||||
# Load from list
|
||||
weights = [
|
||||
("weight", mx.random.uniform(shape=(10, 10))),
|
||||
("bias", mx.zeros((10,))),
|
||||
]
|
||||
model.load_weights(weights)
|
||||
|
||||
# Missing weight
|
||||
weights = [
|
||||
("weight", mx.random.uniform(shape=(10, 10))),
|
||||
]
|
||||
|
||||
# Raises a ValueError exception
|
||||
model.load_weights(weights)
|
||||
|
||||
# Ok, only updates the weight but not the bias
|
||||
model.load_weights(weights, strict=False)
|
||||
"""
|
||||
self.update(tree_unflatten(list(mx.load(file).items())))
|
||||
weights = file_or_weights
|
||||
if isinstance(weights, str):
|
||||
weights = list(mx.load(weights).items())
|
||||
|
||||
if strict:
|
||||
new_weights = dict(weights)
|
||||
curr_weights = dict(tree_flatten(self.parameters()))
|
||||
if extras := (new_weights.keys() - curr_weights.keys()):
|
||||
extras = " ".join(extras)
|
||||
raise ValueError(f"Received parameters not in model: {extras}.")
|
||||
if missing := (curr_weights.keys() - new_weights.keys()):
|
||||
missing = " ".join(missing)
|
||||
raise ValueError(f"Missing parameters: {missing}.")
|
||||
for k, v in curr_weights.items():
|
||||
v_new = new_weights[k]
|
||||
if not isinstance(v_new, mx.array):
|
||||
raise ValueError(
|
||||
"Expected mx.array but received "
|
||||
f"{type(v_new)} for parameter {k}"
|
||||
)
|
||||
if v_new.shape != v.shape:
|
||||
raise ValueError(
|
||||
f"Expected shape {v.shape} but received "
|
||||
f" shape {v_new.shape} for parameter {k}"
|
||||
)
|
||||
|
||||
self.update(tree_unflatten(weights))
|
||||
|
||||
def save_weights(self, file: str):
|
||||
"""
|
||||
Save the model's weights to a `.npz` file.
|
||||
Save the model's weights to a ``.npz`` file.
|
||||
"""
|
||||
mx.savez(file, **dict(tree_flatten(self.parameters())))
|
||||
|
||||
@@ -351,23 +420,26 @@ class Module(dict):
|
||||
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
||||
computing gradients for it.
|
||||
|
||||
This function is idempotent ie freezing a frozen model is a noop.
|
||||
This function is idempotent i.e. freezing a frozen model is a no-op.
|
||||
|
||||
For instance to only train the attention parameters from a transformer:
|
||||
Example:
|
||||
For instance to only train the attention parameters from a Transformer:
|
||||
|
||||
model = ...
|
||||
model.freeze()
|
||||
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
|
||||
.. code-block:: python
|
||||
|
||||
model = nn.Transformer()
|
||||
model.freeze()
|
||||
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
|
||||
|
||||
Args:
|
||||
recurse (bool, optional): If True then freeze the parameters of the
|
||||
submodules as well (default: True).
|
||||
submodules as well. Default: ``True``.
|
||||
keys (str or list[str], optional): If provided then only these
|
||||
parameters will be frozen otherwise all the parameters of a
|
||||
module. For instance freeze all biases by calling
|
||||
``module.freeze(keys="bias")``.
|
||||
strict (bool, optional): If set to True validate that the passed keys exist
|
||||
(default: False).
|
||||
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
||||
Default: ``False``.
|
||||
"""
|
||||
|
||||
def _freeze_impl(_, m):
|
||||
@@ -401,21 +473,25 @@ class Module(dict):
|
||||
This function is idempotent ie unfreezing a model that is not frozen is
|
||||
a noop.
|
||||
|
||||
For instance to only train the biases one can do:
|
||||
Example:
|
||||
|
||||
model = ...
|
||||
model.freeze()
|
||||
model.unfreeze(keys="bias")
|
||||
For instance to only train the biases of a Transformer one can do:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = nn.Transformer()
|
||||
model.freeze()
|
||||
model.unfreeze(keys="bias")
|
||||
|
||||
Args:
|
||||
recurse (bool, optional): If True then unfreeze the parameters of the
|
||||
submodules as well (default: True).
|
||||
submodules as well. Default: ``True``.
|
||||
keys (str or list[str], optional): If provided then only these
|
||||
parameters will be unfrozen otherwise all the parameters of a
|
||||
module. For instance unfreeze all biases by calling
|
||||
``module.unfreeze(keys="bias")``.
|
||||
strict (bool, optional): If set to True validate that the passed keys exist
|
||||
(default: False).
|
||||
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
||||
Default: ``False``.
|
||||
"""
|
||||
|
||||
def _unfreeze_impl(_, m):
|
||||
@@ -432,10 +508,25 @@ class Module(dict):
|
||||
_unfreeze_impl("", self)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set the model in or out of training mode.
|
||||
|
||||
Training mode only applies to certain layers. For example
|
||||
:obj:`Dropout` applies a random mask in training mode, but is the
|
||||
identity in evaluation mode.
|
||||
|
||||
Args:
|
||||
mode (bool): Indicate if the model should be in training or
|
||||
evaluation mode. Default: ``True``.
|
||||
"""
|
||||
|
||||
def _set_train(_, m):
|
||||
m._training = mode
|
||||
|
||||
self.apply_to_modules(_set_train)
|
||||
|
||||
def eval(self):
|
||||
"""Set the model to evaluation mode.
|
||||
|
||||
See :func:`train`.
|
||||
"""
|
||||
self.train(False)
|
||||
|
Reference in New Issue
Block a user