Transformer fix (#167)

* add transformer with dropout, fix transformer ffm, layernorm order

* precommit changes

* precommit changes

* add docstring, activation, norm_first

* run precommit

* run precommit

* add doctstring

* precommit

* style nits in docs

---------

Co-authored-by: junwoo-yun <junwoo.yun@bagelcode.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
YUN, Junwoo 2023-12-28 00:48:36 +08:00 committed by GitHub
parent 79c95b6919
commit 4417e37ede
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 220 additions and 77 deletions

View File

@ -9,7 +9,7 @@ Layers
:toctree: _autosummary :toctree: _autosummary
:template: nn-module-template.rst :template: nn-module-template.rst
Embedding Sequential
ReLU ReLU
PReLU PReLU
GELU GELU
@ -17,17 +17,19 @@ Layers
Step Step
SELU SELU
Mish Mish
Embedding
Linear Linear
QuantizedLinear
Conv1d Conv1d
Conv2d Conv2d
BatchNorm BatchNorm
LayerNorm LayerNorm
RMSNorm RMSNorm
GroupNorm GroupNorm
RoPE
MultiHeadAttention
Sequential
QuantizedLinear
Dropout Dropout
Dropout2d Dropout2d
Transformer
MultiHeadAttention
ALiBi
RoPE
SinusoidalPositionalEncoding

View File

@ -14,6 +14,7 @@ from mlx.nn.layers.activations import (
SiLU, SiLU,
Softplus, Softplus,
Step, Step,
Tanh,
celu, celu,
elu, elu,
gelu, gelu,
@ -29,6 +30,7 @@ from mlx.nn.layers.activations import (
silu, silu,
softplus, softplus,
step, step,
tanh,
) )
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential from mlx.nn.layers.containers import Sequential
@ -41,6 +43,7 @@ from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalE
from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.transformer import ( from mlx.nn.layers.transformer import (
MultiHeadAttention, MultiHeadAttention,
Transformer,
TransformerEncoder, TransformerEncoder,
TransformerEncoderLayer, TransformerEncoderLayer,
) )

View File

@ -179,12 +179,12 @@ def selu(x):
def prelu(x: mx.array, alpha: mx.array) -> mx.array: def prelu(x: mx.array, alpha: mx.array) -> mx.array:
r"""Applies the element-wise function: r"""Applies the element-wise parametric ReLU.
.. math:: .. math::
\text{PReLU}(x) = \max(0,x) + a * \min(0,x) \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
Here :math:`a` is an array. where :math:`a` is an array.
""" """
return mx.maximum(0, x) + alpha * mx.minimum(0, x) return mx.maximum(0, x) + alpha * mx.minimum(0, x)

View File

@ -8,21 +8,22 @@ from mlx.nn.layers.base import Module
class RoPE(Module): class RoPE(Module):
"""Implements the rotary positional encoding [1]. """Implements the rotary positional encoding.
The traditional implementation rotates consecutive pairs of elements in the The traditional implementation rotates consecutive pairs of elements in the
feature dimension while the default implementation rotates pairs with feature dimension while the default implementation rotates pairs with
stride half the feature dimensions for efficiency. stride half the feature dimensions for efficiency.
[1]: https://arxiv.org/abs/2104.09864 For more details see `RoFormer: Enhanced Transformer with Rotary Position
Embedding <https://arxiv.org/abs/2104.09864>`_.
Args: Args:
dims (int): The feature dimensions to be rotated. If the input feature dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged. is larger than dims then the rest is left unchanged.
traditional (bool, optional): If set to True choose the traditional traditional (bool, optional): If set to True choose the traditional
implementation which is slightly less efficient. Default: ``False`` implementation which is slightly less efficient. Default: ``False``.
base (float, optional): The base used to compute angular frequency for base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Default: ``10000`` each dimension in the positional encodings. Default: ``10000``.
""" """
def __init__(self, dims: int, traditional: bool = False, base: float = 10000): def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
@ -89,19 +90,23 @@ class RoPE(Module):
class SinusoidalPositionalEncoding(Module): class SinusoidalPositionalEncoding(Module):
"""Implements sinusoidal positional encoding similar to [1]. r"""Implements sinusoidal positional encoding.
[1]: https://arxiv.org/abs/1706.03762 For more details see the paper `Attention Is All You Need
<https://arxiv.org/abs/1706.03762>`_.
Args: Args:
dims (int): The dimensionality of the resulting positional embeddings. dims (int): The dimensionality of the resulting positional embeddings.
min_freq (float): The minimum frequency expected (default: 0.0001) min_freq (float, optional): The minimum frequency expected. Default:
max_freq (float): The maximum frequency expected (default: 1) ``0.0001``.
scale (float): Scale the embeddings by that number (default: sqrt(dims//2)) max_freq (float, optional): The maximum frequency expected. Default:
cos_first (bool): If set to True embed using ``[cos(x); sin(x)]`` ``1``.
instead of the other way around (default: False) scale (float, optional): A multiplicative scale for the embeddings.
full_turns (bool): If set to True multiply the frequencies Default: ``sqrt(dims//2)``.
with ``2 pi`` (default: False) cos_first (bool, optional): If ``True`` embed using ``[cos(x); sin(x)]``
instead of the reverse. Default: ``False``.
full_turns (bool, optional): If ``True`` multiply the frequencies with
:math:`2\pi`. Default: ``False``.
""" """
def __init__( def __init__(

View File

@ -1,10 +1,12 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import math import math
from typing import Any, Optional from typing import Any, Callable, Optional
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.activations import relu
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.linear import Linear from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm from mlx.nn.layers.normalization import LayerNorm
@ -12,26 +14,34 @@ from mlx.nn.layers.normalization import LayerNorm
class MultiHeadAttention(Module): class MultiHeadAttention(Module):
"""Implements the scaled dot product attention with multiple heads. """Implements the scaled dot product attention with multiple heads.
Given inputs for queries, keys and values the ``MultiHeadAttention`` produces Given inputs for queries, keys and values the ``MultiHeadAttention``
new values by aggregating information from the input values according to produces new values by aggregating information from the input values
the similarities of the input queries and keys. according to the similarities of the input queries and keys.
All inputs as well as the output are linearly projected without biases. All inputs as well as the output are linearly projected without biases by
default.
MultiHeadAttention also expects an additive attention mask that should be ``MultiHeadAttention`` also takes an optional additive attention mask that
broadcastable with (batch, num_heads, # queries, # keys). The mask should should be broadcastable with ``(batch, num_heads, # queries, # keys)``. The
have ``-inf`` or very negative numbers to the positions that should *not* be mask should have ``-inf`` or very large negative numbers at the positions
attended to. that should *not* be attended to.
Args: Args:
dims (int): The model dimensions. If no other dims are provided then dims (int): The model dimensions. This is also the default
dims is used for queries, keys, values and the output. value for the queries, keys, values, and the output.
num_heads (int): How many attention heads to use num_heads (int): The number of attention heads to use.
query_input_dims (int, optional): The input dimensions of the queries (default: dims). query_input_dims (int, optional): The input dimensions of the queries.
key_input_dims (int, optional): The input dimensions of the keys (default: dims). Default: ``dims``.
value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims). key_input_dims (int, optional): The input dimensions of the keys.
value_dims (int, optional): The dimensions of the values after the projection (default: dims). Default: ``dims``.
value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims). value_input_dims (int, optional): The input dimensions of the values.
Default: ``key_input_dims``.
value_dims (int, optional): The dimensions of the values after the
projection. Default: ``dims``.
value_output_dims (int, optional): The dimensions the new values will
be projected to. Default: ``dims``.
bias (bool, optional): Whether or not to use a bias in the projections.
Default: ``False``.
""" """
def __init__( def __init__(
@ -49,7 +59,8 @@ class MultiHeadAttention(Module):
if (dims % num_heads) != 0: if (dims % num_heads) != 0:
raise ValueError( raise ValueError(
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0" "The input feature dimensions should be divisible by the "
f"number of heads ({dims} % {num_heads}) != 0"
) )
query_input_dims = query_input_dims or dims query_input_dims = query_input_dims or dims
@ -97,7 +108,15 @@ class MultiHeadAttention(Module):
class TransformerEncoderLayer(Module): class TransformerEncoderLayer(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation: Callable[[Any], Any] = relu,
norm_first: bool = False,
):
super().__init__() super().__init__()
mlp_dims = mlp_dims or dims * 4 mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads) self.attention = MultiHeadAttention(dims, num_heads)
@ -105,28 +124,55 @@ class TransformerEncoderLayer(Module):
self.ln2 = LayerNorm(dims) self.ln2 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims) self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims) self.linear2 = Linear(mlp_dims, dims)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.activation = activation
self.norm_first = norm_first
def __call__(self, x, mask): def __call__(self, x, mask):
if self.norm_first:
y = self.ln1(x) y = self.ln1(x)
y = self.attention(y, y, y, mask) y = self.attention(y, y, y, mask)
y = self.dropout1(y)
x = x + y x = x + y
y = self.ln2(x) y = self.ln2(x)
y = self.linear1(y) y = self.linear1(y)
y = mx.maximum(y, 0) y = self.activation(y)
y = self.dropout2(y)
y = self.linear2(y) y = self.linear2(y)
x = x + y y = x + y
return x else:
y = self.attention(x, x, x, mask)
y = self.dropout1(y)
y = self.ln1(x + y)
y = self.linear1(y)
y = self.activation(y)
y = self.dropout2(y)
y = self.linear2(y)
y = self.ln2(x + y)
return y
class TransformerEncoder(Module): class TransformerEncoder(Module):
def __init__( def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None self,
num_layers: int,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation=relu,
norm_first: bool = False,
): ):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims) TransformerEncoderLayer(
dims, num_heads, mlp_dims, dropout, activation, norm_first
)
for i in range(num_layers) for i in range(num_layers)
] ]
self.ln = LayerNorm(dims) self.ln = LayerNorm(dims)
@ -134,13 +180,19 @@ class TransformerEncoder(Module):
def __call__(self, x, mask): def __call__(self, x, mask):
for l in self.layers: for l in self.layers:
x = l(x, mask) x = l(x, mask)
x = self.ln(x) return self.ln(x)
return x
class TransformerDecoderLayer(Module): class TransformerDecoderLayer(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation: Callable[[Any], Any] = relu,
norm_first: bool = False,
):
super().__init__() super().__init__()
mlp_dims = mlp_dims or dims * 4 mlp_dims = mlp_dims or dims * 4
self.self_attention = MultiHeadAttention(dims, num_heads) self.self_attention = MultiHeadAttention(dims, num_heads)
@ -150,32 +202,65 @@ class TransformerDecoderLayer(Module):
self.ln3 = LayerNorm(dims) self.ln3 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims) self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims) self.linear2 = Linear(mlp_dims, dims)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)
self.activation = activation
self.norm_first = norm_first
def __call__(self, x, memory, x_mask, memory_mask): def __call__(self, x, memory, x_mask, memory_mask):
if self.norm_first:
y = self.ln1(x) y = self.ln1(x)
y = self.self_attention(y, y, y, x_mask) y = self.self_attention(y, y, y, x_mask)
y = self.dropout1(y)
x = x + y x = x + y
y = self.ln2(x) y = self.ln2(x)
y = self.cross_attention(y, memory, memory, memory_mask) y = self.cross_attention(y, memory, memory, memory_mask)
y = self.dropout2(y)
x = x + y x = x + y
y = self.ln3(x) y = self.ln3(x)
y = self.linear1(y) y = self.linear1(y)
y = mx.maximum(y, 0) y = self.activation(y)
y = self.dropout3(y)
y = self.linear2(y) y = self.linear2(y)
x = x + y y = x + y
return x else:
y = self.self_attention(x, x, x, x_mask)
y = self.dropout1(y)
x = self.ln1(x + y)
y = self.cross_attention(y, memory, memory, memory_mask)
y = self.dropout2(y)
x = self.ln1(x + y)
y = self.linear1(x)
y = self.activation(y)
y = self.dropout3(y)
y = self.linear2(y)
y = self.ln3(x + y)
return y
class TransformerDecoder(Module): class TransformerDecoder(Module):
def __init__( def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None self,
num_layers: int,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation=relu,
norm_first: bool = False,
): ):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerDecoderLayer(dims, num_heads, mlp_dims) TransformerDecoderLayer(
dims, num_heads, mlp_dims, dropout, activation, norm_first
)
for i in range(num_layers) for i in range(num_layers)
] ]
self.ln = LayerNorm(dims) self.ln = LayerNorm(dims)
@ -183,12 +268,47 @@ class TransformerDecoder(Module):
def __call__(self, x, memory, x_mask, memory_mask): def __call__(self, x, memory, x_mask, memory_mask):
for l in self.layers: for l in self.layers:
x = l(x, memory, x_mask, memory_mask) x = l(x, memory, x_mask, memory_mask)
x = self.ln(x) return self.ln(x)
return x
class Transformer(Module): class Transformer(Module):
"""
Implements a standard Transformer model.
The implementation is based on `Attention Is All You Need
<https://arxiv.org/abs/1706.03762>`_.
The Transformer model contains an encoder and a decoder. The encoder
processes the input sequence and the decoder generates the output sequence.
The interaction between encoder and decoder happens through the attention
mechanism.
Args:
dims (int, optional): The number of expected features in the
encoder/decoder inputs. Default: ``512``.
num_heads (int, optional): The number of attention heads. Default:
``8``.
num_encoder_layers (int, optional): The number of encoder layers in the
Transformer encoder. Default: ``6``.
num_decoder_layers (int, optional): The number of decoder layers in the
Transformer decoder. Default: ``6``.
mlp_dims (int, optional): The hidden dimension of the MLP block in each
Transformer layer. Defaults to ``4*dims`` if not provided. Default:
``None``.
dropout (float, optional): The dropout value for the Transformer
encoder and decoder. Dropout is used after each attention layer and
the activation in the MLP layer. Default: ``0.0``.
activation (function, optional): the activation function for the MLP
hidden layer. Default: :func:`mlx.nn.relu`.
custom_encoder (nn.Module, optional): A custom encoder to replace the
standard Transformer encoder. Default: ``None``.
custom_decoder (nn.Module, optional): A custom decoder to replace the
standard Transformer decoder. Default: ``None``.
norm_first (bool, optional): if ``True``, encoder and decoder layers
will perform layer normalization before attention and MLP
operations, otherwise after. Default: ``False``.
"""
def __init__( def __init__(
self, self,
dims: int = 512, dims: int = 512,
@ -196,26 +316,39 @@ class Transformer(Module):
num_encoder_layers: int = 6, num_encoder_layers: int = 6,
num_decoder_layers: int = 6, num_decoder_layers: int = 6,
mlp_dims: Optional[int] = None, mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation: Callable[[Any], Any] = relu,
custom_encoder: Optional[Any] = None, custom_encoder: Optional[Any] = None,
custom_decoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
norm_first: bool = False,
): ):
super().__init__() super().__init__()
if custom_encoder is not None: if custom_encoder is not None:
self.encoder = custom_encoder self.encoder = custom_encoder
else: else:
self.encoder = TransformerEncoder( self.encoder = TransformerEncoder(
num_encoder_layers, dims, num_heads, mlp_dims num_encoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
) )
if custom_decoder is not None: if custom_decoder is not None:
self.decoder = custom_decoder self.decoder = custom_decoder
else: else:
self.decoder = TransformerDecoder( self.decoder = TransformerDecoder(
num_decoder_layers, dims, num_heads, mlp_dims num_decoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
) )
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask): def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
memory = self.encoder(src, src_mask) memory = self.encoder(src, src_mask)
output = self.decoder(tgt, memory, tgt_mask, memory_mask) return self.decoder(tgt, memory, tgt_mask, memory_mask)
return output