Transformer fix (#167)

* add transformer with dropout, fix transformer ffm, layernorm order

* precommit changes

* precommit changes

* add docstring, activation, norm_first

* run precommit

* run precommit

* add doctstring

* precommit

* style nits in docs

---------

Co-authored-by: junwoo-yun <junwoo.yun@bagelcode.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
YUN, Junwoo 2023-12-28 00:48:36 +08:00 committed by GitHub
parent 79c95b6919
commit 4417e37ede
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 220 additions and 77 deletions

View File

@ -9,7 +9,7 @@ Layers
:toctree: _autosummary
:template: nn-module-template.rst
Embedding
Sequential
ReLU
PReLU
GELU
@ -17,17 +17,19 @@ Layers
Step
SELU
Mish
Embedding
Linear
QuantizedLinear
Conv1d
Conv2d
BatchNorm
LayerNorm
RMSNorm
GroupNorm
RoPE
MultiHeadAttention
Sequential
QuantizedLinear
Dropout
Dropout2d
Transformer
MultiHeadAttention
ALiBi
RoPE
SinusoidalPositionalEncoding

View File

@ -14,6 +14,7 @@ from mlx.nn.layers.activations import (
SiLU,
Softplus,
Step,
Tanh,
celu,
elu,
gelu,
@ -29,6 +30,7 @@ from mlx.nn.layers.activations import (
silu,
softplus,
step,
tanh,
)
from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential
@ -41,6 +43,7 @@ from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalE
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.transformer import (
MultiHeadAttention,
Transformer,
TransformerEncoder,
TransformerEncoderLayer,
)

View File

@ -179,12 +179,12 @@ def selu(x):
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
r"""Applies the element-wise function:
r"""Applies the element-wise parametric ReLU.
.. math::
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
Here :math:`a` is an array.
where :math:`a` is an array.
"""
return mx.maximum(0, x) + alpha * mx.minimum(0, x)

View File

@ -8,21 +8,22 @@ from mlx.nn.layers.base import Module
class RoPE(Module):
"""Implements the rotary positional encoding [1].
"""Implements the rotary positional encoding.
The traditional implementation rotates consecutive pairs of elements in the
feature dimension while the default implementation rotates pairs with
stride half the feature dimensions for efficiency.
[1]: https://arxiv.org/abs/2104.09864
For more details see `RoFormer: Enhanced Transformer with Rotary Position
Embedding <https://arxiv.org/abs/2104.09864>`_.
Args:
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
traditional (bool, optional): If set to True choose the traditional
implementation which is slightly less efficient. Default: ``False``
implementation which is slightly less efficient. Default: ``False``.
base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Default: ``10000``
each dimension in the positional encodings. Default: ``10000``.
"""
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
@ -89,19 +90,23 @@ class RoPE(Module):
class SinusoidalPositionalEncoding(Module):
"""Implements sinusoidal positional encoding similar to [1].
r"""Implements sinusoidal positional encoding.
[1]: https://arxiv.org/abs/1706.03762
For more details see the paper `Attention Is All You Need
<https://arxiv.org/abs/1706.03762>`_.
Args:
dims (int): The dimensionality of the resulting positional embeddings.
min_freq (float): The minimum frequency expected (default: 0.0001)
max_freq (float): The maximum frequency expected (default: 1)
scale (float): Scale the embeddings by that number (default: sqrt(dims//2))
cos_first (bool): If set to True embed using ``[cos(x); sin(x)]``
instead of the other way around (default: False)
full_turns (bool): If set to True multiply the frequencies
with ``2 pi`` (default: False)
min_freq (float, optional): The minimum frequency expected. Default:
``0.0001``.
max_freq (float, optional): The maximum frequency expected. Default:
``1``.
scale (float, optional): A multiplicative scale for the embeddings.
Default: ``sqrt(dims//2)``.
cos_first (bool, optional): If ``True`` embed using ``[cos(x); sin(x)]``
instead of the reverse. Default: ``False``.
full_turns (bool, optional): If ``True`` multiply the frequencies with
:math:`2\pi`. Default: ``False``.
"""
def __init__(

View File

@ -1,10 +1,12 @@
# Copyright © 2023 Apple Inc.
import math
from typing import Any, Optional
from typing import Any, Callable, Optional
import mlx.core as mx
from mlx.nn.layers.activations import relu
from mlx.nn.layers.base import Module
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm
@ -12,26 +14,34 @@ from mlx.nn.layers.normalization import LayerNorm
class MultiHeadAttention(Module):
"""Implements the scaled dot product attention with multiple heads.
Given inputs for queries, keys and values the ``MultiHeadAttention`` produces
new values by aggregating information from the input values according to
the similarities of the input queries and keys.
Given inputs for queries, keys and values the ``MultiHeadAttention``
produces new values by aggregating information from the input values
according to the similarities of the input queries and keys.
All inputs as well as the output are linearly projected without biases.
All inputs as well as the output are linearly projected without biases by
default.
MultiHeadAttention also expects an additive attention mask that should be
broadcastable with (batch, num_heads, # queries, # keys). The mask should
have ``-inf`` or very negative numbers to the positions that should *not* be
attended to.
``MultiHeadAttention`` also takes an optional additive attention mask that
should be broadcastable with ``(batch, num_heads, # queries, # keys)``. The
mask should have ``-inf`` or very large negative numbers at the positions
that should *not* be attended to.
Args:
dims (int): The model dimensions. If no other dims are provided then
dims is used for queries, keys, values and the output.
num_heads (int): How many attention heads to use
query_input_dims (int, optional): The input dimensions of the queries (default: dims).
key_input_dims (int, optional): The input dimensions of the keys (default: dims).
value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims).
value_dims (int, optional): The dimensions of the values after the projection (default: dims).
value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims).
dims (int): The model dimensions. This is also the default
value for the queries, keys, values, and the output.
num_heads (int): The number of attention heads to use.
query_input_dims (int, optional): The input dimensions of the queries.
Default: ``dims``.
key_input_dims (int, optional): The input dimensions of the keys.
Default: ``dims``.
value_input_dims (int, optional): The input dimensions of the values.
Default: ``key_input_dims``.
value_dims (int, optional): The dimensions of the values after the
projection. Default: ``dims``.
value_output_dims (int, optional): The dimensions the new values will
be projected to. Default: ``dims``.
bias (bool, optional): Whether or not to use a bias in the projections.
Default: ``False``.
"""
def __init__(
@ -49,7 +59,8 @@ class MultiHeadAttention(Module):
if (dims % num_heads) != 0:
raise ValueError(
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
"The input feature dimensions should be divisible by the "
f"number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
@ -97,7 +108,15 @@ class MultiHeadAttention(Module):
class TransformerEncoderLayer(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation: Callable[[Any], Any] = relu,
norm_first: bool = False,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads)
@ -105,28 +124,55 @@ class TransformerEncoderLayer(Module):
self.ln2 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.activation = activation
self.norm_first = norm_first
def __call__(self, x, mask):
if self.norm_first:
y = self.ln1(x)
y = self.attention(y, y, y, mask)
y = self.dropout1(y)
x = x + y
y = self.ln2(x)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.activation(y)
y = self.dropout2(y)
y = self.linear2(y)
x = x + y
y = x + y
return x
else:
y = self.attention(x, x, x, mask)
y = self.dropout1(y)
y = self.ln1(x + y)
y = self.linear1(y)
y = self.activation(y)
y = self.dropout2(y)
y = self.linear2(y)
y = self.ln2(x + y)
return y
class TransformerEncoder(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
self,
num_layers: int,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation=relu,
norm_first: bool = False,
):
super().__init__()
self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims)
TransformerEncoderLayer(
dims, num_heads, mlp_dims, dropout, activation, norm_first
)
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
@ -134,13 +180,19 @@ class TransformerEncoder(Module):
def __call__(self, x, mask):
for l in self.layers:
x = l(x, mask)
x = self.ln(x)
return x
return self.ln(x)
class TransformerDecoderLayer(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation: Callable[[Any], Any] = relu,
norm_first: bool = False,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.self_attention = MultiHeadAttention(dims, num_heads)
@ -150,32 +202,65 @@ class TransformerDecoderLayer(Module):
self.ln3 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)
self.activation = activation
self.norm_first = norm_first
def __call__(self, x, memory, x_mask, memory_mask):
if self.norm_first:
y = self.ln1(x)
y = self.self_attention(y, y, y, x_mask)
y = self.dropout1(y)
x = x + y
y = self.ln2(x)
y = self.cross_attention(y, memory, memory, memory_mask)
y = self.dropout2(y)
x = x + y
y = self.ln3(x)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.activation(y)
y = self.dropout3(y)
y = self.linear2(y)
x = x + y
y = x + y
return x
else:
y = self.self_attention(x, x, x, x_mask)
y = self.dropout1(y)
x = self.ln1(x + y)
y = self.cross_attention(y, memory, memory, memory_mask)
y = self.dropout2(y)
x = self.ln1(x + y)
y = self.linear1(x)
y = self.activation(y)
y = self.dropout3(y)
y = self.linear2(y)
y = self.ln3(x + y)
return y
class TransformerDecoder(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
self,
num_layers: int,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation=relu,
norm_first: bool = False,
):
super().__init__()
self.layers = [
TransformerDecoderLayer(dims, num_heads, mlp_dims)
TransformerDecoderLayer(
dims, num_heads, mlp_dims, dropout, activation, norm_first
)
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
@ -183,12 +268,47 @@ class TransformerDecoder(Module):
def __call__(self, x, memory, x_mask, memory_mask):
for l in self.layers:
x = l(x, memory, x_mask, memory_mask)
x = self.ln(x)
return x
return self.ln(x)
class Transformer(Module):
"""
Implements a standard Transformer model.
The implementation is based on `Attention Is All You Need
<https://arxiv.org/abs/1706.03762>`_.
The Transformer model contains an encoder and a decoder. The encoder
processes the input sequence and the decoder generates the output sequence.
The interaction between encoder and decoder happens through the attention
mechanism.
Args:
dims (int, optional): The number of expected features in the
encoder/decoder inputs. Default: ``512``.
num_heads (int, optional): The number of attention heads. Default:
``8``.
num_encoder_layers (int, optional): The number of encoder layers in the
Transformer encoder. Default: ``6``.
num_decoder_layers (int, optional): The number of decoder layers in the
Transformer decoder. Default: ``6``.
mlp_dims (int, optional): The hidden dimension of the MLP block in each
Transformer layer. Defaults to ``4*dims`` if not provided. Default:
``None``.
dropout (float, optional): The dropout value for the Transformer
encoder and decoder. Dropout is used after each attention layer and
the activation in the MLP layer. Default: ``0.0``.
activation (function, optional): the activation function for the MLP
hidden layer. Default: :func:`mlx.nn.relu`.
custom_encoder (nn.Module, optional): A custom encoder to replace the
standard Transformer encoder. Default: ``None``.
custom_decoder (nn.Module, optional): A custom decoder to replace the
standard Transformer decoder. Default: ``None``.
norm_first (bool, optional): if ``True``, encoder and decoder layers
will perform layer normalization before attention and MLP
operations, otherwise after. Default: ``False``.
"""
def __init__(
self,
dims: int = 512,
@ -196,26 +316,39 @@ class Transformer(Module):
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation: Callable[[Any], Any] = relu,
custom_encoder: Optional[Any] = None,
custom_decoder: Optional[Any] = None,
norm_first: bool = False,
):
super().__init__()
if custom_encoder is not None:
self.encoder = custom_encoder
else:
self.encoder = TransformerEncoder(
num_encoder_layers, dims, num_heads, mlp_dims
num_encoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
)
if custom_decoder is not None:
self.decoder = custom_decoder
else:
self.decoder = TransformerDecoder(
num_decoder_layers, dims, num_heads, mlp_dims
num_decoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
)
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
memory = self.encoder(src, src_mask)
output = self.decoder(tgt, memory, tgt_mask, memory_mask)
return output
return self.decoder(tgt, memory, tgt_mask, memory_mask)