mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
awni's commit files
This commit is contained in:
401
python/mlx/nn/layers/base.py
Normal file
401
python/mlx/nn/layers/base.py
Normal file
@@ -0,0 +1,401 @@
|
||||
import textwrap
|
||||
from typing import Any, Callable, List, Union, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
class Module(dict):
|
||||
"""Base class for building neural networks with MLX.
|
||||
|
||||
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
|
||||
your models should do the same.
|
||||
|
||||
A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`
|
||||
instances in arbitrary nesting of python lists or dicts. The ``Module``
|
||||
then allows recursively extracting all the :class:`mlx.core.array` instances
|
||||
using :meth:`mlx.nn.Module.parameters`.
|
||||
|
||||
In addition, the ``Module`` has the concept of trainable and non trainable
|
||||
parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad`
|
||||
the gradients are returned only with respect to the trainable parameters.
|
||||
All arrays in a module are trainable unless they are added in the "frozen"
|
||||
set by calling :meth:`freeze`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
class MyMLP(nn.Module):
|
||||
def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
|
||||
super().__init__()
|
||||
|
||||
self.in_proj = nn.Linear(in_dims, hidden_dims)
|
||||
self.out_proj = nn.Linear(hidden_dims, out_dims)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.in_proj(x)
|
||||
x = mx.maximum(x, 0)
|
||||
return self.out_proj(x)
|
||||
|
||||
model = MyMLP(2, 1)
|
||||
|
||||
# All the model parameters are created but since MLX is lazy by
|
||||
# default, they are not evaluated yet. Calling `mx.eval` actually
|
||||
# allocates memory and initializes the parameters.
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# Setting a parameter to a new value is as simply as accessing that
|
||||
# parameter and assigning a new array to it.
|
||||
model.in_proj.weight = model.in_proj.weight * 2
|
||||
mx.eval(model.parameters())
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Should be called by the subclasses of ``Module``."""
|
||||
self._no_grad = set()
|
||||
self._training = True
|
||||
|
||||
@property
|
||||
def training(self):
|
||||
return self._training
|
||||
|
||||
def _extra_repr(self):
|
||||
return ""
|
||||
|
||||
def __repr__(self):
|
||||
children = tree_flatten(self.children(), is_leaf=self.is_module)
|
||||
value = f"{type(self).__name__}({self._extra_repr()}"
|
||||
for k, v in children:
|
||||
value += "\n"
|
||||
value += textwrap.indent(f"({k}): {repr(v)}", prefix=" ")
|
||||
if children:
|
||||
value += "\n"
|
||||
value += ")"
|
||||
|
||||
return value
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
|
||||
|
||||
def __setattr__(self, key: str, val: Any):
|
||||
self[key] = val
|
||||
|
||||
def load_weights(self, file: str):
|
||||
"""
|
||||
Load and update the model's weights from a `.npz` file.
|
||||
"""
|
||||
self.update(tree_unflatten(list(mx.load(file).items())))
|
||||
|
||||
def save_weights(self, file: str):
|
||||
"""
|
||||
Save the model's weights to a `.npz` file.
|
||||
"""
|
||||
mx.savez(file, **dict(tree_flatten(self.parameters())))
|
||||
|
||||
@staticmethod
|
||||
def is_module(value):
|
||||
return isinstance(value, Module)
|
||||
|
||||
@staticmethod
|
||||
def valid_child_filter(module, key, value):
|
||||
return isinstance(value, (dict, list))
|
||||
|
||||
@staticmethod
|
||||
def valid_parameter_filter(module, key, value):
|
||||
return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
|
||||
|
||||
@staticmethod
|
||||
def trainable_parameter_filter(module, key, value):
|
||||
return (
|
||||
Module.valid_parameter_filter(module, key, value)
|
||||
and key not in module._no_grad
|
||||
)
|
||||
|
||||
def filter_and_map(
|
||||
self,
|
||||
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
|
||||
map_fn: Optional[Callable] = None,
|
||||
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||
):
|
||||
"""Recursively filter the contents of the module using ``filter_fn``,
|
||||
namely only select keys and values where ``filter_fn`` returns true.
|
||||
|
||||
This is used to implement :meth:`parameters` and :meth:`trainable_parameters`
|
||||
but it can also be used to extract any subset of the module's parameters.
|
||||
|
||||
Args:
|
||||
filter_fn (Callable): Given a value, the key in which it is found
|
||||
and the containing module, decide whether to keep the value or
|
||||
drop it.
|
||||
map_fn (Callable, optional): Optionally transform the value before
|
||||
returning it.
|
||||
is_leaf_fn (Callable, optional): Given a value, the key in which it
|
||||
is found and the containing module decide if it is a leaf.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the contents of the module recursively filtered
|
||||
"""
|
||||
|
||||
map_fn = map_fn or (lambda x: x)
|
||||
is_leaf_fn = is_leaf_fn or (
|
||||
lambda m, k, v: not isinstance(v, (Module, dict, list))
|
||||
)
|
||||
|
||||
def unwrap(vk, v):
|
||||
if is_leaf_fn(self, vk, v):
|
||||
return map_fn(v)
|
||||
|
||||
if isinstance(v, Module):
|
||||
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
|
||||
|
||||
if isinstance(v, dict):
|
||||
nd = {}
|
||||
for k, v in v.items():
|
||||
tk = f"{vk}.{k}"
|
||||
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
|
||||
return nd
|
||||
|
||||
if isinstance(v, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(v):
|
||||
tk = f"{vk}.{i}"
|
||||
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
||||
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
|
||||
|
||||
def parameters(self):
|
||||
"""Recursively return all the :class:`mlx.core.array` members of this Module
|
||||
as a dict of dicts and lists."""
|
||||
return self.filter_and_map(self.valid_parameter_filter)
|
||||
|
||||
def trainable_parameters(self):
|
||||
"""Recursively return all the non frozen :class:`mlx.core.array` members of
|
||||
this Module as a dict of dicts and lists."""
|
||||
return self.filter_and_map(self.trainable_parameter_filter)
|
||||
|
||||
def children(self):
|
||||
"""Return the direct descendants of this Module instance."""
|
||||
return self.filter_and_map(
|
||||
self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module)
|
||||
)
|
||||
|
||||
def leaf_modules(self):
|
||||
"""Return the submodules that do not contain other modules."""
|
||||
|
||||
def _is_leaf_module(m, k, v):
|
||||
return isinstance(v, Module) and len(tree_flatten(v.children())) == 0
|
||||
|
||||
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
||||
|
||||
def update(self, parameters: dict):
|
||||
"""Replace the parameters of this Module with the provided ones in the
|
||||
dict of dicts and lists.
|
||||
|
||||
Commonly used by the optimizer to change the model to the updated
|
||||
(optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the
|
||||
tracers in the model in order to compute gradients.
|
||||
|
||||
The passed in parameters dictionary need not be a full dictionary
|
||||
similar to :meth:`parameters`. Only the provided locations will be
|
||||
updated.
|
||||
|
||||
Args:
|
||||
parameters (dict): A complete or partial dictionary of the modules
|
||||
parameters.
|
||||
"""
|
||||
|
||||
def apply(dst, parameters):
|
||||
if isinstance(parameters, dict):
|
||||
for k in parameters:
|
||||
if k in dst:
|
||||
current_value = dst[k]
|
||||
new_value = parameters[k]
|
||||
if isinstance(current_value, mx.array):
|
||||
dst[k] = new_value
|
||||
elif isinstance(current_value, Module):
|
||||
current_value.update(new_value)
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif isinstance(parameters, list):
|
||||
for i in range(len(dst)):
|
||||
current_value = dst[i]
|
||||
new_value = parameters[i]
|
||||
if isinstance(current_value, mx.array):
|
||||
dst[i] = new_value
|
||||
elif isinstance(current_value, Module):
|
||||
current_value.update(new_value)
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
|
||||
apply(self, parameters)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
map_fn: Callable[[mx.array], mx.array],
|
||||
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||
):
|
||||
"""Map all the parameters using the provided ``map_fn`` and immediately
|
||||
update the module with the mapped parameters.
|
||||
|
||||
For instance running ``model.apply(lambda x: x.astype(mx.float16))``
|
||||
casts all parameters to 16 bit floats.
|
||||
|
||||
Args:
|
||||
map_fn (Callable): Maps an array to another array
|
||||
filter_fn (Callable, optional): Filter to select which arrays to
|
||||
map (default: :meth:`Module.valid_parameter_filter`).
|
||||
"""
|
||||
filter_fn = filter_fn or Module.valid_parameter_filter
|
||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||
|
||||
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
|
||||
"""Apply a function to all the modules in this instance (including this
|
||||
instance).
|
||||
|
||||
Args:
|
||||
apply_fn (Callable): The function to apply to the modules.
|
||||
"""
|
||||
module_stack = [("", self)]
|
||||
while module_stack:
|
||||
prefix, mod = module_stack.pop()
|
||||
apply_fn(prefix, mod)
|
||||
prefix = "." + prefix if prefix else ""
|
||||
module_stack.extend(
|
||||
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
|
||||
)
|
||||
|
||||
def modules(self):
|
||||
"""Return a list with all the modules in this instance.
|
||||
|
||||
Returns:
|
||||
A list of :class:`mlx.nn.Module` instances.
|
||||
"""
|
||||
modulelist = []
|
||||
self.apply_to_modules(lambda k, m: modulelist.append(m))
|
||||
return modulelist
|
||||
|
||||
def named_modules(self):
|
||||
"""Return a list with all the modules in this instance and their name
|
||||
with dot notation.
|
||||
|
||||
Returns:
|
||||
A list of tuples (str, :class:`mlx.nn.Module`).
|
||||
"""
|
||||
modulelist = []
|
||||
self.apply_to_modules(lambda k, m: modulelist.append((k, m)))
|
||||
return modulelist
|
||||
|
||||
def _validate_keys(self, keys, strict):
|
||||
keys = keys if isinstance(keys, list) else [keys]
|
||||
if strict:
|
||||
for k in keys:
|
||||
if k not in self:
|
||||
raise KeyError(f"Module doesn't contain member {k}.")
|
||||
return keys
|
||||
|
||||
def freeze(
|
||||
self,
|
||||
*,
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
):
|
||||
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
||||
computing gradients for it.
|
||||
|
||||
This function is idempotent ie freezing a frozen model is a noop.
|
||||
|
||||
For instance to only train the attention parameters from a transformer:
|
||||
|
||||
model = ...
|
||||
model.freeze()
|
||||
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
|
||||
|
||||
Args:
|
||||
recurse (bool, optional): If True then freeze the parameters of the
|
||||
submodules as well (default: True).
|
||||
keys (str or list[str], optional): If provided then only these
|
||||
parameters will be frozen otherwise all the parameters of a
|
||||
module. For instance freeze all biases by calling
|
||||
``module.freeze(keys="bias")``.
|
||||
strict (bool, optional): If set to True validate that the passed keys exist
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def _freeze_impl(_, m):
|
||||
local_keys = keys
|
||||
if local_keys is None:
|
||||
local_keys = tree_flatten(
|
||||
m.filter_and_map(
|
||||
lambda m, k, v: (not isinstance(v, Module))
|
||||
and m.valid_parameter_filter(m, k, v)
|
||||
)
|
||||
)
|
||||
local_keys = [k for (k, v) in local_keys]
|
||||
|
||||
local_keys = m._validate_keys(local_keys, strict)
|
||||
m._no_grad.update(local_keys)
|
||||
|
||||
if recurse:
|
||||
self.apply_to_modules(_freeze_impl)
|
||||
else:
|
||||
_freeze_impl("", self)
|
||||
|
||||
def unfreeze(
|
||||
self,
|
||||
*,
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
):
|
||||
"""Unfreeze the Module's parameters or some of them.
|
||||
|
||||
This function is idempotent ie unfreezing a model that is not frozen is
|
||||
a noop.
|
||||
|
||||
For instance to only train the biases one can do:
|
||||
|
||||
model = ...
|
||||
model.freeze()
|
||||
model.unfreeze(keys="bias")
|
||||
|
||||
Args:
|
||||
recurse (bool, optional): If True then unfreeze the parameters of the
|
||||
submodules as well (default: True).
|
||||
keys (str or list[str], optional): If provided then only these
|
||||
parameters will be unfrozen otherwise all the parameters of a
|
||||
module. For instance unfreeze all biases by calling
|
||||
``module.unfreeze(keys="bias")``.
|
||||
strict (bool, optional): If set to True validate that the passed keys exist
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def _unfreeze_impl(_, m):
|
||||
if keys is None:
|
||||
m._no_grad.clear()
|
||||
|
||||
else:
|
||||
local_keys = m._validate_keys(keys, strict)
|
||||
m._no_grad.difference_update(local_keys)
|
||||
|
||||
if recurse:
|
||||
self.apply_to_modules(_unfreeze_impl)
|
||||
else:
|
||||
_unfreeze_impl("", self)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
def _set_train(_, m):
|
||||
m._training = mode
|
||||
|
||||
self.apply_to_modules(_set_train)
|
||||
|
||||
def eval(self):
|
||||
self.train(False)
|
||||
22
python/mlx/nn/layers/containers.py
Normal file
22
python/mlx/nn/layers/containers.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Sequential(Module):
|
||||
"""A layer that calls the passed callables in order.
|
||||
|
||||
We can pass either modules or plain callables to the Sequential module. If
|
||||
our functions have learnable parameters they should be implemented as
|
||||
``nn.Module`` instances.
|
||||
|
||||
Args:
|
||||
modules (tuple of Callables): The modules to call in order
|
||||
"""
|
||||
|
||||
def __init__(self, *modules):
|
||||
super().__init__()
|
||||
self.layers = list(modules)
|
||||
|
||||
def __call__(self, x):
|
||||
for m in self.layers:
|
||||
x = m(x)
|
||||
return x
|
||||
33
python/mlx/nn/layers/dropout.py
Normal file
33
python/mlx/nn/layers/dropout.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Dropout(Module):
|
||||
"""Randomly zero a portion of the elements during training.
|
||||
|
||||
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
|
||||
:math:`p` is the probability of zeroing an element. This is done so the
|
||||
expected value of a given element will remain the same.
|
||||
|
||||
Args:
|
||||
p (float): The probability to zero an element
|
||||
"""
|
||||
|
||||
def __init__(self, p: float = 0.5):
|
||||
super().__init__()
|
||||
|
||||
if p < 0 or p >= 1:
|
||||
raise ValueError("The dropout probability should be in [0, 1)")
|
||||
|
||||
self._p_1 = 1 - p
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"p={1-self._p_1}"
|
||||
|
||||
def __call__(self, x):
|
||||
if self._p_1 == 1 or not self.training:
|
||||
return x
|
||||
|
||||
mask = mx.random.bernoulli(self._p_1, x.shape)
|
||||
|
||||
return (1 / self._p_1) * mask.astype(x.dtype) * x
|
||||
178
python/mlx/nn/layers/normalization.py
Normal file
178
python/mlx/nn/layers/normalization.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class LayerNorm(Module):
|
||||
r"""Applies layer normalization [1] on the inputs.
|
||||
|
||||
Computes
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
|
||||
|
||||
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||
parameters initialized at 1 and 0 respectively.
|
||||
|
||||
[1]: https://arxiv.org/abs/1607.06450
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimension of the input to normalize over
|
||||
eps (float): A small additive constant for numerical stability
|
||||
affine (bool): If True learn an affine transform to apply after the
|
||||
normalization
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
self.dims = dims
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x):
|
||||
means = mx.mean(x, axis=-1, keepdims=True)
|
||||
var = mx.var(x, axis=-1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
|
||||
|
||||
class RMSNorm(Module):
|
||||
r"""Applies Root Mean Square normalization [1] to the inputs.
|
||||
|
||||
Computes
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma
|
||||
|
||||
where :math:`\gamma` is a learned per feature dimension parameter initialized at
|
||||
1.
|
||||
|
||||
[1]: https://arxiv.org/abs/1910.07467
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimension of the input to normalize over
|
||||
eps (float): A small additive constant for numerical stability
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.weight.shape[0]}, eps={self.eps}"
|
||||
|
||||
def __call__(self, x):
|
||||
# S is 1/sqrt(N) where N is the size of the features of x and is used
|
||||
# to compute a numerically more stable RMS of x by multiplying with S
|
||||
# first and summing.
|
||||
#
|
||||
# This way we prefer underflow over overflow which is controlled with
|
||||
# the parameter epsilon anyway.
|
||||
S = 1 / x.shape[-1] ** 0.5
|
||||
|
||||
n = (x * S).square().sum(axis=-1, keepdims=True)
|
||||
n = mx.rsqrt(n + self.eps)
|
||||
|
||||
return self.weight * x * n
|
||||
|
||||
|
||||
class GroupNorm(Module):
|
||||
r"""Applies Group Normalization [1] to the inputs.
|
||||
|
||||
Computes the same normalization as layer norm, namely
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
|
||||
|
||||
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||
parameters initialized at 1 and 0 respectively. However, the mean and
|
||||
variance are computed over the spatial dimensions and each group of
|
||||
features. In particular, the input is split into num_groups accross the
|
||||
feature dimension.
|
||||
|
||||
The feature dimension is assumed to be the last dimension and the dimensions
|
||||
that precede it (except the first) are considered the spatial dimensions.
|
||||
|
||||
[1]: https://arxiv.org/abs/1803.08494
|
||||
|
||||
Args:
|
||||
num_groups (int): Number of groups to separate the features into
|
||||
dims (int): The feature dimensions of the input to normalize over
|
||||
eps (float): A small additive constant for numerical stability
|
||||
affine (bool): If True learn an affine transform to apply after the
|
||||
normalization.
|
||||
pytorch_compatible (bool): If True perform the group normalization in
|
||||
the same order/grouping as PyTorch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_groups: int,
|
||||
dims: int,
|
||||
eps: float = 1e-5,
|
||||
affine: bool = True,
|
||||
pytorch_compatible: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.weight = mx.ones((dims,))
|
||||
self.num_groups = num_groups
|
||||
self.dims = dims
|
||||
self.eps = eps
|
||||
self.pytorch_compatible = pytorch_compatible
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.num_groups}, {self.dims}, eps={self.eps}, "
|
||||
f"affine={'weight' in self}, pytorch_compatible={self.pytorch_compatible}"
|
||||
)
|
||||
|
||||
def _pytorch_compatible_group_norm(self, x):
|
||||
num_groups = self.num_groups
|
||||
batch, *rest, dims = x.shape
|
||||
|
||||
# Split into groups
|
||||
x = x.reshape(batch, -1, num_groups, dims // num_groups)
|
||||
x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups)
|
||||
|
||||
# Normalize
|
||||
means = mx.mean(x, axis=1, keepdims=True)
|
||||
var = mx.var(x, axis=1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
x = x.reshape(batch, -1, dims // num_groups, num_groups)
|
||||
x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims)
|
||||
|
||||
return x
|
||||
|
||||
def _group_norm(self, x):
|
||||
num_groups = self.num_groups
|
||||
batch, *rest, dims = x.shape
|
||||
|
||||
# Split into groups
|
||||
x = x.reshape(batch, -1, num_groups)
|
||||
|
||||
# Normalize
|
||||
means = mx.mean(x, axis=1, keepdims=True)
|
||||
var = mx.var(x, axis=1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
x = x.reshape(batch, *rest, dims)
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x):
|
||||
group_norm = (
|
||||
self._pytorch_compatible_group_norm
|
||||
if self.pytorch_compatible
|
||||
else self._group_norm
|
||||
)
|
||||
x = group_norm(x)
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
142
python/mlx/nn/layers/positional_encoding.py
Normal file
142
python/mlx/nn/layers/positional_encoding.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class RoPE(Module):
|
||||
"""Implements the rotary positional encoding [1].
|
||||
|
||||
The traditional implementation rotates consecutive pairs of elements in the
|
||||
feature dimension while the default implementation rotates pairs with
|
||||
stride half the feature dimensions for efficiency.
|
||||
|
||||
[1]: https://arxiv.org/abs/2104.09864
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimensions to be rotated. If the input feature
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool): If set to True choose the traditional
|
||||
implementation which is slightly less efficient.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, traditional: bool = False):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, traditional={self.traditional}"
|
||||
|
||||
def _compute_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., : self.dims // 2]
|
||||
x2 = x[..., self.dims // 2 : self.dims]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
|
||||
else:
|
||||
rx = mx.concatenate([rx1, rx2], axis=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
raise NotImplementedError(
|
||||
"RoPE doesn't implement partial traditional application"
|
||||
)
|
||||
|
||||
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return mx.reshape(rx, shape)
|
||||
|
||||
@staticmethod
|
||||
def create_cos_sin_theta(
|
||||
N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32
|
||||
):
|
||||
D = D // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype)
|
||||
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||
costheta = mx.cos(theta)
|
||||
sintheta = mx.sin(theta)
|
||||
|
||||
return costheta, sintheta
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(Module):
|
||||
"""Implements sinusoidal positional encoding similar to [1].
|
||||
|
||||
[1]: https://arxiv.org/abs/1706.03762
|
||||
|
||||
Args:
|
||||
dims (int): The dimensionality of the resulting positional embeddings.
|
||||
min_freq (float): The minimum frequency expected (default: 0.0001)
|
||||
max_freq (float): The maximum frequency expected (default: 1)
|
||||
scale (float): Scale the embeddings by that number (default: sqrt(dims//2))
|
||||
cos_first (bool): If set to True embed using ``[cos(x); sin(x)]``
|
||||
instead of the other way around (default: False)
|
||||
full_turns (bool): If set to True multiply the frequencies
|
||||
with ``2 pi`` (default: False)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
min_freq: float = 0.0001,
|
||||
max_freq: float = 1,
|
||||
scale: Optional[float] = None,
|
||||
cos_first: bool = False,
|
||||
full_turns: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
one_zero = 1 - mx.arange(0, dims // 2) / (dims // 2 - 1)
|
||||
min_freq = math.log(min_freq)
|
||||
max_freq = math.log(max_freq)
|
||||
|
||||
# Start with underscore so it is not included in the parameters
|
||||
self._sigmas = mx.exp(one_zero * (max_freq - min_freq) + min_freq)
|
||||
if full_turns:
|
||||
self._sigmas = self._sigmas * (2 * math.pi)
|
||||
|
||||
# Save some constants that define the implementation
|
||||
self.scale = scale or (2 / dims) ** 0.5
|
||||
self.cos_first = cos_first
|
||||
|
||||
def __call__(self, x):
|
||||
y = x[..., None] * self._sigmas
|
||||
cosy = mx.cos(y)
|
||||
siny = mx.sin(y)
|
||||
|
||||
if self.cos_first:
|
||||
y = mx.concatenate([cosy, siny], axis=-1)
|
||||
else:
|
||||
y = mx.concatenate([siny, cosy], axis=-1)
|
||||
|
||||
if self.scale != 1:
|
||||
y = y * self.scale
|
||||
|
||||
return y
|
||||
136
python/mlx/nn/layers/transformer.py
Normal file
136
python/mlx/nn/layers/transformer.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.nn.layers.normalization import LayerNorm
|
||||
|
||||
|
||||
class MultiHeadAttention(Module):
|
||||
"""Implements the scaled dot product attention with multiple heads.
|
||||
|
||||
Given inputs for queries, keys and values the ``MultiHeadAttention`` produces
|
||||
new values by aggregating information from the input values according to
|
||||
the similarities of the input queries and keys.
|
||||
|
||||
All inputs as well as the output are lineary projected without biases.
|
||||
|
||||
MultiHeadAttention also expects an additive attention mask that should be
|
||||
broadcastable with (batch, num_heads, # queries, # keys). The mask should
|
||||
have ``-inf`` or very negative numbers to the positions that should *not* be
|
||||
attended to.
|
||||
|
||||
Args:
|
||||
dims (int): The model dimensions. If no other dims are provided then
|
||||
dims is used for queries, keys, values and the output.
|
||||
num_heads (int): How many attention heads to use
|
||||
query_input_dims (int, optional): The input dimensions of the queries (default: dims).
|
||||
key_input_dims (int, optional): The input dimensions of the keys (default: dims).
|
||||
value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims).
|
||||
value_dims (int, optional): The dimensions of the values after the projection (default: dims).
|
||||
value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
query_input_dims: Optional[int] = None,
|
||||
key_input_dims: Optional[int] = None,
|
||||
value_input_dims: Optional[int] = None,
|
||||
value_dims: Optional[int] = None,
|
||||
value_output_dims: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if (dims % num_heads) != 0:
|
||||
raise ValueError(
|
||||
f"The input feature dimensions should be divisble by the number of heads ({dims} % {num_heads}) != 0"
|
||||
)
|
||||
|
||||
query_input_dims = query_input_dims or dims
|
||||
key_input_dims = key_input_dims or dims
|
||||
value_input_dims = value_input_dims or key_input_dims
|
||||
value_dims = value_dims or dims
|
||||
value_output_dims = value_output_dims or dims
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.query_proj = Linear(query_input_dims, dims, False)
|
||||
self.key_proj = Linear(key_input_dims, dims, False)
|
||||
self.value_proj = Linear(value_input_dims, value_dims, False)
|
||||
self.out_proj = Linear(value_dims, value_output_dims, False)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat)
|
||||
|
||||
@staticmethod
|
||||
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
||||
indices = mx.arange(N)
|
||||
mask = indices[:, None] < indices[None]
|
||||
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
|
||||
# TODO: Should replace this with finfo(dtype).min
|
||||
mask = mask.astype(dtype) * -1e9
|
||||
return mask
|
||||
|
||||
|
||||
class TransformerEncoderLayer(Module):
|
||||
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
||||
super().__init__()
|
||||
mlp_dims = mlp_dims or dims * 4
|
||||
self.attention = MultiHeadAttention(dims, num_heads)
|
||||
self.ln1 = LayerNorm(dims)
|
||||
self.ln2 = LayerNorm(dims)
|
||||
self.linear1 = Linear(dims, mlp_dims)
|
||||
self.linear2 = Linear(mlp_dims, dims)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
y = self.ln1(x)
|
||||
y = self.attention(y, y, y, mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.linear1(y)
|
||||
y = mx.maximum(y, 0)
|
||||
y = self.linear2(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(Module):
|
||||
def __init__(
|
||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(dims)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
x = self.ln(x)
|
||||
|
||||
return x
|
||||
Reference in New Issue
Block a user