Module checks the weight on load_weights (#337)

* update module to check weights on load, also fix docs and reorganize tests

* nits + rebase

* a few more docs updates for Module

* use manual module file

* comment
This commit is contained in:
Awni Hannun
2024-01-02 18:55:42 -08:00
committed by GitHub
parent 0782a4573a
commit dff4a3833f
6 changed files with 581 additions and 360 deletions

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
import textwrap
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Tuple, Union
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
@@ -61,6 +61,7 @@ class Module(dict):
@property
def training(self):
"""Boolean indicating if the model is in training mode."""
return self._training
def _extra_repr(self):
@@ -87,15 +88,83 @@ class Module(dict):
def __setattr__(self, key: str, val: Any):
self[key] = val
def load_weights(self, file: str):
def load_weights(
self,
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
strict: bool = True,
):
"""
Load and update the model's weights from a `.npz` file.
Update the model's weights from a ``.npz`` or a list.
Args:
file_or_weights (str or list(tuple(str, mx.array))): The path to
the weights ``.npz`` file or a list of pairs of parameter names
and arrays.
strict (bool, optional): If ``True`` then checks that the provided
weights exactly match the parameters of the model. Otherwise,
only the weights actually contained in the model are loaded and
shapes are not checked. Default: ``True``.
Example:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
model = nn.Linear(10, 10)
# Load from file
model.load_weights("weights.npz")
# Load from list
weights = [
("weight", mx.random.uniform(shape=(10, 10))),
("bias", mx.zeros((10,))),
]
model.load_weights(weights)
# Missing weight
weights = [
("weight", mx.random.uniform(shape=(10, 10))),
]
# Raises a ValueError exception
model.load_weights(weights)
# Ok, only updates the weight but not the bias
model.load_weights(weights, strict=False)
"""
self.update(tree_unflatten(list(mx.load(file).items())))
weights = file_or_weights
if isinstance(weights, str):
weights = list(mx.load(weights).items())
if strict:
new_weights = dict(weights)
curr_weights = dict(tree_flatten(self.parameters()))
if extras := (new_weights.keys() - curr_weights.keys()):
extras = " ".join(extras)
raise ValueError(f"Received parameters not in model: {extras}.")
if missing := (curr_weights.keys() - new_weights.keys()):
missing = " ".join(missing)
raise ValueError(f"Missing parameters: {missing}.")
for k, v in curr_weights.items():
v_new = new_weights[k]
if not isinstance(v_new, mx.array):
raise ValueError(
"Expected mx.array but received "
f"{type(v_new)} for parameter {k}"
)
if v_new.shape != v.shape:
raise ValueError(
f"Expected shape {v.shape} but received "
f" shape {v_new.shape} for parameter {k}"
)
self.update(tree_unflatten(weights))
def save_weights(self, file: str):
"""
Save the model's weights to a `.npz` file.
Save the model's weights to a ``.npz`` file.
"""
mx.savez(file, **dict(tree_flatten(self.parameters())))
@@ -351,23 +420,26 @@ class Module(dict):
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it.
This function is idempotent ie freezing a frozen model is a noop.
This function is idempotent i.e. freezing a frozen model is a no-op.
For instance to only train the attention parameters from a transformer:
Example:
For instance to only train the attention parameters from a Transformer:
model = ...
model.freeze()
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
.. code-block:: python
model = nn.Transformer()
model.freeze()
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
Args:
recurse (bool, optional): If True then freeze the parameters of the
submodules as well (default: True).
submodules as well. Default: ``True``.
keys (str or list[str], optional): If provided then only these
parameters will be frozen otherwise all the parameters of a
module. For instance freeze all biases by calling
``module.freeze(keys="bias")``.
strict (bool, optional): If set to True validate that the passed keys exist
(default: False).
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
Default: ``False``.
"""
def _freeze_impl(_, m):
@@ -401,21 +473,25 @@ class Module(dict):
This function is idempotent ie unfreezing a model that is not frozen is
a noop.
For instance to only train the biases one can do:
Example:
model = ...
model.freeze()
model.unfreeze(keys="bias")
For instance to only train the biases of a Transformer one can do:
.. code-block:: python
model = nn.Transformer()
model.freeze()
model.unfreeze(keys="bias")
Args:
recurse (bool, optional): If True then unfreeze the parameters of the
submodules as well (default: True).
submodules as well. Default: ``True``.
keys (str or list[str], optional): If provided then only these
parameters will be unfrozen otherwise all the parameters of a
module. For instance unfreeze all biases by calling
``module.unfreeze(keys="bias")``.
strict (bool, optional): If set to True validate that the passed keys exist
(default: False).
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
Default: ``False``.
"""
def _unfreeze_impl(_, m):
@@ -432,10 +508,25 @@ class Module(dict):
_unfreeze_impl("", self)
def train(self, mode: bool = True):
"""Set the model in or out of training mode.
Training mode only applies to certain layers. For example
:obj:`Dropout` applies a random mask in training mode, but is the
identity in evaluation mode.
Args:
mode (bool): Indicate if the model should be in training or
evaluation mode. Default: ``True``.
"""
def _set_train(_, m):
m._training = mode
self.apply_to_modules(_set_train)
def eval(self):
"""Set the model to evaluation mode.
See :func:`train`.
"""
self.train(False)