diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index bbdacbfa6..f1d9b42de 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. import math -from typing import Any, Optional, Callable +from typing import Any, Callable, Optional import mlx.core as mx from mlx.nn.layers.activations import relu @@ -133,7 +133,7 @@ class TransformerEncoderLayer(Module): y = self.dropout2(y) y = self.linear2(y) y = x + y - + else: y = self.attention(x, x, x, mask) y = self.dropout1(y) @@ -144,17 +144,26 @@ class TransformerEncoderLayer(Module): y = self.dropout2(y) y = self.linear2(y) y = self.ln2(x + y) - + return y class TransformerEncoder(Module): def __init__( - self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu + self, + num_layers: int, + dims: int, + num_heads: int, + mlp_dims: Optional[int] = None, + dropout: float = 0.0, + norm_first: bool = False, + activation=relu, ): super().__init__() self.layers = [ - TransformerEncoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation) + TransformerEncoderLayer( + dims, num_heads, mlp_dims, dropout, norm_first, activation + ) for i in range(num_layers) ] self.ln = LayerNorm(dims) @@ -210,7 +219,7 @@ class TransformerDecoderLayer(Module): y = self.dropout3(y) y = self.linear2(y) x = x + y - + else: y = self.self_attention(x, x, x, x_mask) y = self.dropout1(y) @@ -231,11 +240,20 @@ class TransformerDecoderLayer(Module): class TransformerDecoder(Module): def __init__( - self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu + self, + num_layers: int, + dims: int, + num_heads: int, + mlp_dims: Optional[int] = None, + dropout: float = 0.0, + norm_first: bool = False, + activation=relu, ): super().__init__() self.layers = [ - TransformerDecoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation) + TransformerDecoderLayer( + dims, num_heads, mlp_dims, dropout, norm_first, activation + ) for i in range(num_layers) ] self.ln = LayerNorm(dims) @@ -261,15 +279,16 @@ class Transformer(Module): num_heads (int): The number of heads in the multi-head attention models. num_encoder_layers (int): The number of sub-encoder-layers in the Transformer encoder. num_decoder_layers (int): The number of sub-decoder-layers in the Transformer decoder. - mlp_dims (Optional[int]): The dimensionality of the feedforward network model in each Transformer layer. + mlp_dims (Optional[int]): The dimensionality of the feedforward network model in each Transformer layer. Defaults to 4*dims if not provided. dropout (float): The dropout value for Transformer encoder/decoder. activation (Callable[[Any], Any]): the activation function of encoder/decoder intermediate layer - custom_encoder (Optional[Any]): A custom encoder to replace the standard Transformer encoder. - custom_decoder (Optional[Any]): A custom decoder to replace the standard Transformer decoder. + custom_encoder (Optional[Any]): A custom encoder to replace the standard Transformer encoder. + custom_decoder (Optional[Any]): A custom decoder to replace the standard Transformer decoder. norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before other attention and feedforward operations, otherwise after. Default is``False``. """ + def __init__( self, dims: int = 512, @@ -281,21 +300,33 @@ class Transformer(Module): activation: Callable[[Any], Any] = relu, custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None, - norm_first: bool = False + norm_first: bool = False, ): super().__init__() if custom_encoder is not None: self.encoder = custom_encoder else: self.encoder = TransformerEncoder( - num_encoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first + num_encoder_layers, + dims, + num_heads, + mlp_dims, + dropout, + activation, + norm_first, ) if custom_decoder is not None: self.decoder = custom_decoder else: self.decoder = TransformerDecoder( - num_decoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first + num_decoder_layers, + dims, + num_heads, + mlp_dims, + dropout, + activation, + norm_first, ) def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):