mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Transformer fix (#167)
* add transformer with dropout, fix transformer ffm, layernorm order * precommit changes * precommit changes * add docstring, activation, norm_first * run precommit * run precommit * add doctstring * precommit * style nits in docs --------- Co-authored-by: junwoo-yun <junwoo.yun@bagelcode.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
79c95b6919
commit
4417e37ede
@ -9,7 +9,7 @@ Layers
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
:template: nn-module-template.rst
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
Embedding
|
Sequential
|
||||||
ReLU
|
ReLU
|
||||||
PReLU
|
PReLU
|
||||||
GELU
|
GELU
|
||||||
@ -17,17 +17,19 @@ Layers
|
|||||||
Step
|
Step
|
||||||
SELU
|
SELU
|
||||||
Mish
|
Mish
|
||||||
|
Embedding
|
||||||
Linear
|
Linear
|
||||||
|
QuantizedLinear
|
||||||
Conv1d
|
Conv1d
|
||||||
Conv2d
|
Conv2d
|
||||||
BatchNorm
|
BatchNorm
|
||||||
LayerNorm
|
LayerNorm
|
||||||
RMSNorm
|
RMSNorm
|
||||||
GroupNorm
|
GroupNorm
|
||||||
RoPE
|
|
||||||
MultiHeadAttention
|
|
||||||
Sequential
|
|
||||||
QuantizedLinear
|
|
||||||
Dropout
|
Dropout
|
||||||
Dropout2d
|
Dropout2d
|
||||||
|
Transformer
|
||||||
|
MultiHeadAttention
|
||||||
|
ALiBi
|
||||||
|
RoPE
|
||||||
|
SinusoidalPositionalEncoding
|
||||||
|
@ -14,6 +14,7 @@ from mlx.nn.layers.activations import (
|
|||||||
SiLU,
|
SiLU,
|
||||||
Softplus,
|
Softplus,
|
||||||
Step,
|
Step,
|
||||||
|
Tanh,
|
||||||
celu,
|
celu,
|
||||||
elu,
|
elu,
|
||||||
gelu,
|
gelu,
|
||||||
@ -29,6 +30,7 @@ from mlx.nn.layers.activations import (
|
|||||||
silu,
|
silu,
|
||||||
softplus,
|
softplus,
|
||||||
step,
|
step,
|
||||||
|
tanh,
|
||||||
)
|
)
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
from mlx.nn.layers.containers import Sequential
|
from mlx.nn.layers.containers import Sequential
|
||||||
@ -41,6 +43,7 @@ from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalE
|
|||||||
from mlx.nn.layers.quantized import QuantizedLinear
|
from mlx.nn.layers.quantized import QuantizedLinear
|
||||||
from mlx.nn.layers.transformer import (
|
from mlx.nn.layers.transformer import (
|
||||||
MultiHeadAttention,
|
MultiHeadAttention,
|
||||||
|
Transformer,
|
||||||
TransformerEncoder,
|
TransformerEncoder,
|
||||||
TransformerEncoderLayer,
|
TransformerEncoderLayer,
|
||||||
)
|
)
|
||||||
|
@ -179,12 +179,12 @@ def selu(x):
|
|||||||
|
|
||||||
|
|
||||||
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
||||||
r"""Applies the element-wise function:
|
r"""Applies the element-wise parametric ReLU.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
||||||
|
|
||||||
Here :math:`a` is an array.
|
where :math:`a` is an array.
|
||||||
"""
|
"""
|
||||||
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
|
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
|
||||||
|
|
||||||
|
@ -8,21 +8,22 @@ from mlx.nn.layers.base import Module
|
|||||||
|
|
||||||
|
|
||||||
class RoPE(Module):
|
class RoPE(Module):
|
||||||
"""Implements the rotary positional encoding [1].
|
"""Implements the rotary positional encoding.
|
||||||
|
|
||||||
The traditional implementation rotates consecutive pairs of elements in the
|
The traditional implementation rotates consecutive pairs of elements in the
|
||||||
feature dimension while the default implementation rotates pairs with
|
feature dimension while the default implementation rotates pairs with
|
||||||
stride half the feature dimensions for efficiency.
|
stride half the feature dimensions for efficiency.
|
||||||
|
|
||||||
[1]: https://arxiv.org/abs/2104.09864
|
For more details see `RoFormer: Enhanced Transformer with Rotary Position
|
||||||
|
Embedding <https://arxiv.org/abs/2104.09864>`_.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dims (int): The feature dimensions to be rotated. If the input feature
|
dims (int): The feature dimensions to be rotated. If the input feature
|
||||||
is larger than dims then the rest is left unchanged.
|
is larger than dims then the rest is left unchanged.
|
||||||
traditional (bool, optional): If set to True choose the traditional
|
traditional (bool, optional): If set to True choose the traditional
|
||||||
implementation which is slightly less efficient. Default: ``False``
|
implementation which is slightly less efficient. Default: ``False``.
|
||||||
base (float, optional): The base used to compute angular frequency for
|
base (float, optional): The base used to compute angular frequency for
|
||||||
each dimension in the positional encodings. Default: ``10000``
|
each dimension in the positional encodings. Default: ``10000``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
|
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
|
||||||
@ -89,19 +90,23 @@ class RoPE(Module):
|
|||||||
|
|
||||||
|
|
||||||
class SinusoidalPositionalEncoding(Module):
|
class SinusoidalPositionalEncoding(Module):
|
||||||
"""Implements sinusoidal positional encoding similar to [1].
|
r"""Implements sinusoidal positional encoding.
|
||||||
|
|
||||||
[1]: https://arxiv.org/abs/1706.03762
|
For more details see the paper `Attention Is All You Need
|
||||||
|
<https://arxiv.org/abs/1706.03762>`_.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dims (int): The dimensionality of the resulting positional embeddings.
|
dims (int): The dimensionality of the resulting positional embeddings.
|
||||||
min_freq (float): The minimum frequency expected (default: 0.0001)
|
min_freq (float, optional): The minimum frequency expected. Default:
|
||||||
max_freq (float): The maximum frequency expected (default: 1)
|
``0.0001``.
|
||||||
scale (float): Scale the embeddings by that number (default: sqrt(dims//2))
|
max_freq (float, optional): The maximum frequency expected. Default:
|
||||||
cos_first (bool): If set to True embed using ``[cos(x); sin(x)]``
|
``1``.
|
||||||
instead of the other way around (default: False)
|
scale (float, optional): A multiplicative scale for the embeddings.
|
||||||
full_turns (bool): If set to True multiply the frequencies
|
Default: ``sqrt(dims//2)``.
|
||||||
with ``2 pi`` (default: False)
|
cos_first (bool, optional): If ``True`` embed using ``[cos(x); sin(x)]``
|
||||||
|
instead of the reverse. Default: ``False``.
|
||||||
|
full_turns (bool, optional): If ``True`` multiply the frequencies with
|
||||||
|
:math:`2\pi`. Default: ``False``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
from mlx.nn.layers.activations import relu
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
|
from mlx.nn.layers.dropout import Dropout
|
||||||
from mlx.nn.layers.linear import Linear
|
from mlx.nn.layers.linear import Linear
|
||||||
from mlx.nn.layers.normalization import LayerNorm
|
from mlx.nn.layers.normalization import LayerNorm
|
||||||
|
|
||||||
@ -12,26 +14,34 @@ from mlx.nn.layers.normalization import LayerNorm
|
|||||||
class MultiHeadAttention(Module):
|
class MultiHeadAttention(Module):
|
||||||
"""Implements the scaled dot product attention with multiple heads.
|
"""Implements the scaled dot product attention with multiple heads.
|
||||||
|
|
||||||
Given inputs for queries, keys and values the ``MultiHeadAttention`` produces
|
Given inputs for queries, keys and values the ``MultiHeadAttention``
|
||||||
new values by aggregating information from the input values according to
|
produces new values by aggregating information from the input values
|
||||||
the similarities of the input queries and keys.
|
according to the similarities of the input queries and keys.
|
||||||
|
|
||||||
All inputs as well as the output are linearly projected without biases.
|
All inputs as well as the output are linearly projected without biases by
|
||||||
|
default.
|
||||||
|
|
||||||
MultiHeadAttention also expects an additive attention mask that should be
|
``MultiHeadAttention`` also takes an optional additive attention mask that
|
||||||
broadcastable with (batch, num_heads, # queries, # keys). The mask should
|
should be broadcastable with ``(batch, num_heads, # queries, # keys)``. The
|
||||||
have ``-inf`` or very negative numbers to the positions that should *not* be
|
mask should have ``-inf`` or very large negative numbers at the positions
|
||||||
attended to.
|
that should *not* be attended to.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dims (int): The model dimensions. If no other dims are provided then
|
dims (int): The model dimensions. This is also the default
|
||||||
dims is used for queries, keys, values and the output.
|
value for the queries, keys, values, and the output.
|
||||||
num_heads (int): How many attention heads to use
|
num_heads (int): The number of attention heads to use.
|
||||||
query_input_dims (int, optional): The input dimensions of the queries (default: dims).
|
query_input_dims (int, optional): The input dimensions of the queries.
|
||||||
key_input_dims (int, optional): The input dimensions of the keys (default: dims).
|
Default: ``dims``.
|
||||||
value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims).
|
key_input_dims (int, optional): The input dimensions of the keys.
|
||||||
value_dims (int, optional): The dimensions of the values after the projection (default: dims).
|
Default: ``dims``.
|
||||||
value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims).
|
value_input_dims (int, optional): The input dimensions of the values.
|
||||||
|
Default: ``key_input_dims``.
|
||||||
|
value_dims (int, optional): The dimensions of the values after the
|
||||||
|
projection. Default: ``dims``.
|
||||||
|
value_output_dims (int, optional): The dimensions the new values will
|
||||||
|
be projected to. Default: ``dims``.
|
||||||
|
bias (bool, optional): Whether or not to use a bias in the projections.
|
||||||
|
Default: ``False``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -49,7 +59,8 @@ class MultiHeadAttention(Module):
|
|||||||
|
|
||||||
if (dims % num_heads) != 0:
|
if (dims % num_heads) != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
|
"The input feature dimensions should be divisible by the "
|
||||||
|
f"number of heads ({dims} % {num_heads}) != 0"
|
||||||
)
|
)
|
||||||
|
|
||||||
query_input_dims = query_input_dims or dims
|
query_input_dims = query_input_dims or dims
|
||||||
@ -97,7 +108,15 @@ class MultiHeadAttention(Module):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(Module):
|
class TransformerEncoderLayer(Module):
|
||||||
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_dims: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
activation: Callable[[Any], Any] = relu,
|
||||||
|
norm_first: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = mlp_dims or dims * 4
|
mlp_dims = mlp_dims or dims * 4
|
||||||
self.attention = MultiHeadAttention(dims, num_heads)
|
self.attention = MultiHeadAttention(dims, num_heads)
|
||||||
@ -105,28 +124,55 @@ class TransformerEncoderLayer(Module):
|
|||||||
self.ln2 = LayerNorm(dims)
|
self.ln2 = LayerNorm(dims)
|
||||||
self.linear1 = Linear(dims, mlp_dims)
|
self.linear1 = Linear(dims, mlp_dims)
|
||||||
self.linear2 = Linear(mlp_dims, dims)
|
self.linear2 = Linear(mlp_dims, dims)
|
||||||
|
self.dropout1 = Dropout(dropout)
|
||||||
|
self.dropout2 = Dropout(dropout)
|
||||||
|
self.activation = activation
|
||||||
|
self.norm_first = norm_first
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
def __call__(self, x, mask):
|
||||||
y = self.ln1(x)
|
if self.norm_first:
|
||||||
y = self.attention(y, y, y, mask)
|
y = self.ln1(x)
|
||||||
x = x + y
|
y = self.attention(y, y, y, mask)
|
||||||
|
y = self.dropout1(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
y = self.ln2(x)
|
y = self.ln2(x)
|
||||||
y = self.linear1(y)
|
y = self.linear1(y)
|
||||||
y = mx.maximum(y, 0)
|
y = self.activation(y)
|
||||||
y = self.linear2(y)
|
y = self.dropout2(y)
|
||||||
x = x + y
|
y = self.linear2(y)
|
||||||
|
y = x + y
|
||||||
|
|
||||||
return x
|
else:
|
||||||
|
y = self.attention(x, x, x, mask)
|
||||||
|
y = self.dropout1(y)
|
||||||
|
y = self.ln1(x + y)
|
||||||
|
|
||||||
|
y = self.linear1(y)
|
||||||
|
y = self.activation(y)
|
||||||
|
y = self.dropout2(y)
|
||||||
|
y = self.linear2(y)
|
||||||
|
y = self.ln2(x + y)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(Module):
|
class TransformerEncoder(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
dims: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_dims: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
activation=relu,
|
||||||
|
norm_first: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
TransformerEncoderLayer(
|
||||||
|
dims, num_heads, mlp_dims, dropout, activation, norm_first
|
||||||
|
)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
]
|
||||||
self.ln = LayerNorm(dims)
|
self.ln = LayerNorm(dims)
|
||||||
@ -134,13 +180,19 @@ class TransformerEncoder(Module):
|
|||||||
def __call__(self, x, mask):
|
def __call__(self, x, mask):
|
||||||
for l in self.layers:
|
for l in self.layers:
|
||||||
x = l(x, mask)
|
x = l(x, mask)
|
||||||
x = self.ln(x)
|
return self.ln(x)
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(Module):
|
class TransformerDecoderLayer(Module):
|
||||||
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_dims: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
activation: Callable[[Any], Any] = relu,
|
||||||
|
norm_first: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = mlp_dims or dims * 4
|
mlp_dims = mlp_dims or dims * 4
|
||||||
self.self_attention = MultiHeadAttention(dims, num_heads)
|
self.self_attention = MultiHeadAttention(dims, num_heads)
|
||||||
@ -150,32 +202,65 @@ class TransformerDecoderLayer(Module):
|
|||||||
self.ln3 = LayerNorm(dims)
|
self.ln3 = LayerNorm(dims)
|
||||||
self.linear1 = Linear(dims, mlp_dims)
|
self.linear1 = Linear(dims, mlp_dims)
|
||||||
self.linear2 = Linear(mlp_dims, dims)
|
self.linear2 = Linear(mlp_dims, dims)
|
||||||
|
self.dropout1 = Dropout(dropout)
|
||||||
|
self.dropout2 = Dropout(dropout)
|
||||||
|
self.dropout3 = Dropout(dropout)
|
||||||
|
self.activation = activation
|
||||||
|
self.norm_first = norm_first
|
||||||
|
|
||||||
def __call__(self, x, memory, x_mask, memory_mask):
|
def __call__(self, x, memory, x_mask, memory_mask):
|
||||||
y = self.ln1(x)
|
if self.norm_first:
|
||||||
y = self.self_attention(y, y, y, x_mask)
|
y = self.ln1(x)
|
||||||
x = x + y
|
y = self.self_attention(y, y, y, x_mask)
|
||||||
|
y = self.dropout1(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
y = self.ln2(x)
|
y = self.ln2(x)
|
||||||
y = self.cross_attention(y, memory, memory, memory_mask)
|
y = self.cross_attention(y, memory, memory, memory_mask)
|
||||||
x = x + y
|
y = self.dropout2(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
y = self.ln3(x)
|
y = self.ln3(x)
|
||||||
y = self.linear1(y)
|
y = self.linear1(y)
|
||||||
y = mx.maximum(y, 0)
|
y = self.activation(y)
|
||||||
y = self.linear2(y)
|
y = self.dropout3(y)
|
||||||
x = x + y
|
y = self.linear2(y)
|
||||||
|
y = x + y
|
||||||
|
|
||||||
return x
|
else:
|
||||||
|
y = self.self_attention(x, x, x, x_mask)
|
||||||
|
y = self.dropout1(y)
|
||||||
|
x = self.ln1(x + y)
|
||||||
|
|
||||||
|
y = self.cross_attention(y, memory, memory, memory_mask)
|
||||||
|
y = self.dropout2(y)
|
||||||
|
x = self.ln1(x + y)
|
||||||
|
|
||||||
|
y = self.linear1(x)
|
||||||
|
y = self.activation(y)
|
||||||
|
y = self.dropout3(y)
|
||||||
|
y = self.linear2(y)
|
||||||
|
y = self.ln3(x + y)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoder(Module):
|
class TransformerDecoder(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
dims: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_dims: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
activation=relu,
|
||||||
|
norm_first: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerDecoderLayer(dims, num_heads, mlp_dims)
|
TransformerDecoderLayer(
|
||||||
|
dims, num_heads, mlp_dims, dropout, activation, norm_first
|
||||||
|
)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
]
|
||||||
self.ln = LayerNorm(dims)
|
self.ln = LayerNorm(dims)
|
||||||
@ -183,12 +268,47 @@ class TransformerDecoder(Module):
|
|||||||
def __call__(self, x, memory, x_mask, memory_mask):
|
def __call__(self, x, memory, x_mask, memory_mask):
|
||||||
for l in self.layers:
|
for l in self.layers:
|
||||||
x = l(x, memory, x_mask, memory_mask)
|
x = l(x, memory, x_mask, memory_mask)
|
||||||
x = self.ln(x)
|
return self.ln(x)
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Transformer(Module):
|
class Transformer(Module):
|
||||||
|
"""
|
||||||
|
Implements a standard Transformer model.
|
||||||
|
|
||||||
|
The implementation is based on `Attention Is All You Need
|
||||||
|
<https://arxiv.org/abs/1706.03762>`_.
|
||||||
|
|
||||||
|
The Transformer model contains an encoder and a decoder. The encoder
|
||||||
|
processes the input sequence and the decoder generates the output sequence.
|
||||||
|
The interaction between encoder and decoder happens through the attention
|
||||||
|
mechanism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dims (int, optional): The number of expected features in the
|
||||||
|
encoder/decoder inputs. Default: ``512``.
|
||||||
|
num_heads (int, optional): The number of attention heads. Default:
|
||||||
|
``8``.
|
||||||
|
num_encoder_layers (int, optional): The number of encoder layers in the
|
||||||
|
Transformer encoder. Default: ``6``.
|
||||||
|
num_decoder_layers (int, optional): The number of decoder layers in the
|
||||||
|
Transformer decoder. Default: ``6``.
|
||||||
|
mlp_dims (int, optional): The hidden dimension of the MLP block in each
|
||||||
|
Transformer layer. Defaults to ``4*dims`` if not provided. Default:
|
||||||
|
``None``.
|
||||||
|
dropout (float, optional): The dropout value for the Transformer
|
||||||
|
encoder and decoder. Dropout is used after each attention layer and
|
||||||
|
the activation in the MLP layer. Default: ``0.0``.
|
||||||
|
activation (function, optional): the activation function for the MLP
|
||||||
|
hidden layer. Default: :func:`mlx.nn.relu`.
|
||||||
|
custom_encoder (nn.Module, optional): A custom encoder to replace the
|
||||||
|
standard Transformer encoder. Default: ``None``.
|
||||||
|
custom_decoder (nn.Module, optional): A custom decoder to replace the
|
||||||
|
standard Transformer decoder. Default: ``None``.
|
||||||
|
norm_first (bool, optional): if ``True``, encoder and decoder layers
|
||||||
|
will perform layer normalization before attention and MLP
|
||||||
|
operations, otherwise after. Default: ``False``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dims: int = 512,
|
dims: int = 512,
|
||||||
@ -196,26 +316,39 @@ class Transformer(Module):
|
|||||||
num_encoder_layers: int = 6,
|
num_encoder_layers: int = 6,
|
||||||
num_decoder_layers: int = 6,
|
num_decoder_layers: int = 6,
|
||||||
mlp_dims: Optional[int] = None,
|
mlp_dims: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
activation: Callable[[Any], Any] = relu,
|
||||||
custom_encoder: Optional[Any] = None,
|
custom_encoder: Optional[Any] = None,
|
||||||
custom_decoder: Optional[Any] = None,
|
custom_decoder: Optional[Any] = None,
|
||||||
|
norm_first: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if custom_encoder is not None:
|
if custom_encoder is not None:
|
||||||
self.encoder = custom_encoder
|
self.encoder = custom_encoder
|
||||||
else:
|
else:
|
||||||
self.encoder = TransformerEncoder(
|
self.encoder = TransformerEncoder(
|
||||||
num_encoder_layers, dims, num_heads, mlp_dims
|
num_encoder_layers,
|
||||||
|
dims,
|
||||||
|
num_heads,
|
||||||
|
mlp_dims,
|
||||||
|
dropout,
|
||||||
|
activation,
|
||||||
|
norm_first,
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_decoder is not None:
|
if custom_decoder is not None:
|
||||||
self.decoder = custom_decoder
|
self.decoder = custom_decoder
|
||||||
else:
|
else:
|
||||||
self.decoder = TransformerDecoder(
|
self.decoder = TransformerDecoder(
|
||||||
num_decoder_layers, dims, num_heads, mlp_dims
|
num_decoder_layers,
|
||||||
|
dims,
|
||||||
|
num_heads,
|
||||||
|
mlp_dims,
|
||||||
|
dropout,
|
||||||
|
activation,
|
||||||
|
norm_first,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
||||||
memory = self.encoder(src, src_mask)
|
memory = self.encoder(src, src_mask)
|
||||||
output = self.decoder(tgt, memory, tgt_mask, memory_mask)
|
return self.decoder(tgt, memory, tgt_mask, memory_mask)
|
||||||
|
|
||||||
return output
|
|
||||||
|
Loading…
Reference in New Issue
Block a user