diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index ee27042eb..431a01096 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -147,6 +147,7 @@ class TransformerEncoderLayer(Module): return y + class TransformerEncoder(Module): def __init__( self, @@ -279,13 +280,13 @@ class Transformer(Module): num_encoder_layers (int): The number of sub-encoder-layers in the Transformer encoder (default: 6) num_decoder_layers (int): The number of sub-decoder-layers in the Transformer decoder (default: 6) mlp_dims (Optional[int]): The dimensionality of the feedforward network model in each Transformer layer, - Defaults to 4*dims if not provided (default: None). + Defaults to 4*dims if not provided (default: None) dropout (float): The dropout value for Transformer encoder/decoder (default: 0.0) activation (Callable[[Any], Any]): the activation function of encoder/decoder intermediate layer (default: relu) custom_encoder (Optional[Any]): A custom encoder to replace the standard Transformer encoder (default: None) custom_decoder (Optional[Any]): A custom decoder to replace the standard Transformer decoder (default: None) - norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before - other attention and feedforward operations, otherwise after. (default: False) + norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before + other attention and feedforward operations, otherwise after (default: False) """ def __init__(