mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00

* Added support for python copy * precommit changes * removed `_compiled_call_impl` line * added tests and suggested changes * ACK changes
539 lines
19 KiB
Python
539 lines
19 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import textwrap
|
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
|
|
|
import mlx.core as mx
|
|
from mlx.utils import tree_flatten, tree_unflatten
|
|
|
|
|
|
class Module(dict):
|
|
"""Base class for building neural networks with MLX.
|
|
|
|
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
|
|
your models should do the same.
|
|
|
|
A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`
|
|
instances in arbitrary nesting of python lists or dicts. The ``Module``
|
|
then allows recursively extracting all the :class:`mlx.core.array` instances
|
|
using :meth:`mlx.nn.Module.parameters`.
|
|
|
|
In addition, the ``Module`` has the concept of trainable and non trainable
|
|
parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad`
|
|
the gradients are returned only with respect to the trainable parameters.
|
|
All arrays in a module are trainable unless they are added in the "frozen"
|
|
set by calling :meth:`freeze`.
|
|
|
|
.. code-block:: python
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
class MyMLP(nn.Module):
|
|
def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
|
|
super().__init__()
|
|
|
|
self.in_proj = nn.Linear(in_dims, hidden_dims)
|
|
self.out_proj = nn.Linear(hidden_dims, out_dims)
|
|
|
|
def __call__(self, x):
|
|
x = self.in_proj(x)
|
|
x = mx.maximum(x, 0)
|
|
return self.out_proj(x)
|
|
|
|
model = MyMLP(2, 1)
|
|
|
|
# All the model parameters are created but since MLX is lazy by
|
|
# default, they are not evaluated yet. Calling `mx.eval` actually
|
|
# allocates memory and initializes the parameters.
|
|
mx.eval(model.parameters())
|
|
|
|
# Setting a parameter to a new value is as simply as accessing that
|
|
# parameter and assigning a new array to it.
|
|
model.in_proj.weight = model.in_proj.weight * 2
|
|
mx.eval(model.parameters())
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Should be called by the subclasses of ``Module``."""
|
|
self._no_grad = set()
|
|
self._training = True
|
|
|
|
@property
|
|
def training(self):
|
|
"""Boolean indicating if the model is in training mode."""
|
|
return self._training
|
|
|
|
def _extra_repr(self):
|
|
return ""
|
|
|
|
def __repr__(self):
|
|
children = tree_flatten(self.children(), is_leaf=self.is_module)
|
|
value = f"{type(self).__name__}({self._extra_repr()}"
|
|
for k, v in children:
|
|
value += "\n"
|
|
value += textwrap.indent(f"({k}): {repr(v)}", prefix=" ")
|
|
if children:
|
|
value += "\n"
|
|
value += ")"
|
|
|
|
return value
|
|
|
|
def __getattr__(self, key: str):
|
|
if key in self:
|
|
return self[key]
|
|
else:
|
|
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
|
|
|
|
def __setattr__(self, key: str, val: Any):
|
|
self[key] = val
|
|
|
|
def load_weights(
|
|
self,
|
|
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
|
strict: bool = True,
|
|
):
|
|
"""
|
|
Update the model's weights from a ``.npz`` or a list.
|
|
|
|
Args:
|
|
file_or_weights (str or list(tuple(str, mx.array))): The path to
|
|
the weights ``.npz`` file or a list of pairs of parameter names
|
|
and arrays.
|
|
strict (bool, optional): If ``True`` then checks that the provided
|
|
weights exactly match the parameters of the model. Otherwise,
|
|
only the weights actually contained in the model are loaded and
|
|
shapes are not checked. Default: ``True``.
|
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
model = nn.Linear(10, 10)
|
|
|
|
# Load from file
|
|
model.load_weights("weights.npz")
|
|
|
|
# Load from list
|
|
weights = [
|
|
("weight", mx.random.uniform(shape=(10, 10))),
|
|
("bias", mx.zeros((10,))),
|
|
]
|
|
model.load_weights(weights)
|
|
|
|
# Missing weight
|
|
weights = [
|
|
("weight", mx.random.uniform(shape=(10, 10))),
|
|
]
|
|
|
|
# Raises a ValueError exception
|
|
model.load_weights(weights)
|
|
|
|
# Ok, only updates the weight but not the bias
|
|
model.load_weights(weights, strict=False)
|
|
"""
|
|
weights = file_or_weights
|
|
if isinstance(weights, str):
|
|
weights = list(mx.load(weights).items())
|
|
|
|
if strict:
|
|
new_weights = dict(weights)
|
|
curr_weights = dict(tree_flatten(self.parameters()))
|
|
if extras := (new_weights.keys() - curr_weights.keys()):
|
|
extras = " ".join(extras)
|
|
raise ValueError(f"Received parameters not in model: {extras}.")
|
|
if missing := (curr_weights.keys() - new_weights.keys()):
|
|
missing = " ".join(missing)
|
|
raise ValueError(f"Missing parameters: {missing}.")
|
|
for k, v in curr_weights.items():
|
|
v_new = new_weights[k]
|
|
if not isinstance(v_new, mx.array):
|
|
raise ValueError(
|
|
"Expected mx.array but received "
|
|
f"{type(v_new)} for parameter {k}"
|
|
)
|
|
if v_new.shape != v.shape:
|
|
raise ValueError(
|
|
f"Expected shape {v.shape} but received "
|
|
f" shape {v_new.shape} for parameter {k}"
|
|
)
|
|
|
|
self.update(tree_unflatten(weights))
|
|
|
|
def save_weights(self, file: str):
|
|
"""
|
|
Save the model's weights to a ``.npz`` file.
|
|
"""
|
|
mx.savez(file, **dict(tree_flatten(self.parameters())))
|
|
|
|
@staticmethod
|
|
def is_module(value):
|
|
return isinstance(value, Module)
|
|
|
|
@staticmethod
|
|
def valid_child_filter(module, key, value):
|
|
return isinstance(value, (dict, list))
|
|
|
|
@staticmethod
|
|
def valid_parameter_filter(module, key, value):
|
|
return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
|
|
|
|
@staticmethod
|
|
def trainable_parameter_filter(module, key, value):
|
|
return (
|
|
Module.valid_parameter_filter(module, key, value)
|
|
and key not in module._no_grad
|
|
)
|
|
|
|
def filter_and_map(
|
|
self,
|
|
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
|
|
map_fn: Optional[Callable] = None,
|
|
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
|
):
|
|
"""Recursively filter the contents of the module using ``filter_fn``,
|
|
namely only select keys and values where ``filter_fn`` returns true.
|
|
|
|
This is used to implement :meth:`parameters` and :meth:`trainable_parameters`
|
|
but it can also be used to extract any subset of the module's parameters.
|
|
|
|
Args:
|
|
filter_fn (Callable): Given a value, the key in which it is found
|
|
and the containing module, decide whether to keep the value or
|
|
drop it.
|
|
map_fn (Callable, optional): Optionally transform the value before
|
|
returning it.
|
|
is_leaf_fn (Callable, optional): Given a value, the key in which it
|
|
is found and the containing module decide if it is a leaf.
|
|
|
|
Returns:
|
|
A dictionary containing the contents of the module recursively filtered
|
|
"""
|
|
|
|
map_fn = map_fn or (lambda x: x)
|
|
is_leaf_fn = is_leaf_fn or (
|
|
lambda m, k, v: not isinstance(v, (Module, dict, list))
|
|
)
|
|
|
|
def unwrap(vk, v):
|
|
if is_leaf_fn(self, vk, v):
|
|
return map_fn(v)
|
|
|
|
if isinstance(v, Module):
|
|
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
|
|
|
|
if isinstance(v, dict):
|
|
nd = {}
|
|
for k, v in v.items():
|
|
tk = f"{vk}.{k}"
|
|
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
|
|
return nd
|
|
|
|
if isinstance(v, list):
|
|
nl = []
|
|
for i, vi in enumerate(v):
|
|
tk = f"{vk}.{i}"
|
|
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
|
|
return nl
|
|
|
|
raise RuntimeError("Unexpected leaf found while traversing the module")
|
|
|
|
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
|
|
|
|
def parameters(self):
|
|
"""Recursively return all the :class:`mlx.core.array` members of this Module
|
|
as a dict of dicts and lists."""
|
|
return self.filter_and_map(self.valid_parameter_filter)
|
|
|
|
def trainable_parameters(self):
|
|
"""Recursively return all the non frozen :class:`mlx.core.array` members of
|
|
this Module as a dict of dicts and lists."""
|
|
return self.filter_and_map(self.trainable_parameter_filter)
|
|
|
|
def children(self):
|
|
"""Return the direct descendants of this Module instance."""
|
|
return self.filter_and_map(
|
|
self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module)
|
|
)
|
|
|
|
def leaf_modules(self):
|
|
"""Return the submodules that do not contain other modules."""
|
|
|
|
def _is_leaf_module(m, k, v):
|
|
return isinstance(v, Module) and len(tree_flatten(v.children())) == 0
|
|
|
|
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
|
|
|
def update(self, parameters: dict):
|
|
"""Replace the parameters of this Module with the provided ones in the
|
|
dict of dicts and lists.
|
|
|
|
Commonly used by the optimizer to change the model to the updated
|
|
(optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the
|
|
tracers in the model in order to compute gradients.
|
|
|
|
The passed in parameters dictionary need not be a full dictionary
|
|
similar to :meth:`parameters`. Only the provided locations will be
|
|
updated.
|
|
|
|
Args:
|
|
parameters (dict): A complete or partial dictionary of the modules
|
|
parameters.
|
|
"""
|
|
|
|
def apply(dst, parameters):
|
|
if isinstance(parameters, dict):
|
|
for k in parameters:
|
|
if k in dst:
|
|
current_value = dst[k]
|
|
new_value = parameters[k]
|
|
if isinstance(current_value, mx.array):
|
|
dst[k] = new_value
|
|
elif isinstance(current_value, Module):
|
|
current_value.update(new_value)
|
|
elif isinstance(current_value, (dict, list)):
|
|
apply(current_value, new_value)
|
|
elif isinstance(parameters, list):
|
|
for i in range(len(dst)):
|
|
current_value = dst[i]
|
|
new_value = parameters[i]
|
|
if isinstance(current_value, mx.array):
|
|
dst[i] = new_value
|
|
elif isinstance(current_value, Module):
|
|
current_value.update(new_value)
|
|
elif isinstance(current_value, (dict, list)):
|
|
apply(current_value, new_value)
|
|
|
|
apply(self, parameters)
|
|
|
|
def apply(
|
|
self,
|
|
map_fn: Callable[[mx.array], mx.array],
|
|
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
|
):
|
|
"""Map all the parameters using the provided ``map_fn`` and immediately
|
|
update the module with the mapped parameters.
|
|
|
|
For instance running ``model.apply(lambda x: x.astype(mx.float16))``
|
|
casts all parameters to 16 bit floats.
|
|
|
|
Args:
|
|
map_fn (Callable): Maps an array to another array
|
|
filter_fn (Callable, optional): Filter to select which arrays to
|
|
map (default: :meth:`Module.valid_parameter_filter`).
|
|
"""
|
|
filter_fn = filter_fn or Module.valid_parameter_filter
|
|
self.update(self.filter_and_map(filter_fn, map_fn))
|
|
|
|
def update_modules(self, modules: dict):
|
|
"""Replace the child modules of this :class:`Module` instance with the
|
|
provided ones in the dict of dicts and lists.
|
|
|
|
It is the equivalent of :meth:`Module.update` but for modules instead
|
|
of parameters and allows us to flexibly edit complex architectures by
|
|
programmatically swapping layers.
|
|
|
|
The passed in parameters dictionary need not be a full dictionary
|
|
similar to :meth:`parameters`. Only the provided locations will be
|
|
updated.
|
|
|
|
Args:
|
|
modules (dict): A complete or partial dictionary of the modules
|
|
submodules.
|
|
"""
|
|
|
|
def apply(dst, modules):
|
|
if isinstance(modules, dict):
|
|
for k in modules:
|
|
if k in dst:
|
|
current_value = dst[k]
|
|
new_value = modules[k]
|
|
if self.is_module(current_value) and self.is_module(new_value):
|
|
dst[k] = new_value
|
|
elif isinstance(current_value, (dict, list)):
|
|
apply(current_value, new_value)
|
|
elif isinstance(modules, list):
|
|
for i in range(len(dst)):
|
|
current_value = dst[i]
|
|
new_value = modules[i]
|
|
if self.is_module(current_value) and self.is_module(new_value):
|
|
dst[i] = new_value
|
|
elif isinstance(current_value, (dict, list)):
|
|
apply(current_value, new_value)
|
|
|
|
apply(self, modules)
|
|
|
|
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
|
|
"""Apply a function to all the modules in this instance (including this
|
|
instance).
|
|
|
|
Args:
|
|
apply_fn (Callable): The function to apply to the modules.
|
|
"""
|
|
module_stack = [("", self)]
|
|
while module_stack:
|
|
prefix, mod = module_stack.pop()
|
|
apply_fn(prefix, mod)
|
|
prefix = "." + prefix if prefix else ""
|
|
module_stack.extend(
|
|
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
|
|
)
|
|
|
|
def modules(self):
|
|
"""Return a list with all the modules in this instance.
|
|
|
|
Returns:
|
|
A list of :class:`mlx.nn.Module` instances.
|
|
"""
|
|
modulelist = []
|
|
self.apply_to_modules(lambda k, m: modulelist.append(m))
|
|
return modulelist
|
|
|
|
def named_modules(self):
|
|
"""Return a list with all the modules in this instance and their name
|
|
with dot notation.
|
|
|
|
Returns:
|
|
A list of tuples (str, :class:`mlx.nn.Module`).
|
|
"""
|
|
modulelist = []
|
|
self.apply_to_modules(lambda k, m: modulelist.append((k, m)))
|
|
return modulelist
|
|
|
|
def _validate_keys(self, keys, strict):
|
|
keys = keys if isinstance(keys, list) else [keys]
|
|
if strict:
|
|
for k in keys:
|
|
if k not in self:
|
|
raise KeyError(f"Module doesn't contain member {k}.")
|
|
return keys
|
|
|
|
def freeze(
|
|
self,
|
|
*,
|
|
recurse: bool = True,
|
|
keys: Optional[Union[str, List[str]]] = None,
|
|
strict: bool = False,
|
|
):
|
|
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
|
computing gradients for it.
|
|
|
|
This function is idempotent i.e. freezing a frozen model is a no-op.
|
|
|
|
Example:
|
|
For instance to only train the attention parameters from a Transformer:
|
|
|
|
.. code-block:: python
|
|
|
|
model = nn.Transformer()
|
|
model.freeze()
|
|
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
|
|
|
|
Args:
|
|
recurse (bool, optional): If True then freeze the parameters of the
|
|
submodules as well. Default: ``True``.
|
|
keys (str or list[str], optional): If provided then only these
|
|
parameters will be frozen otherwise all the parameters of a
|
|
module. For instance freeze all biases by calling
|
|
``module.freeze(keys="bias")``.
|
|
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
|
Default: ``False``.
|
|
"""
|
|
|
|
def _freeze_impl(_, m):
|
|
local_keys = keys
|
|
if local_keys is None:
|
|
local_keys = tree_flatten(
|
|
m.filter_and_map(
|
|
lambda m, k, v: (not isinstance(v, Module))
|
|
and m.valid_parameter_filter(m, k, v)
|
|
)
|
|
)
|
|
local_keys = [k for (k, v) in local_keys]
|
|
|
|
local_keys = m._validate_keys(local_keys, strict)
|
|
m._no_grad.update(local_keys)
|
|
|
|
if recurse:
|
|
self.apply_to_modules(_freeze_impl)
|
|
else:
|
|
_freeze_impl("", self)
|
|
|
|
def unfreeze(
|
|
self,
|
|
*,
|
|
recurse: bool = True,
|
|
keys: Optional[Union[str, List[str]]] = None,
|
|
strict: bool = False,
|
|
):
|
|
"""Unfreeze the Module's parameters or some of them.
|
|
|
|
This function is idempotent ie unfreezing a model that is not frozen is
|
|
a noop.
|
|
|
|
Example:
|
|
|
|
For instance to only train the biases of a Transformer one can do:
|
|
|
|
.. code-block:: python
|
|
|
|
model = nn.Transformer()
|
|
model.freeze()
|
|
model.unfreeze(keys="bias")
|
|
|
|
Args:
|
|
recurse (bool, optional): If True then unfreeze the parameters of the
|
|
submodules as well. Default: ``True``.
|
|
keys (str or list[str], optional): If provided then only these
|
|
parameters will be unfrozen otherwise all the parameters of a
|
|
module. For instance unfreeze all biases by calling
|
|
``module.unfreeze(keys="bias")``.
|
|
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
|
Default: ``False``.
|
|
"""
|
|
|
|
def _unfreeze_impl(_, m):
|
|
if keys is None:
|
|
m._no_grad.clear()
|
|
|
|
else:
|
|
local_keys = m._validate_keys(keys, strict)
|
|
m._no_grad.difference_update(local_keys)
|
|
|
|
if recurse:
|
|
self.apply_to_modules(_unfreeze_impl)
|
|
else:
|
|
_unfreeze_impl("", self)
|
|
|
|
def train(self, mode: bool = True):
|
|
"""Set the model in or out of training mode.
|
|
|
|
Training mode only applies to certain layers. For example
|
|
:obj:`Dropout` applies a random mask in training mode, but is the
|
|
identity in evaluation mode.
|
|
|
|
Args:
|
|
mode (bool): Indicate if the model should be in training or
|
|
evaluation mode. Default: ``True``.
|
|
"""
|
|
|
|
def _set_train(_, m):
|
|
m._training = mode
|
|
|
|
self.apply_to_modules(_set_train)
|
|
|
|
def eval(self):
|
|
"""Set the model to evaluation mode.
|
|
|
|
See :func:`train`.
|
|
"""
|
|
self.train(False)
|
|
|
|
def __getstate__(self):
|
|
return self.__dict__.copy()
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__.update(state)
|