diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 51a7f55f4..ee27042eb 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -147,7 +147,6 @@ class TransformerEncoderLayer(Module): return y - class TransformerEncoder(Module): def __init__( self, @@ -275,18 +274,18 @@ class Transformer(Module): between encoder and decoder happens through the attention mechanism. Args: - dims (int): The number of expected features in the encoder/decoder inputs. - num_heads (int): The number of heads in the multi-head attention models. - num_encoder_layers (int): The number of sub-encoder-layers in the Transformer encoder. - num_decoder_layers (int): The number of sub-decoder-layers in the Transformer decoder. - mlp_dims (Optional[int]): The dimensionality of the feedforward network model in each Transformer layer. - Defaults to 4*dims if not provided. - dropout (float): The dropout value for Transformer encoder/decoder. - activation (Callable[[Any], Any]): the activation function of encoder/decoder intermediate layer - custom_encoder (Optional[Any]): A custom encoder to replace the standard Transformer encoder. - custom_decoder (Optional[Any]): A custom decoder to replace the standard Transformer decoder. - norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before - other attention and feedforward operations, otherwise after. Default is``False``. + dims (int): The number of expected features in the encoder/decoder inputs (default: 512) + num_heads (int): The number of heads in the multi-head attention models (default: 8) + num_encoder_layers (int): The number of sub-encoder-layers in the Transformer encoder (default: 6) + num_decoder_layers (int): The number of sub-decoder-layers in the Transformer decoder (default: 6) + mlp_dims (Optional[int]): The dimensionality of the feedforward network model in each Transformer layer, + Defaults to 4*dims if not provided (default: None). + dropout (float): The dropout value for Transformer encoder/decoder (default: 0.0) + activation (Callable[[Any], Any]): the activation function of encoder/decoder intermediate layer (default: relu) + custom_encoder (Optional[Any]): A custom encoder to replace the standard Transformer encoder (default: None) + custom_decoder (Optional[Any]): A custom decoder to replace the standard Transformer decoder (default: None) + norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before + other attention and feedforward operations, otherwise after. (default: False) """ def __init__(