From 297e69017cf37e9cb5de4b93e4b3f0d7adf1b63c Mon Sep 17 00:00:00 2001 From: junwoo-yun Date: Mon, 25 Dec 2023 07:39:42 +0800 Subject: [PATCH] add docstring, activation, norm_first --- python/mlx/nn/layers/transformer.py | 219 +++++++++++++--------------- 1 file changed, 99 insertions(+), 120 deletions(-) diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 9b70221ff..bbdacbfa6 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -1,9 +1,10 @@ # Copyright © 2023 Apple Inc. import math -from typing import Any, Optional +from typing import Any, Optional, Callable import mlx.core as mx +from mlx.nn.layers.activations import relu from mlx.nn.layers.base import Module from mlx.nn.layers.dropout import Dropout from mlx.nn.layers.linear import Linear @@ -98,34 +99,14 @@ class MultiHeadAttention(Module): class TransformerEncoderLayer(Module): - def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): - super().__init__() - mlp_dims = mlp_dims or dims * 4 - self.attention = MultiHeadAttention(dims, num_heads) - self.ln1 = LayerNorm(dims) - self.ln2 = LayerNorm(dims) - self.linear1 = Linear(dims, mlp_dims) - self.linear2 = Linear(mlp_dims, dims) - - def __call__(self, x, mask): - y = self.attention(x, x, x, mask) - y = self.ln1(x + y) - - y = self.linear1(y) - y = mx.maximum(y, 0) - y = self.linear2(y) - - y = self.ln2(x + y) - return y - - -class TransformerEncoderLayerWithDropout(Module): def __init__( self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, - dropout_rate: float = 0.1, + dropout: float = 0.0, + activation: Callable[[Any], Any] = relu, + norm_first: bool = False, ): super().__init__() mlp_dims = mlp_dims or dims * 4 @@ -134,49 +115,46 @@ class TransformerEncoderLayerWithDropout(Module): self.ln2 = LayerNorm(dims) self.linear1 = Linear(dims, mlp_dims) self.linear2 = Linear(mlp_dims, dims) - self.dropout1 = Dropout(dropout_rate) - self.dropout2 = Dropout(dropout_rate) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.activation = activation + self.norm_first = norm_first def __call__(self, x, mask): - y = self.attention(x, x, x, mask) - y = self.dropout1(y) - y = self.ln1(x + y) + if self.norm_first: + y = self.ln1(x) + y = self.attention(y, y, y, mask) + y = self.dropout1(y) + x = x + y - y = self.linear1(y) - y = mx.maximum(y, 0) - y = self.linear2(y) - y = self.dropout2(y) + y = self.ln2(x) + y = self.linear1(y) + y = self.activation(y) + y = self.dropout2(y) + y = self.linear2(y) + y = x + y + + else: + y = self.attention(x, x, x, mask) + y = self.dropout1(y) + y = self.ln1(x + y) - y = self.ln2(x + y) + y = self.linear1(y) + y = self.activation(y) + y = self.dropout2(y) + y = self.linear2(y) + y = self.ln2(x + y) + return y class TransformerEncoder(Module): def __init__( - self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu ): super().__init__() self.layers = [ - TransformerEncoderLayer(dims, num_heads, mlp_dims) - for i in range(num_layers) - ] - self.ln = LayerNorm(dims) - - def __call__(self, x, mask): - for l in self.layers: - x = l(x, mask) - x = self.ln(x) - - return x - - -class TransformerEncoderWithDropout(Module): - def __init__( - self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None - ): - super().__init__() - self.layers = [ - TransformerEncoderLayerWithDropout(dims, num_heads, mlp_dims) + TransformerEncoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation) for i in range(num_layers) ] self.ln = LayerNorm(dims) @@ -190,39 +168,14 @@ class TransformerEncoderWithDropout(Module): class TransformerDecoderLayer(Module): - def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): - super().__init__() - mlp_dims = mlp_dims or dims * 4 - self.self_attention = MultiHeadAttention(dims, num_heads) - self.cross_attention = MultiHeadAttention(dims, num_heads) - self.ln1 = LayerNorm(dims) - self.ln2 = LayerNorm(dims) - self.ln3 = LayerNorm(dims) - self.linear1 = Linear(dims, mlp_dims) - self.linear2 = Linear(mlp_dims, dims) - - def __call__(self, x, memory, x_mask, memory_mask): - y = self.self_attention(x, x, x, x_mask) - x = self.ln1(x + y) - - y = self.cross_attention(y, memory, memory, memory_mask) - x = self.ln1(x + y) - - y = self.linear1(x) - y = mx.maximum(y, 0) - y = self.linear2(y) - y = self.ln3(x + y) - - return y - - -class TransformerDecoderLayerWithDropout(Module): def __init__( self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, - dropout_rate: float = 0.1, + dropout: float = 0.0, + activation: Callable[[Any], Any] = relu, + norm_first: bool = False, ): super().__init__() mlp_dims = mlp_dims or dims * 4 @@ -233,54 +186,56 @@ class TransformerDecoderLayerWithDropout(Module): self.ln3 = LayerNorm(dims) self.linear1 = Linear(dims, mlp_dims) self.linear2 = Linear(mlp_dims, dims) - self.dropout1 = Dropout(dropout_rate) - self.dropout2 = Dropout(dropout_rate) - self.dropout3 = Dropout(dropout_rate) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + self.activation = activation + self.norm_first = norm_first def __call__(self, x, memory, x_mask, memory_mask): - y = self.self_attention(x, x, x, x_mask) - y = self.dropout1(y) - x = self.ln1(x + y) + if self.norm_first: + y = self.ln1(x) + y = self.self_attention(y, y, y, x_mask) + y = self.dropout1(y) + x = x + y - y = self.cross_attention(y, memory, memory, memory_mask) - y = self.dropout2(y) - x = self.ln1(x + y) + y = self.ln2(x) + y = self.cross_attention(y, memory, memory, memory_mask) + y = self.dropout2(y) + x = x + y - y = self.linear1(x) - y = mx.maximum(y, 0) - y = self.linear2(y) - y = self.dropout3(y) - y = self.ln3(x + y) + y = self.ln3(x) + y = self.linear1(y) + y = self.activation(y) + y = self.dropout3(y) + y = self.linear2(y) + x = x + y + + else: + y = self.self_attention(x, x, x, x_mask) + y = self.dropout1(y) + x = self.ln1(x + y) + + y = self.cross_attention(y, memory, memory, memory_mask) + y = self.dropout2(y) + x = self.ln1(x + y) + + y = self.linear1(x) + y = self.activation(y) + y = self.dropout3(y) + y = self.linear2(y) + y = self.ln3(x + y) return y class TransformerDecoder(Module): def __init__( - self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu ): super().__init__() self.layers = [ - TransformerDecoderLayer(dims, num_heads, mlp_dims) - for i in range(num_layers) - ] - self.ln = LayerNorm(dims) - - def __call__(self, x, memory, x_mask, memory_mask): - for l in self.layers: - x = l(x, memory, x_mask, memory_mask) - x = self.ln(x) - - return x - - -class TransformerDecoderWithDropout(Module): - def __init__( - self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None - ): - super().__init__() - self.layers = [ - TransformerDecoderLayerWithDropout(dims, num_heads, mlp_dims) + TransformerDecoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation) for i in range(num_layers) ] self.ln = LayerNorm(dims) @@ -294,6 +249,27 @@ class TransformerDecoderWithDropout(Module): class Transformer(Module): + """ + Implements a standard Transformer model based on the paper "Attention Is All You Need". + + The Transformer model consists of an encoder and a decoder. The encoder processes + the input sequence and the decoder generates the output sequence. The interaction + between encoder and decoder happens through the attention mechanism. + + Args: + dims (int): The number of expected features in the encoder/decoder inputs. + num_heads (int): The number of heads in the multi-head attention models. + num_encoder_layers (int): The number of sub-encoder-layers in the Transformer encoder. + num_decoder_layers (int): The number of sub-decoder-layers in the Transformer decoder. + mlp_dims (Optional[int]): The dimensionality of the feedforward network model in each Transformer layer. + Defaults to 4*dims if not provided. + dropout (float): The dropout value for Transformer encoder/decoder. + activation (Callable[[Any], Any]): the activation function of encoder/decoder intermediate layer + custom_encoder (Optional[Any]): A custom encoder to replace the standard Transformer encoder. + custom_decoder (Optional[Any]): A custom decoder to replace the standard Transformer decoder. + norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before + other attention and feedforward operations, otherwise after. Default is``False``. + """ def __init__( self, dims: int = 512, @@ -301,22 +277,25 @@ class Transformer(Module): num_encoder_layers: int = 6, num_decoder_layers: int = 6, mlp_dims: Optional[int] = None, + dropout: float = 0.0, + activation: Callable[[Any], Any] = relu, custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None, + norm_first: bool = False ): super().__init__() if custom_encoder is not None: self.encoder = custom_encoder else: self.encoder = TransformerEncoder( - num_encoder_layers, dims, num_heads, mlp_dims + num_encoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first ) if custom_decoder is not None: self.decoder = custom_decoder else: self.decoder = TransformerDecoder( - num_decoder_layers, dims, num_heads, mlp_dims + num_decoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first ) def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):