mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
add docstring, activation, norm_first
This commit is contained in:
parent
df1f8aa3be
commit
297e69017c
@ -1,9 +1,10 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, Callable
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
from mlx.nn.layers.activations import relu
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
from mlx.nn.layers.dropout import Dropout
|
from mlx.nn.layers.dropout import Dropout
|
||||||
from mlx.nn.layers.linear import Linear
|
from mlx.nn.layers.linear import Linear
|
||||||
@ -98,34 +99,14 @@ class MultiHeadAttention(Module):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(Module):
|
class TransformerEncoderLayer(Module):
|
||||||
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
|
||||||
super().__init__()
|
|
||||||
mlp_dims = mlp_dims or dims * 4
|
|
||||||
self.attention = MultiHeadAttention(dims, num_heads)
|
|
||||||
self.ln1 = LayerNorm(dims)
|
|
||||||
self.ln2 = LayerNorm(dims)
|
|
||||||
self.linear1 = Linear(dims, mlp_dims)
|
|
||||||
self.linear2 = Linear(mlp_dims, dims)
|
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
|
||||||
y = self.attention(x, x, x, mask)
|
|
||||||
y = self.ln1(x + y)
|
|
||||||
|
|
||||||
y = self.linear1(y)
|
|
||||||
y = mx.maximum(y, 0)
|
|
||||||
y = self.linear2(y)
|
|
||||||
|
|
||||||
y = self.ln2(x + y)
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayerWithDropout(Module):
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dims: int,
|
dims: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
mlp_dims: Optional[int] = None,
|
mlp_dims: Optional[int] = None,
|
||||||
dropout_rate: float = 0.1,
|
dropout: float = 0.0,
|
||||||
|
activation: Callable[[Any], Any] = relu,
|
||||||
|
norm_first: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = mlp_dims or dims * 4
|
mlp_dims = mlp_dims or dims * 4
|
||||||
@ -134,49 +115,46 @@ class TransformerEncoderLayerWithDropout(Module):
|
|||||||
self.ln2 = LayerNorm(dims)
|
self.ln2 = LayerNorm(dims)
|
||||||
self.linear1 = Linear(dims, mlp_dims)
|
self.linear1 = Linear(dims, mlp_dims)
|
||||||
self.linear2 = Linear(mlp_dims, dims)
|
self.linear2 = Linear(mlp_dims, dims)
|
||||||
self.dropout1 = Dropout(dropout_rate)
|
self.dropout1 = Dropout(dropout)
|
||||||
self.dropout2 = Dropout(dropout_rate)
|
self.dropout2 = Dropout(dropout)
|
||||||
|
self.activation = activation
|
||||||
|
self.norm_first = norm_first
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
def __call__(self, x, mask):
|
||||||
|
if self.norm_first:
|
||||||
|
y = self.ln1(x)
|
||||||
|
y = self.attention(y, y, y, mask)
|
||||||
|
y = self.dropout1(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln2(x)
|
||||||
|
y = self.linear1(y)
|
||||||
|
y = self.activation(y)
|
||||||
|
y = self.dropout2(y)
|
||||||
|
y = self.linear2(y)
|
||||||
|
y = x + y
|
||||||
|
|
||||||
|
else:
|
||||||
y = self.attention(x, x, x, mask)
|
y = self.attention(x, x, x, mask)
|
||||||
y = self.dropout1(y)
|
y = self.dropout1(y)
|
||||||
y = self.ln1(x + y)
|
y = self.ln1(x + y)
|
||||||
|
|
||||||
y = self.linear1(y)
|
y = self.linear1(y)
|
||||||
y = mx.maximum(y, 0)
|
y = self.activation(y)
|
||||||
y = self.linear2(y)
|
|
||||||
y = self.dropout2(y)
|
y = self.dropout2(y)
|
||||||
|
y = self.linear2(y)
|
||||||
y = self.ln2(x + y)
|
y = self.ln2(x + y)
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(Module):
|
class TransformerEncoder(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
TransformerEncoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation)
|
||||||
for i in range(num_layers)
|
|
||||||
]
|
|
||||||
self.ln = LayerNorm(dims)
|
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
|
||||||
for l in self.layers:
|
|
||||||
x = l(x, mask)
|
|
||||||
x = self.ln(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderWithDropout(Module):
|
|
||||||
def __init__(
|
|
||||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.layers = [
|
|
||||||
TransformerEncoderLayerWithDropout(dims, num_heads, mlp_dims)
|
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
]
|
||||||
self.ln = LayerNorm(dims)
|
self.ln = LayerNorm(dims)
|
||||||
@ -190,39 +168,14 @@ class TransformerEncoderWithDropout(Module):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(Module):
|
class TransformerDecoderLayer(Module):
|
||||||
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
|
||||||
super().__init__()
|
|
||||||
mlp_dims = mlp_dims or dims * 4
|
|
||||||
self.self_attention = MultiHeadAttention(dims, num_heads)
|
|
||||||
self.cross_attention = MultiHeadAttention(dims, num_heads)
|
|
||||||
self.ln1 = LayerNorm(dims)
|
|
||||||
self.ln2 = LayerNorm(dims)
|
|
||||||
self.ln3 = LayerNorm(dims)
|
|
||||||
self.linear1 = Linear(dims, mlp_dims)
|
|
||||||
self.linear2 = Linear(mlp_dims, dims)
|
|
||||||
|
|
||||||
def __call__(self, x, memory, x_mask, memory_mask):
|
|
||||||
y = self.self_attention(x, x, x, x_mask)
|
|
||||||
x = self.ln1(x + y)
|
|
||||||
|
|
||||||
y = self.cross_attention(y, memory, memory, memory_mask)
|
|
||||||
x = self.ln1(x + y)
|
|
||||||
|
|
||||||
y = self.linear1(x)
|
|
||||||
y = mx.maximum(y, 0)
|
|
||||||
y = self.linear2(y)
|
|
||||||
y = self.ln3(x + y)
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayerWithDropout(Module):
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dims: int,
|
dims: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
mlp_dims: Optional[int] = None,
|
mlp_dims: Optional[int] = None,
|
||||||
dropout_rate: float = 0.1,
|
dropout: float = 0.0,
|
||||||
|
activation: Callable[[Any], Any] = relu,
|
||||||
|
norm_first: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = mlp_dims or dims * 4
|
mlp_dims = mlp_dims or dims * 4
|
||||||
@ -233,11 +186,32 @@ class TransformerDecoderLayerWithDropout(Module):
|
|||||||
self.ln3 = LayerNorm(dims)
|
self.ln3 = LayerNorm(dims)
|
||||||
self.linear1 = Linear(dims, mlp_dims)
|
self.linear1 = Linear(dims, mlp_dims)
|
||||||
self.linear2 = Linear(mlp_dims, dims)
|
self.linear2 = Linear(mlp_dims, dims)
|
||||||
self.dropout1 = Dropout(dropout_rate)
|
self.dropout1 = Dropout(dropout)
|
||||||
self.dropout2 = Dropout(dropout_rate)
|
self.dropout2 = Dropout(dropout)
|
||||||
self.dropout3 = Dropout(dropout_rate)
|
self.dropout3 = Dropout(dropout)
|
||||||
|
self.activation = activation
|
||||||
|
self.norm_first = norm_first
|
||||||
|
|
||||||
def __call__(self, x, memory, x_mask, memory_mask):
|
def __call__(self, x, memory, x_mask, memory_mask):
|
||||||
|
if self.norm_first:
|
||||||
|
y = self.ln1(x)
|
||||||
|
y = self.self_attention(y, y, y, x_mask)
|
||||||
|
y = self.dropout1(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln2(x)
|
||||||
|
y = self.cross_attention(y, memory, memory, memory_mask)
|
||||||
|
y = self.dropout2(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln3(x)
|
||||||
|
y = self.linear1(y)
|
||||||
|
y = self.activation(y)
|
||||||
|
y = self.dropout3(y)
|
||||||
|
y = self.linear2(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
else:
|
||||||
y = self.self_attention(x, x, x, x_mask)
|
y = self.self_attention(x, x, x, x_mask)
|
||||||
y = self.dropout1(y)
|
y = self.dropout1(y)
|
||||||
x = self.ln1(x + y)
|
x = self.ln1(x + y)
|
||||||
@ -247,9 +221,9 @@ class TransformerDecoderLayerWithDropout(Module):
|
|||||||
x = self.ln1(x + y)
|
x = self.ln1(x + y)
|
||||||
|
|
||||||
y = self.linear1(x)
|
y = self.linear1(x)
|
||||||
y = mx.maximum(y, 0)
|
y = self.activation(y)
|
||||||
y = self.linear2(y)
|
|
||||||
y = self.dropout3(y)
|
y = self.dropout3(y)
|
||||||
|
y = self.linear2(y)
|
||||||
y = self.ln3(x + y)
|
y = self.ln3(x + y)
|
||||||
|
|
||||||
return y
|
return y
|
||||||
@ -257,30 +231,11 @@ class TransformerDecoderLayerWithDropout(Module):
|
|||||||
|
|
||||||
class TransformerDecoder(Module):
|
class TransformerDecoder(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerDecoderLayer(dims, num_heads, mlp_dims)
|
TransformerDecoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation)
|
||||||
for i in range(num_layers)
|
|
||||||
]
|
|
||||||
self.ln = LayerNorm(dims)
|
|
||||||
|
|
||||||
def __call__(self, x, memory, x_mask, memory_mask):
|
|
||||||
for l in self.layers:
|
|
||||||
x = l(x, memory, x_mask, memory_mask)
|
|
||||||
x = self.ln(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoderWithDropout(Module):
|
|
||||||
def __init__(
|
|
||||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.layers = [
|
|
||||||
TransformerDecoderLayerWithDropout(dims, num_heads, mlp_dims)
|
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
]
|
||||||
self.ln = LayerNorm(dims)
|
self.ln = LayerNorm(dims)
|
||||||
@ -294,6 +249,27 @@ class TransformerDecoderWithDropout(Module):
|
|||||||
|
|
||||||
|
|
||||||
class Transformer(Module):
|
class Transformer(Module):
|
||||||
|
"""
|
||||||
|
Implements a standard Transformer model based on the paper "Attention Is All You Need".
|
||||||
|
|
||||||
|
The Transformer model consists of an encoder and a decoder. The encoder processes
|
||||||
|
the input sequence and the decoder generates the output sequence. The interaction
|
||||||
|
between encoder and decoder happens through the attention mechanism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dims (int): The number of expected features in the encoder/decoder inputs.
|
||||||
|
num_heads (int): The number of heads in the multi-head attention models.
|
||||||
|
num_encoder_layers (int): The number of sub-encoder-layers in the Transformer encoder.
|
||||||
|
num_decoder_layers (int): The number of sub-decoder-layers in the Transformer decoder.
|
||||||
|
mlp_dims (Optional[int]): The dimensionality of the feedforward network model in each Transformer layer.
|
||||||
|
Defaults to 4*dims if not provided.
|
||||||
|
dropout (float): The dropout value for Transformer encoder/decoder.
|
||||||
|
activation (Callable[[Any], Any]): the activation function of encoder/decoder intermediate layer
|
||||||
|
custom_encoder (Optional[Any]): A custom encoder to replace the standard Transformer encoder.
|
||||||
|
custom_decoder (Optional[Any]): A custom decoder to replace the standard Transformer decoder.
|
||||||
|
norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before
|
||||||
|
other attention and feedforward operations, otherwise after. Default is``False``.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dims: int = 512,
|
dims: int = 512,
|
||||||
@ -301,22 +277,25 @@ class Transformer(Module):
|
|||||||
num_encoder_layers: int = 6,
|
num_encoder_layers: int = 6,
|
||||||
num_decoder_layers: int = 6,
|
num_decoder_layers: int = 6,
|
||||||
mlp_dims: Optional[int] = None,
|
mlp_dims: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
activation: Callable[[Any], Any] = relu,
|
||||||
custom_encoder: Optional[Any] = None,
|
custom_encoder: Optional[Any] = None,
|
||||||
custom_decoder: Optional[Any] = None,
|
custom_decoder: Optional[Any] = None,
|
||||||
|
norm_first: bool = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if custom_encoder is not None:
|
if custom_encoder is not None:
|
||||||
self.encoder = custom_encoder
|
self.encoder = custom_encoder
|
||||||
else:
|
else:
|
||||||
self.encoder = TransformerEncoder(
|
self.encoder = TransformerEncoder(
|
||||||
num_encoder_layers, dims, num_heads, mlp_dims
|
num_encoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_decoder is not None:
|
if custom_decoder is not None:
|
||||||
self.decoder = custom_decoder
|
self.decoder = custom_decoder
|
||||||
else:
|
else:
|
||||||
self.decoder = TransformerDecoder(
|
self.decoder = TransformerDecoder(
|
||||||
num_decoder_layers, dims, num_heads, mlp_dims
|
num_decoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
||||||
|
Loading…
Reference in New Issue
Block a user