awni's commit files

This commit is contained in:
Awni Hannun
2023-11-29 10:30:41 -08:00
parent e411fcae68
commit 8ca7f9e8e9
130 changed files with 30159 additions and 0 deletions

View File

@@ -0,0 +1,401 @@
import textwrap
from typing import Any, Callable, List, Union, Optional
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
class Module(dict):
"""Base class for building neural networks with MLX.
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
your models should do the same.
A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`
instances in arbitrary nesting of python lists or dicts. The ``Module``
then allows recursively extracting all the :class:`mlx.core.array` instances
using :meth:`mlx.nn.Module.parameters`.
In addition, the ``Module`` has the concept of trainable and non trainable
parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad`
the gradients are returned only with respect to the trainable parameters.
All arrays in a module are trainable unless they are added in the "frozen"
set by calling :meth:`freeze`.
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
class MyMLP(nn.Module):
def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
super().__init__()
self.in_proj = nn.Linear(in_dims, hidden_dims)
self.out_proj = nn.Linear(hidden_dims, out_dims)
def __call__(self, x):
x = self.in_proj(x)
x = mx.maximum(x, 0)
return self.out_proj(x)
model = MyMLP(2, 1)
# All the model parameters are created but since MLX is lazy by
# default, they are not evaluated yet. Calling `mx.eval` actually
# allocates memory and initializes the parameters.
mx.eval(model.parameters())
# Setting a parameter to a new value is as simply as accessing that
# parameter and assigning a new array to it.
model.in_proj.weight = model.in_proj.weight * 2
mx.eval(model.parameters())
"""
def __init__(self):
"""Should be called by the subclasses of ``Module``."""
self._no_grad = set()
self._training = True
@property
def training(self):
return self._training
def _extra_repr(self):
return ""
def __repr__(self):
children = tree_flatten(self.children(), is_leaf=self.is_module)
value = f"{type(self).__name__}({self._extra_repr()}"
for k, v in children:
value += "\n"
value += textwrap.indent(f"({k}): {repr(v)}", prefix=" ")
if children:
value += "\n"
value += ")"
return value
def __getattr__(self, key: str):
if key in self:
return self[key]
else:
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
def __setattr__(self, key: str, val: Any):
self[key] = val
def load_weights(self, file: str):
"""
Load and update the model's weights from a `.npz` file.
"""
self.update(tree_unflatten(list(mx.load(file).items())))
def save_weights(self, file: str):
"""
Save the model's weights to a `.npz` file.
"""
mx.savez(file, **dict(tree_flatten(self.parameters())))
@staticmethod
def is_module(value):
return isinstance(value, Module)
@staticmethod
def valid_child_filter(module, key, value):
return isinstance(value, (dict, list))
@staticmethod
def valid_parameter_filter(module, key, value):
return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
@staticmethod
def trainable_parameter_filter(module, key, value):
return (
Module.valid_parameter_filter(module, key, value)
and key not in module._no_grad
)
def filter_and_map(
self,
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
map_fn: Optional[Callable] = None,
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
):
"""Recursively filter the contents of the module using ``filter_fn``,
namely only select keys and values where ``filter_fn`` returns true.
This is used to implement :meth:`parameters` and :meth:`trainable_parameters`
but it can also be used to extract any subset of the module's parameters.
Args:
filter_fn (Callable): Given a value, the key in which it is found
and the containing module, decide whether to keep the value or
drop it.
map_fn (Callable, optional): Optionally transform the value before
returning it.
is_leaf_fn (Callable, optional): Given a value, the key in which it
is found and the containing module decide if it is a leaf.
Returns:
A dictionary containing the contents of the module recursively filtered
"""
map_fn = map_fn or (lambda x: x)
is_leaf_fn = is_leaf_fn or (
lambda m, k, v: not isinstance(v, (Module, dict, list))
)
def unwrap(vk, v):
if is_leaf_fn(self, vk, v):
return map_fn(v)
if isinstance(v, Module):
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
if isinstance(v, dict):
nd = {}
for k, v in v.items():
tk = f"{vk}.{k}"
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
return nd
if isinstance(v, list):
nl = []
for i, vi in enumerate(v):
tk = f"{vk}.{i}"
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
def parameters(self):
"""Recursively return all the :class:`mlx.core.array` members of this Module
as a dict of dicts and lists."""
return self.filter_and_map(self.valid_parameter_filter)
def trainable_parameters(self):
"""Recursively return all the non frozen :class:`mlx.core.array` members of
this Module as a dict of dicts and lists."""
return self.filter_and_map(self.trainable_parameter_filter)
def children(self):
"""Return the direct descendants of this Module instance."""
return self.filter_and_map(
self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module)
)
def leaf_modules(self):
"""Return the submodules that do not contain other modules."""
def _is_leaf_module(m, k, v):
return isinstance(v, Module) and len(tree_flatten(v.children())) == 0
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict):
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.
Commonly used by the optimizer to change the model to the updated
(optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the
tracers in the model in order to compute gradients.
The passed in parameters dictionary need not be a full dictionary
similar to :meth:`parameters`. Only the provided locations will be
updated.
Args:
parameters (dict): A complete or partial dictionary of the modules
parameters.
"""
def apply(dst, parameters):
if isinstance(parameters, dict):
for k in parameters:
if k in dst:
current_value = dst[k]
new_value = parameters[k]
if isinstance(current_value, mx.array):
dst[k] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif isinstance(parameters, list):
for i in range(len(dst)):
current_value = dst[i]
new_value = parameters[i]
if isinstance(current_value, mx.array):
dst[i] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
apply(self, parameters)
def apply(
self,
map_fn: Callable[[mx.array], mx.array],
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
):
"""Map all the parameters using the provided ``map_fn`` and immediately
update the module with the mapped parameters.
For instance running ``model.apply(lambda x: x.astype(mx.float16))``
casts all parameters to 16 bit floats.
Args:
map_fn (Callable): Maps an array to another array
filter_fn (Callable, optional): Filter to select which arrays to
map (default: :meth:`Module.valid_parameter_filter`).
"""
filter_fn = filter_fn or Module.valid_parameter_filter
self.update(self.filter_and_map(filter_fn, map_fn))
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
"""Apply a function to all the modules in this instance (including this
instance).
Args:
apply_fn (Callable): The function to apply to the modules.
"""
module_stack = [("", self)]
while module_stack:
prefix, mod = module_stack.pop()
apply_fn(prefix, mod)
prefix = "." + prefix if prefix else ""
module_stack.extend(
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
)
def modules(self):
"""Return a list with all the modules in this instance.
Returns:
A list of :class:`mlx.nn.Module` instances.
"""
modulelist = []
self.apply_to_modules(lambda k, m: modulelist.append(m))
return modulelist
def named_modules(self):
"""Return a list with all the modules in this instance and their name
with dot notation.
Returns:
A list of tuples (str, :class:`mlx.nn.Module`).
"""
modulelist = []
self.apply_to_modules(lambda k, m: modulelist.append((k, m)))
return modulelist
def _validate_keys(self, keys, strict):
keys = keys if isinstance(keys, list) else [keys]
if strict:
for k in keys:
if k not in self:
raise KeyError(f"Module doesn't contain member {k}.")
return keys
def freeze(
self,
*,
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
):
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it.
This function is idempotent ie freezing a frozen model is a noop.
For instance to only train the attention parameters from a transformer:
model = ...
model.freeze()
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
Args:
recurse (bool, optional): If True then freeze the parameters of the
submodules as well (default: True).
keys (str or list[str], optional): If provided then only these
parameters will be frozen otherwise all the parameters of a
module. For instance freeze all biases by calling
``module.freeze(keys="bias")``.
strict (bool, optional): If set to True validate that the passed keys exist
(default: False).
"""
def _freeze_impl(_, m):
local_keys = keys
if local_keys is None:
local_keys = tree_flatten(
m.filter_and_map(
lambda m, k, v: (not isinstance(v, Module))
and m.valid_parameter_filter(m, k, v)
)
)
local_keys = [k for (k, v) in local_keys]
local_keys = m._validate_keys(local_keys, strict)
m._no_grad.update(local_keys)
if recurse:
self.apply_to_modules(_freeze_impl)
else:
_freeze_impl("", self)
def unfreeze(
self,
*,
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
):
"""Unfreeze the Module's parameters or some of them.
This function is idempotent ie unfreezing a model that is not frozen is
a noop.
For instance to only train the biases one can do:
model = ...
model.freeze()
model.unfreeze(keys="bias")
Args:
recurse (bool, optional): If True then unfreeze the parameters of the
submodules as well (default: True).
keys (str or list[str], optional): If provided then only these
parameters will be unfrozen otherwise all the parameters of a
module. For instance unfreeze all biases by calling
``module.unfreeze(keys="bias")``.
strict (bool, optional): If set to True validate that the passed keys exist
(default: False).
"""
def _unfreeze_impl(_, m):
if keys is None:
m._no_grad.clear()
else:
local_keys = m._validate_keys(keys, strict)
m._no_grad.difference_update(local_keys)
if recurse:
self.apply_to_modules(_unfreeze_impl)
else:
_unfreeze_impl("", self)
def train(self, mode: bool = True):
def _set_train(_, m):
m._training = mode
self.apply_to_modules(_set_train)
def eval(self):
self.train(False)

View File

@@ -0,0 +1,22 @@
from mlx.nn.layers.base import Module
class Sequential(Module):
"""A layer that calls the passed callables in order.
We can pass either modules or plain callables to the Sequential module. If
our functions have learnable parameters they should be implemented as
``nn.Module`` instances.
Args:
modules (tuple of Callables): The modules to call in order
"""
def __init__(self, *modules):
super().__init__()
self.layers = list(modules)
def __call__(self, x):
for m in self.layers:
x = m(x)
return x

View File

@@ -0,0 +1,33 @@
import mlx.core as mx
from mlx.nn.layers.base import Module
class Dropout(Module):
"""Randomly zero a portion of the elements during training.
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
:math:`p` is the probability of zeroing an element. This is done so the
expected value of a given element will remain the same.
Args:
p (float): The probability to zero an element
"""
def __init__(self, p: float = 0.5):
super().__init__()
if p < 0 or p >= 1:
raise ValueError("The dropout probability should be in [0, 1)")
self._p_1 = 1 - p
def _extra_repr(self):
return f"p={1-self._p_1}"
def __call__(self, x):
if self._p_1 == 1 or not self.training:
return x
mask = mx.random.bernoulli(self._p_1, x.shape)
return (1 / self._p_1) * mask.astype(x.dtype) * x

View File

@@ -0,0 +1,178 @@
import mlx.core as mx
from mlx.nn.layers.base import Module
class LayerNorm(Module):
r"""Applies layer normalization [1] on the inputs.
Computes
.. math::
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively.
[1]: https://arxiv.org/abs/1607.06450
Args:
dims (int): The feature dimension of the input to normalize over
eps (float): A small additive constant for numerical stability
affine (bool): If True learn an affine transform to apply after the
normalization
"""
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
super().__init__()
if affine:
self.bias = mx.zeros((dims,))
self.weight = mx.ones((dims,))
self.eps = eps
self.dims = dims
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x):
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x
class RMSNorm(Module):
r"""Applies Root Mean Square normalization [1] to the inputs.
Computes
.. math::
y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma
where :math:`\gamma` is a learned per feature dimension parameter initialized at
1.
[1]: https://arxiv.org/abs/1910.07467
Args:
dims (int): The feature dimension of the input to normalize over
eps (float): A small additive constant for numerical stability
"""
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _extra_repr(self):
return f"{self.weight.shape[0]}, eps={self.eps}"
def __call__(self, x):
# S is 1/sqrt(N) where N is the size of the features of x and is used
# to compute a numerically more stable RMS of x by multiplying with S
# first and summing.
#
# This way we prefer underflow over overflow which is controlled with
# the parameter epsilon anyway.
S = 1 / x.shape[-1] ** 0.5
n = (x * S).square().sum(axis=-1, keepdims=True)
n = mx.rsqrt(n + self.eps)
return self.weight * x * n
class GroupNorm(Module):
r"""Applies Group Normalization [1] to the inputs.
Computes the same normalization as layer norm, namely
.. math::
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively. However, the mean and
variance are computed over the spatial dimensions and each group of
features. In particular, the input is split into num_groups accross the
feature dimension.
The feature dimension is assumed to be the last dimension and the dimensions
that precede it (except the first) are considered the spatial dimensions.
[1]: https://arxiv.org/abs/1803.08494
Args:
num_groups (int): Number of groups to separate the features into
dims (int): The feature dimensions of the input to normalize over
eps (float): A small additive constant for numerical stability
affine (bool): If True learn an affine transform to apply after the
normalization.
pytorch_compatible (bool): If True perform the group normalization in
the same order/grouping as PyTorch.
"""
def __init__(
self,
num_groups: int,
dims: int,
eps: float = 1e-5,
affine: bool = True,
pytorch_compatible: bool = False,
):
super().__init__()
if affine:
self.bias = mx.zeros((dims,))
self.weight = mx.ones((dims,))
self.num_groups = num_groups
self.dims = dims
self.eps = eps
self.pytorch_compatible = pytorch_compatible
def _extra_repr(self):
return (
f"{self.num_groups}, {self.dims}, eps={self.eps}, "
f"affine={'weight' in self}, pytorch_compatible={self.pytorch_compatible}"
)
def _pytorch_compatible_group_norm(self, x):
num_groups = self.num_groups
batch, *rest, dims = x.shape
# Split into groups
x = x.reshape(batch, -1, num_groups, dims // num_groups)
x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups)
# Normalize
means = mx.mean(x, axis=1, keepdims=True)
var = mx.var(x, axis=1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
x = x.reshape(batch, -1, dims // num_groups, num_groups)
x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims)
return x
def _group_norm(self, x):
num_groups = self.num_groups
batch, *rest, dims = x.shape
# Split into groups
x = x.reshape(batch, -1, num_groups)
# Normalize
means = mx.mean(x, axis=1, keepdims=True)
var = mx.var(x, axis=1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
x = x.reshape(batch, *rest, dims)
return x
def __call__(self, x):
group_norm = (
self._pytorch_compatible_group_norm
if self.pytorch_compatible
else self._group_norm
)
x = group_norm(x)
return (self.weight * x + self.bias) if "weight" in self else x

View File

@@ -0,0 +1,142 @@
import math
from typing import Optional
import mlx.core as mx
from mlx.nn.layers.base import Module
class RoPE(Module):
"""Implements the rotary positional encoding [1].
The traditional implementation rotates consecutive pairs of elements in the
feature dimension while the default implementation rotates pairs with
stride half the feature dimensions for efficiency.
[1]: https://arxiv.org/abs/2104.09864
Args:
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
traditional (bool): If set to True choose the traditional
implementation which is slightly less efficient.
"""
def __init__(self, dims: int, traditional: bool = False):
super().__init__()
self.dims = dims
self.traditional = traditional
def _extra_repr(self):
return f"{self.dims}, traditional={self.traditional}"
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = mx.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
@staticmethod
def create_cos_sin_theta(
N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32
):
D = D // 2
positions = mx.arange(offset, N, dtype=dtype)
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
costheta = mx.cos(theta)
sintheta = mx.sin(theta)
return costheta, sintheta
class SinusoidalPositionalEncoding(Module):
"""Implements sinusoidal positional encoding similar to [1].
[1]: https://arxiv.org/abs/1706.03762
Args:
dims (int): The dimensionality of the resulting positional embeddings.
min_freq (float): The minimum frequency expected (default: 0.0001)
max_freq (float): The maximum frequency expected (default: 1)
scale (float): Scale the embeddings by that number (default: sqrt(dims//2))
cos_first (bool): If set to True embed using ``[cos(x); sin(x)]``
instead of the other way around (default: False)
full_turns (bool): If set to True multiply the frequencies
with ``2 pi`` (default: False)
"""
def __init__(
self,
dims: int,
min_freq: float = 0.0001,
max_freq: float = 1,
scale: Optional[float] = None,
cos_first: bool = False,
full_turns: bool = False,
):
super().__init__()
one_zero = 1 - mx.arange(0, dims // 2) / (dims // 2 - 1)
min_freq = math.log(min_freq)
max_freq = math.log(max_freq)
# Start with underscore so it is not included in the parameters
self._sigmas = mx.exp(one_zero * (max_freq - min_freq) + min_freq)
if full_turns:
self._sigmas = self._sigmas * (2 * math.pi)
# Save some constants that define the implementation
self.scale = scale or (2 / dims) ** 0.5
self.cos_first = cos_first
def __call__(self, x):
y = x[..., None] * self._sigmas
cosy = mx.cos(y)
siny = mx.sin(y)
if self.cos_first:
y = mx.concatenate([cosy, siny], axis=-1)
else:
y = mx.concatenate([siny, cosy], axis=-1)
if self.scale != 1:
y = y * self.scale
return y

View File

@@ -0,0 +1,136 @@
import math
from typing import Optional
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm
class MultiHeadAttention(Module):
"""Implements the scaled dot product attention with multiple heads.
Given inputs for queries, keys and values the ``MultiHeadAttention`` produces
new values by aggregating information from the input values according to
the similarities of the input queries and keys.
All inputs as well as the output are lineary projected without biases.
MultiHeadAttention also expects an additive attention mask that should be
broadcastable with (batch, num_heads, # queries, # keys). The mask should
have ``-inf`` or very negative numbers to the positions that should *not* be
attended to.
Args:
dims (int): The model dimensions. If no other dims are provided then
dims is used for queries, keys, values and the output.
num_heads (int): How many attention heads to use
query_input_dims (int, optional): The input dimensions of the queries (default: dims).
key_input_dims (int, optional): The input dimensions of the keys (default: dims).
value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims).
value_dims (int, optional): The dimensions of the values after the projection (default: dims).
value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims).
"""
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
f"The input feature dimensions should be divisble by the number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.query_proj = Linear(query_input_dims, dims, False)
self.key_proj = Linear(key_input_dims, dims, False)
self.value_proj = Linear(value_input_dims, value_dims, False)
self.out_proj = Linear(value_dims, value_output_dims, False)
def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
@staticmethod
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
# TODO: Should replace this with finfo(dtype).min
mask = mask.astype(dtype) * -1e9
return mask
class TransformerEncoderLayer(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads)
self.ln1 = LayerNorm(dims)
self.ln2 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims)
def __call__(self, x, mask):
y = self.ln1(x)
y = self.attention(y, y, y, mask)
x = x + y
y = self.ln2(x)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
x = x + y
return x
class TransformerEncoder(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims)
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
def __call__(self, x, mask):
for l in self.layers:
x = l(x, mask)
x = self.ln(x)
return x