add docstring, activation, norm_first

This commit is contained in:
junwoo-yun 2023-12-25 07:39:42 +08:00
parent df1f8aa3be
commit 297e69017c

View File

@ -1,9 +1,10 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import math import math
from typing import Any, Optional from typing import Any, Optional, Callable
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.activations import relu
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.dropout import Dropout from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.linear import Linear from mlx.nn.layers.linear import Linear
@ -98,34 +99,14 @@ class MultiHeadAttention(Module):
class TransformerEncoderLayer(Module): class TransformerEncoderLayer(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads)
self.ln1 = LayerNorm(dims)
self.ln2 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims)
def __call__(self, x, mask):
y = self.attention(x, x, x, mask)
y = self.ln1(x + y)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
y = self.ln2(x + y)
return y
class TransformerEncoderLayerWithDropout(Module):
def __init__( def __init__(
self, self,
dims: int, dims: int,
num_heads: int, num_heads: int,
mlp_dims: Optional[int] = None, mlp_dims: Optional[int] = None,
dropout_rate: float = 0.1, dropout: float = 0.0,
activation: Callable[[Any], Any] = relu,
norm_first: bool = False,
): ):
super().__init__() super().__init__()
mlp_dims = mlp_dims or dims * 4 mlp_dims = mlp_dims or dims * 4
@ -134,49 +115,46 @@ class TransformerEncoderLayerWithDropout(Module):
self.ln2 = LayerNorm(dims) self.ln2 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims) self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims) self.linear2 = Linear(mlp_dims, dims)
self.dropout1 = Dropout(dropout_rate) self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout_rate) self.dropout2 = Dropout(dropout)
self.activation = activation
self.norm_first = norm_first
def __call__(self, x, mask): def __call__(self, x, mask):
if self.norm_first:
y = self.ln1(x)
y = self.attention(y, y, y, mask)
y = self.dropout1(y)
x = x + y
y = self.ln2(x)
y = self.linear1(y)
y = self.activation(y)
y = self.dropout2(y)
y = self.linear2(y)
y = x + y
else:
y = self.attention(x, x, x, mask) y = self.attention(x, x, x, mask)
y = self.dropout1(y) y = self.dropout1(y)
y = self.ln1(x + y) y = self.ln1(x + y)
y = self.linear1(y) y = self.linear1(y)
y = mx.maximum(y, 0) y = self.activation(y)
y = self.linear2(y)
y = self.dropout2(y) y = self.dropout2(y)
y = self.linear2(y)
y = self.ln2(x + y) y = self.ln2(x + y)
return y return y
class TransformerEncoder(Module): class TransformerEncoder(Module):
def __init__( def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu
): ):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims) TransformerEncoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation)
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
def __call__(self, x, mask):
for l in self.layers:
x = l(x, mask)
x = self.ln(x)
return x
class TransformerEncoderWithDropout(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerEncoderLayerWithDropout(dims, num_heads, mlp_dims)
for i in range(num_layers) for i in range(num_layers)
] ]
self.ln = LayerNorm(dims) self.ln = LayerNorm(dims)
@ -190,39 +168,14 @@ class TransformerEncoderWithDropout(Module):
class TransformerDecoderLayer(Module): class TransformerDecoderLayer(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.self_attention = MultiHeadAttention(dims, num_heads)
self.cross_attention = MultiHeadAttention(dims, num_heads)
self.ln1 = LayerNorm(dims)
self.ln2 = LayerNorm(dims)
self.ln3 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims)
def __call__(self, x, memory, x_mask, memory_mask):
y = self.self_attention(x, x, x, x_mask)
x = self.ln1(x + y)
y = self.cross_attention(y, memory, memory, memory_mask)
x = self.ln1(x + y)
y = self.linear1(x)
y = mx.maximum(y, 0)
y = self.linear2(y)
y = self.ln3(x + y)
return y
class TransformerDecoderLayerWithDropout(Module):
def __init__( def __init__(
self, self,
dims: int, dims: int,
num_heads: int, num_heads: int,
mlp_dims: Optional[int] = None, mlp_dims: Optional[int] = None,
dropout_rate: float = 0.1, dropout: float = 0.0,
activation: Callable[[Any], Any] = relu,
norm_first: bool = False,
): ):
super().__init__() super().__init__()
mlp_dims = mlp_dims or dims * 4 mlp_dims = mlp_dims or dims * 4
@ -233,11 +186,32 @@ class TransformerDecoderLayerWithDropout(Module):
self.ln3 = LayerNorm(dims) self.ln3 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims) self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims) self.linear2 = Linear(mlp_dims, dims)
self.dropout1 = Dropout(dropout_rate) self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout_rate) self.dropout2 = Dropout(dropout)
self.dropout3 = Dropout(dropout_rate) self.dropout3 = Dropout(dropout)
self.activation = activation
self.norm_first = norm_first
def __call__(self, x, memory, x_mask, memory_mask): def __call__(self, x, memory, x_mask, memory_mask):
if self.norm_first:
y = self.ln1(x)
y = self.self_attention(y, y, y, x_mask)
y = self.dropout1(y)
x = x + y
y = self.ln2(x)
y = self.cross_attention(y, memory, memory, memory_mask)
y = self.dropout2(y)
x = x + y
y = self.ln3(x)
y = self.linear1(y)
y = self.activation(y)
y = self.dropout3(y)
y = self.linear2(y)
x = x + y
else:
y = self.self_attention(x, x, x, x_mask) y = self.self_attention(x, x, x, x_mask)
y = self.dropout1(y) y = self.dropout1(y)
x = self.ln1(x + y) x = self.ln1(x + y)
@ -247,9 +221,9 @@ class TransformerDecoderLayerWithDropout(Module):
x = self.ln1(x + y) x = self.ln1(x + y)
y = self.linear1(x) y = self.linear1(x)
y = mx.maximum(y, 0) y = self.activation(y)
y = self.linear2(y)
y = self.dropout3(y) y = self.dropout3(y)
y = self.linear2(y)
y = self.ln3(x + y) y = self.ln3(x + y)
return y return y
@ -257,30 +231,11 @@ class TransformerDecoderLayerWithDropout(Module):
class TransformerDecoder(Module): class TransformerDecoder(Module):
def __init__( def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu
): ):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerDecoderLayer(dims, num_heads, mlp_dims) TransformerDecoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation)
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
def __call__(self, x, memory, x_mask, memory_mask):
for l in self.layers:
x = l(x, memory, x_mask, memory_mask)
x = self.ln(x)
return x
class TransformerDecoderWithDropout(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerDecoderLayerWithDropout(dims, num_heads, mlp_dims)
for i in range(num_layers) for i in range(num_layers)
] ]
self.ln = LayerNorm(dims) self.ln = LayerNorm(dims)
@ -294,6 +249,27 @@ class TransformerDecoderWithDropout(Module):
class Transformer(Module): class Transformer(Module):
"""
Implements a standard Transformer model based on the paper "Attention Is All You Need".
The Transformer model consists of an encoder and a decoder. The encoder processes
the input sequence and the decoder generates the output sequence. The interaction
between encoder and decoder happens through the attention mechanism.
Args:
dims (int): The number of expected features in the encoder/decoder inputs.
num_heads (int): The number of heads in the multi-head attention models.
num_encoder_layers (int): The number of sub-encoder-layers in the Transformer encoder.
num_decoder_layers (int): The number of sub-decoder-layers in the Transformer decoder.
mlp_dims (Optional[int]): The dimensionality of the feedforward network model in each Transformer layer.
Defaults to 4*dims if not provided.
dropout (float): The dropout value for Transformer encoder/decoder.
activation (Callable[[Any], Any]): the activation function of encoder/decoder intermediate layer
custom_encoder (Optional[Any]): A custom encoder to replace the standard Transformer encoder.
custom_decoder (Optional[Any]): A custom decoder to replace the standard Transformer decoder.
norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before
other attention and feedforward operations, otherwise after. Default is``False``.
"""
def __init__( def __init__(
self, self,
dims: int = 512, dims: int = 512,
@ -301,22 +277,25 @@ class Transformer(Module):
num_encoder_layers: int = 6, num_encoder_layers: int = 6,
num_decoder_layers: int = 6, num_decoder_layers: int = 6,
mlp_dims: Optional[int] = None, mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation: Callable[[Any], Any] = relu,
custom_encoder: Optional[Any] = None, custom_encoder: Optional[Any] = None,
custom_decoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
norm_first: bool = False
): ):
super().__init__() super().__init__()
if custom_encoder is not None: if custom_encoder is not None:
self.encoder = custom_encoder self.encoder = custom_encoder
else: else:
self.encoder = TransformerEncoder( self.encoder = TransformerEncoder(
num_encoder_layers, dims, num_heads, mlp_dims num_encoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first
) )
if custom_decoder is not None: if custom_decoder is not None:
self.decoder = custom_decoder self.decoder = custom_decoder
else: else:
self.decoder = TransformerDecoder( self.decoder = TransformerDecoder(
num_decoder_layers, dims, num_heads, mlp_dims num_decoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first
) )
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask): def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):