mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Transformer fix (#167)
* add transformer with dropout, fix transformer ffm, layernorm order * precommit changes * precommit changes * add docstring, activation, norm_first * run precommit * run precommit * add doctstring * precommit * style nits in docs --------- Co-authored-by: junwoo-yun <junwoo.yun@bagelcode.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
79c95b6919
commit
4417e37ede
@ -9,7 +9,7 @@ Layers
|
||||
:toctree: _autosummary
|
||||
:template: nn-module-template.rst
|
||||
|
||||
Embedding
|
||||
Sequential
|
||||
ReLU
|
||||
PReLU
|
||||
GELU
|
||||
@ -17,17 +17,19 @@ Layers
|
||||
Step
|
||||
SELU
|
||||
Mish
|
||||
Embedding
|
||||
Linear
|
||||
QuantizedLinear
|
||||
Conv1d
|
||||
Conv2d
|
||||
BatchNorm
|
||||
LayerNorm
|
||||
RMSNorm
|
||||
GroupNorm
|
||||
RoPE
|
||||
MultiHeadAttention
|
||||
Sequential
|
||||
QuantizedLinear
|
||||
Dropout
|
||||
Dropout2d
|
||||
|
||||
Transformer
|
||||
MultiHeadAttention
|
||||
ALiBi
|
||||
RoPE
|
||||
SinusoidalPositionalEncoding
|
||||
|
@ -14,6 +14,7 @@ from mlx.nn.layers.activations import (
|
||||
SiLU,
|
||||
Softplus,
|
||||
Step,
|
||||
Tanh,
|
||||
celu,
|
||||
elu,
|
||||
gelu,
|
||||
@ -29,6 +30,7 @@ from mlx.nn.layers.activations import (
|
||||
silu,
|
||||
softplus,
|
||||
step,
|
||||
tanh,
|
||||
)
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.containers import Sequential
|
||||
@ -41,6 +43,7 @@ from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalE
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
from mlx.nn.layers.transformer import (
|
||||
MultiHeadAttention,
|
||||
Transformer,
|
||||
TransformerEncoder,
|
||||
TransformerEncoderLayer,
|
||||
)
|
||||
|
@ -179,12 +179,12 @@ def selu(x):
|
||||
|
||||
|
||||
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
||||
r"""Applies the element-wise function:
|
||||
r"""Applies the element-wise parametric ReLU.
|
||||
|
||||
.. math::
|
||||
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
||||
|
||||
Here :math:`a` is an array.
|
||||
where :math:`a` is an array.
|
||||
"""
|
||||
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
|
||||
|
||||
|
@ -8,21 +8,22 @@ from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class RoPE(Module):
|
||||
"""Implements the rotary positional encoding [1].
|
||||
"""Implements the rotary positional encoding.
|
||||
|
||||
The traditional implementation rotates consecutive pairs of elements in the
|
||||
feature dimension while the default implementation rotates pairs with
|
||||
stride half the feature dimensions for efficiency.
|
||||
|
||||
[1]: https://arxiv.org/abs/2104.09864
|
||||
For more details see `RoFormer: Enhanced Transformer with Rotary Position
|
||||
Embedding <https://arxiv.org/abs/2104.09864>`_.
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimensions to be rotated. If the input feature
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool, optional): If set to True choose the traditional
|
||||
implementation which is slightly less efficient. Default: ``False``
|
||||
implementation which is slightly less efficient. Default: ``False``.
|
||||
base (float, optional): The base used to compute angular frequency for
|
||||
each dimension in the positional encodings. Default: ``10000``
|
||||
each dimension in the positional encodings. Default: ``10000``.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
|
||||
@ -89,19 +90,23 @@ class RoPE(Module):
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(Module):
|
||||
"""Implements sinusoidal positional encoding similar to [1].
|
||||
r"""Implements sinusoidal positional encoding.
|
||||
|
||||
[1]: https://arxiv.org/abs/1706.03762
|
||||
For more details see the paper `Attention Is All You Need
|
||||
<https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
Args:
|
||||
dims (int): The dimensionality of the resulting positional embeddings.
|
||||
min_freq (float): The minimum frequency expected (default: 0.0001)
|
||||
max_freq (float): The maximum frequency expected (default: 1)
|
||||
scale (float): Scale the embeddings by that number (default: sqrt(dims//2))
|
||||
cos_first (bool): If set to True embed using ``[cos(x); sin(x)]``
|
||||
instead of the other way around (default: False)
|
||||
full_turns (bool): If set to True multiply the frequencies
|
||||
with ``2 pi`` (default: False)
|
||||
min_freq (float, optional): The minimum frequency expected. Default:
|
||||
``0.0001``.
|
||||
max_freq (float, optional): The maximum frequency expected. Default:
|
||||
``1``.
|
||||
scale (float, optional): A multiplicative scale for the embeddings.
|
||||
Default: ``sqrt(dims//2)``.
|
||||
cos_first (bool, optional): If ``True`` embed using ``[cos(x); sin(x)]``
|
||||
instead of the reverse. Default: ``False``.
|
||||
full_turns (bool, optional): If ``True`` multiply the frequencies with
|
||||
:math:`2\pi`. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -1,10 +1,12 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.activations import relu
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.dropout import Dropout
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.nn.layers.normalization import LayerNorm
|
||||
|
||||
@ -12,26 +14,34 @@ from mlx.nn.layers.normalization import LayerNorm
|
||||
class MultiHeadAttention(Module):
|
||||
"""Implements the scaled dot product attention with multiple heads.
|
||||
|
||||
Given inputs for queries, keys and values the ``MultiHeadAttention`` produces
|
||||
new values by aggregating information from the input values according to
|
||||
the similarities of the input queries and keys.
|
||||
Given inputs for queries, keys and values the ``MultiHeadAttention``
|
||||
produces new values by aggregating information from the input values
|
||||
according to the similarities of the input queries and keys.
|
||||
|
||||
All inputs as well as the output are linearly projected without biases.
|
||||
All inputs as well as the output are linearly projected without biases by
|
||||
default.
|
||||
|
||||
MultiHeadAttention also expects an additive attention mask that should be
|
||||
broadcastable with (batch, num_heads, # queries, # keys). The mask should
|
||||
have ``-inf`` or very negative numbers to the positions that should *not* be
|
||||
attended to.
|
||||
``MultiHeadAttention`` also takes an optional additive attention mask that
|
||||
should be broadcastable with ``(batch, num_heads, # queries, # keys)``. The
|
||||
mask should have ``-inf`` or very large negative numbers at the positions
|
||||
that should *not* be attended to.
|
||||
|
||||
Args:
|
||||
dims (int): The model dimensions. If no other dims are provided then
|
||||
dims is used for queries, keys, values and the output.
|
||||
num_heads (int): How many attention heads to use
|
||||
query_input_dims (int, optional): The input dimensions of the queries (default: dims).
|
||||
key_input_dims (int, optional): The input dimensions of the keys (default: dims).
|
||||
value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims).
|
||||
value_dims (int, optional): The dimensions of the values after the projection (default: dims).
|
||||
value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims).
|
||||
dims (int): The model dimensions. This is also the default
|
||||
value for the queries, keys, values, and the output.
|
||||
num_heads (int): The number of attention heads to use.
|
||||
query_input_dims (int, optional): The input dimensions of the queries.
|
||||
Default: ``dims``.
|
||||
key_input_dims (int, optional): The input dimensions of the keys.
|
||||
Default: ``dims``.
|
||||
value_input_dims (int, optional): The input dimensions of the values.
|
||||
Default: ``key_input_dims``.
|
||||
value_dims (int, optional): The dimensions of the values after the
|
||||
projection. Default: ``dims``.
|
||||
value_output_dims (int, optional): The dimensions the new values will
|
||||
be projected to. Default: ``dims``.
|
||||
bias (bool, optional): Whether or not to use a bias in the projections.
|
||||
Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -49,7 +59,8 @@ class MultiHeadAttention(Module):
|
||||
|
||||
if (dims % num_heads) != 0:
|
||||
raise ValueError(
|
||||
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
|
||||
"The input feature dimensions should be divisible by the "
|
||||
f"number of heads ({dims} % {num_heads}) != 0"
|
||||
)
|
||||
|
||||
query_input_dims = query_input_dims or dims
|
||||
@ -97,7 +108,15 @@ class MultiHeadAttention(Module):
|
||||
|
||||
|
||||
class TransformerEncoderLayer(Module):
|
||||
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
mlp_dims: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
activation: Callable[[Any], Any] = relu,
|
||||
norm_first: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
mlp_dims = mlp_dims or dims * 4
|
||||
self.attention = MultiHeadAttention(dims, num_heads)
|
||||
@ -105,28 +124,55 @@ class TransformerEncoderLayer(Module):
|
||||
self.ln2 = LayerNorm(dims)
|
||||
self.linear1 = Linear(dims, mlp_dims)
|
||||
self.linear2 = Linear(mlp_dims, dims)
|
||||
self.dropout1 = Dropout(dropout)
|
||||
self.dropout2 = Dropout(dropout)
|
||||
self.activation = activation
|
||||
self.norm_first = norm_first
|
||||
|
||||
def __call__(self, x, mask):
|
||||
if self.norm_first:
|
||||
y = self.ln1(x)
|
||||
y = self.attention(y, y, y, mask)
|
||||
y = self.dropout1(y)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.linear1(y)
|
||||
y = mx.maximum(y, 0)
|
||||
y = self.activation(y)
|
||||
y = self.dropout2(y)
|
||||
y = self.linear2(y)
|
||||
x = x + y
|
||||
y = x + y
|
||||
|
||||
return x
|
||||
else:
|
||||
y = self.attention(x, x, x, mask)
|
||||
y = self.dropout1(y)
|
||||
y = self.ln1(x + y)
|
||||
|
||||
y = self.linear1(y)
|
||||
y = self.activation(y)
|
||||
y = self.dropout2(y)
|
||||
y = self.linear2(y)
|
||||
y = self.ln2(x + y)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class TransformerEncoder(Module):
|
||||
def __init__(
|
||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||
self,
|
||||
num_layers: int,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
mlp_dims: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
activation=relu,
|
||||
norm_first: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
||||
TransformerEncoderLayer(
|
||||
dims, num_heads, mlp_dims, dropout, activation, norm_first
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(dims)
|
||||
@ -134,13 +180,19 @@ class TransformerEncoder(Module):
|
||||
def __call__(self, x, mask):
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
x = self.ln(x)
|
||||
|
||||
return x
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(Module):
|
||||
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
mlp_dims: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
activation: Callable[[Any], Any] = relu,
|
||||
norm_first: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
mlp_dims = mlp_dims or dims * 4
|
||||
self.self_attention = MultiHeadAttention(dims, num_heads)
|
||||
@ -150,32 +202,65 @@ class TransformerDecoderLayer(Module):
|
||||
self.ln3 = LayerNorm(dims)
|
||||
self.linear1 = Linear(dims, mlp_dims)
|
||||
self.linear2 = Linear(mlp_dims, dims)
|
||||
self.dropout1 = Dropout(dropout)
|
||||
self.dropout2 = Dropout(dropout)
|
||||
self.dropout3 = Dropout(dropout)
|
||||
self.activation = activation
|
||||
self.norm_first = norm_first
|
||||
|
||||
def __call__(self, x, memory, x_mask, memory_mask):
|
||||
if self.norm_first:
|
||||
y = self.ln1(x)
|
||||
y = self.self_attention(y, y, y, x_mask)
|
||||
y = self.dropout1(y)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.cross_attention(y, memory, memory, memory_mask)
|
||||
y = self.dropout2(y)
|
||||
x = x + y
|
||||
|
||||
y = self.ln3(x)
|
||||
y = self.linear1(y)
|
||||
y = mx.maximum(y, 0)
|
||||
y = self.activation(y)
|
||||
y = self.dropout3(y)
|
||||
y = self.linear2(y)
|
||||
x = x + y
|
||||
y = x + y
|
||||
|
||||
return x
|
||||
else:
|
||||
y = self.self_attention(x, x, x, x_mask)
|
||||
y = self.dropout1(y)
|
||||
x = self.ln1(x + y)
|
||||
|
||||
y = self.cross_attention(y, memory, memory, memory_mask)
|
||||
y = self.dropout2(y)
|
||||
x = self.ln1(x + y)
|
||||
|
||||
y = self.linear1(x)
|
||||
y = self.activation(y)
|
||||
y = self.dropout3(y)
|
||||
y = self.linear2(y)
|
||||
y = self.ln3(x + y)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class TransformerDecoder(Module):
|
||||
def __init__(
|
||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||
self,
|
||||
num_layers: int,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
mlp_dims: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
activation=relu,
|
||||
norm_first: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerDecoderLayer(dims, num_heads, mlp_dims)
|
||||
TransformerDecoderLayer(
|
||||
dims, num_heads, mlp_dims, dropout, activation, norm_first
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(dims)
|
||||
@ -183,12 +268,47 @@ class TransformerDecoder(Module):
|
||||
def __call__(self, x, memory, x_mask, memory_mask):
|
||||
for l in self.layers:
|
||||
x = l(x, memory, x_mask, memory_mask)
|
||||
x = self.ln(x)
|
||||
|
||||
return x
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
class Transformer(Module):
|
||||
"""
|
||||
Implements a standard Transformer model.
|
||||
|
||||
The implementation is based on `Attention Is All You Need
|
||||
<https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
The Transformer model contains an encoder and a decoder. The encoder
|
||||
processes the input sequence and the decoder generates the output sequence.
|
||||
The interaction between encoder and decoder happens through the attention
|
||||
mechanism.
|
||||
|
||||
Args:
|
||||
dims (int, optional): The number of expected features in the
|
||||
encoder/decoder inputs. Default: ``512``.
|
||||
num_heads (int, optional): The number of attention heads. Default:
|
||||
``8``.
|
||||
num_encoder_layers (int, optional): The number of encoder layers in the
|
||||
Transformer encoder. Default: ``6``.
|
||||
num_decoder_layers (int, optional): The number of decoder layers in the
|
||||
Transformer decoder. Default: ``6``.
|
||||
mlp_dims (int, optional): The hidden dimension of the MLP block in each
|
||||
Transformer layer. Defaults to ``4*dims`` if not provided. Default:
|
||||
``None``.
|
||||
dropout (float, optional): The dropout value for the Transformer
|
||||
encoder and decoder. Dropout is used after each attention layer and
|
||||
the activation in the MLP layer. Default: ``0.0``.
|
||||
activation (function, optional): the activation function for the MLP
|
||||
hidden layer. Default: :func:`mlx.nn.relu`.
|
||||
custom_encoder (nn.Module, optional): A custom encoder to replace the
|
||||
standard Transformer encoder. Default: ``None``.
|
||||
custom_decoder (nn.Module, optional): A custom decoder to replace the
|
||||
standard Transformer decoder. Default: ``None``.
|
||||
norm_first (bool, optional): if ``True``, encoder and decoder layers
|
||||
will perform layer normalization before attention and MLP
|
||||
operations, otherwise after. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int = 512,
|
||||
@ -196,26 +316,39 @@ class Transformer(Module):
|
||||
num_encoder_layers: int = 6,
|
||||
num_decoder_layers: int = 6,
|
||||
mlp_dims: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
activation: Callable[[Any], Any] = relu,
|
||||
custom_encoder: Optional[Any] = None,
|
||||
custom_decoder: Optional[Any] = None,
|
||||
norm_first: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if custom_encoder is not None:
|
||||
self.encoder = custom_encoder
|
||||
else:
|
||||
self.encoder = TransformerEncoder(
|
||||
num_encoder_layers, dims, num_heads, mlp_dims
|
||||
num_encoder_layers,
|
||||
dims,
|
||||
num_heads,
|
||||
mlp_dims,
|
||||
dropout,
|
||||
activation,
|
||||
norm_first,
|
||||
)
|
||||
|
||||
if custom_decoder is not None:
|
||||
self.decoder = custom_decoder
|
||||
else:
|
||||
self.decoder = TransformerDecoder(
|
||||
num_decoder_layers, dims, num_heads, mlp_dims
|
||||
num_decoder_layers,
|
||||
dims,
|
||||
num_heads,
|
||||
mlp_dims,
|
||||
dropout,
|
||||
activation,
|
||||
norm_first,
|
||||
)
|
||||
|
||||
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
||||
memory = self.encoder(src, src_mask)
|
||||
output = self.decoder(tgt, memory, tgt_mask, memory_mask)
|
||||
|
||||
return output
|
||||
return self.decoder(tgt, memory, tgt_mask, memory_mask)
|
||||
|
Loading…
Reference in New Issue
Block a user