diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 5ef45d60d..aa59e0af2 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -9,7 +9,7 @@ Layers :toctree: _autosummary :template: nn-module-template.rst - Embedding + Sequential ReLU PReLU GELU @@ -17,17 +17,19 @@ Layers Step SELU Mish + Embedding Linear + QuantizedLinear Conv1d Conv2d BatchNorm LayerNorm RMSNorm GroupNorm - RoPE - MultiHeadAttention - Sequential - QuantizedLinear Dropout Dropout2d - + Transformer + MultiHeadAttention + ALiBi + RoPE + SinusoidalPositionalEncoding diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 5ac82356a..4dbe96eb6 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -14,6 +14,7 @@ from mlx.nn.layers.activations import ( SiLU, Softplus, Step, + Tanh, celu, elu, gelu, @@ -29,6 +30,7 @@ from mlx.nn.layers.activations import ( silu, softplus, step, + tanh, ) from mlx.nn.layers.base import Module from mlx.nn.layers.containers import Sequential @@ -41,6 +43,7 @@ from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalE from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.transformer import ( MultiHeadAttention, + Transformer, TransformerEncoder, TransformerEncoderLayer, ) diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 25d1c5268..20294380c 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -179,12 +179,12 @@ def selu(x): def prelu(x: mx.array, alpha: mx.array) -> mx.array: - r"""Applies the element-wise function: + r"""Applies the element-wise parametric ReLU. .. math:: \text{PReLU}(x) = \max(0,x) + a * \min(0,x) - Here :math:`a` is an array. + where :math:`a` is an array. """ return mx.maximum(0, x) + alpha * mx.minimum(0, x) diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index db436f407..6c363e368 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -8,21 +8,22 @@ from mlx.nn.layers.base import Module class RoPE(Module): - """Implements the rotary positional encoding [1]. + """Implements the rotary positional encoding. The traditional implementation rotates consecutive pairs of elements in the feature dimension while the default implementation rotates pairs with stride half the feature dimensions for efficiency. - [1]: https://arxiv.org/abs/2104.09864 + For more details see `RoFormer: Enhanced Transformer with Rotary Position + Embedding `_. Args: dims (int): The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged. traditional (bool, optional): If set to True choose the traditional - implementation which is slightly less efficient. Default: ``False`` + implementation which is slightly less efficient. Default: ``False``. base (float, optional): The base used to compute angular frequency for - each dimension in the positional encodings. Default: ``10000`` + each dimension in the positional encodings. Default: ``10000``. """ def __init__(self, dims: int, traditional: bool = False, base: float = 10000): @@ -89,19 +90,23 @@ class RoPE(Module): class SinusoidalPositionalEncoding(Module): - """Implements sinusoidal positional encoding similar to [1]. + r"""Implements sinusoidal positional encoding. - [1]: https://arxiv.org/abs/1706.03762 + For more details see the paper `Attention Is All You Need + `_. Args: dims (int): The dimensionality of the resulting positional embeddings. - min_freq (float): The minimum frequency expected (default: 0.0001) - max_freq (float): The maximum frequency expected (default: 1) - scale (float): Scale the embeddings by that number (default: sqrt(dims//2)) - cos_first (bool): If set to True embed using ``[cos(x); sin(x)]`` - instead of the other way around (default: False) - full_turns (bool): If set to True multiply the frequencies - with ``2 pi`` (default: False) + min_freq (float, optional): The minimum frequency expected. Default: + ``0.0001``. + max_freq (float, optional): The maximum frequency expected. Default: + ``1``. + scale (float, optional): A multiplicative scale for the embeddings. + Default: ``sqrt(dims//2)``. + cos_first (bool, optional): If ``True`` embed using ``[cos(x); sin(x)]`` + instead of the reverse. Default: ``False``. + full_turns (bool, optional): If ``True`` multiply the frequencies with + :math:`2\pi`. Default: ``False``. """ def __init__( diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 8d9efe171..c61f5405f 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -1,10 +1,12 @@ # Copyright © 2023 Apple Inc. import math -from typing import Any, Optional +from typing import Any, Callable, Optional import mlx.core as mx +from mlx.nn.layers.activations import relu from mlx.nn.layers.base import Module +from mlx.nn.layers.dropout import Dropout from mlx.nn.layers.linear import Linear from mlx.nn.layers.normalization import LayerNorm @@ -12,26 +14,34 @@ from mlx.nn.layers.normalization import LayerNorm class MultiHeadAttention(Module): """Implements the scaled dot product attention with multiple heads. - Given inputs for queries, keys and values the ``MultiHeadAttention`` produces - new values by aggregating information from the input values according to - the similarities of the input queries and keys. + Given inputs for queries, keys and values the ``MultiHeadAttention`` + produces new values by aggregating information from the input values + according to the similarities of the input queries and keys. - All inputs as well as the output are linearly projected without biases. + All inputs as well as the output are linearly projected without biases by + default. - MultiHeadAttention also expects an additive attention mask that should be - broadcastable with (batch, num_heads, # queries, # keys). The mask should - have ``-inf`` or very negative numbers to the positions that should *not* be - attended to. + ``MultiHeadAttention`` also takes an optional additive attention mask that + should be broadcastable with ``(batch, num_heads, # queries, # keys)``. The + mask should have ``-inf`` or very large negative numbers at the positions + that should *not* be attended to. Args: - dims (int): The model dimensions. If no other dims are provided then - dims is used for queries, keys, values and the output. - num_heads (int): How many attention heads to use - query_input_dims (int, optional): The input dimensions of the queries (default: dims). - key_input_dims (int, optional): The input dimensions of the keys (default: dims). - value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims). - value_dims (int, optional): The dimensions of the values after the projection (default: dims). - value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims). + dims (int): The model dimensions. This is also the default + value for the queries, keys, values, and the output. + num_heads (int): The number of attention heads to use. + query_input_dims (int, optional): The input dimensions of the queries. + Default: ``dims``. + key_input_dims (int, optional): The input dimensions of the keys. + Default: ``dims``. + value_input_dims (int, optional): The input dimensions of the values. + Default: ``key_input_dims``. + value_dims (int, optional): The dimensions of the values after the + projection. Default: ``dims``. + value_output_dims (int, optional): The dimensions the new values will + be projected to. Default: ``dims``. + bias (bool, optional): Whether or not to use a bias in the projections. + Default: ``False``. """ def __init__( @@ -49,7 +59,8 @@ class MultiHeadAttention(Module): if (dims % num_heads) != 0: raise ValueError( - f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0" + "The input feature dimensions should be divisible by the " + f"number of heads ({dims} % {num_heads}) != 0" ) query_input_dims = query_input_dims or dims @@ -97,7 +108,15 @@ class MultiHeadAttention(Module): class TransformerEncoderLayer(Module): - def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): + def __init__( + self, + dims: int, + num_heads: int, + mlp_dims: Optional[int] = None, + dropout: float = 0.0, + activation: Callable[[Any], Any] = relu, + norm_first: bool = False, + ): super().__init__() mlp_dims = mlp_dims or dims * 4 self.attention = MultiHeadAttention(dims, num_heads) @@ -105,28 +124,55 @@ class TransformerEncoderLayer(Module): self.ln2 = LayerNorm(dims) self.linear1 = Linear(dims, mlp_dims) self.linear2 = Linear(mlp_dims, dims) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.activation = activation + self.norm_first = norm_first def __call__(self, x, mask): - y = self.ln1(x) - y = self.attention(y, y, y, mask) - x = x + y + if self.norm_first: + y = self.ln1(x) + y = self.attention(y, y, y, mask) + y = self.dropout1(y) + x = x + y - y = self.ln2(x) - y = self.linear1(y) - y = mx.maximum(y, 0) - y = self.linear2(y) - x = x + y + y = self.ln2(x) + y = self.linear1(y) + y = self.activation(y) + y = self.dropout2(y) + y = self.linear2(y) + y = x + y - return x + else: + y = self.attention(x, x, x, mask) + y = self.dropout1(y) + y = self.ln1(x + y) + + y = self.linear1(y) + y = self.activation(y) + y = self.dropout2(y) + y = self.linear2(y) + y = self.ln2(x + y) + + return y class TransformerEncoder(Module): def __init__( - self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + self, + num_layers: int, + dims: int, + num_heads: int, + mlp_dims: Optional[int] = None, + dropout: float = 0.0, + activation=relu, + norm_first: bool = False, ): super().__init__() self.layers = [ - TransformerEncoderLayer(dims, num_heads, mlp_dims) + TransformerEncoderLayer( + dims, num_heads, mlp_dims, dropout, activation, norm_first + ) for i in range(num_layers) ] self.ln = LayerNorm(dims) @@ -134,13 +180,19 @@ class TransformerEncoder(Module): def __call__(self, x, mask): for l in self.layers: x = l(x, mask) - x = self.ln(x) - - return x + return self.ln(x) class TransformerDecoderLayer(Module): - def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): + def __init__( + self, + dims: int, + num_heads: int, + mlp_dims: Optional[int] = None, + dropout: float = 0.0, + activation: Callable[[Any], Any] = relu, + norm_first: bool = False, + ): super().__init__() mlp_dims = mlp_dims or dims * 4 self.self_attention = MultiHeadAttention(dims, num_heads) @@ -150,32 +202,65 @@ class TransformerDecoderLayer(Module): self.ln3 = LayerNorm(dims) self.linear1 = Linear(dims, mlp_dims) self.linear2 = Linear(mlp_dims, dims) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + self.activation = activation + self.norm_first = norm_first def __call__(self, x, memory, x_mask, memory_mask): - y = self.ln1(x) - y = self.self_attention(y, y, y, x_mask) - x = x + y + if self.norm_first: + y = self.ln1(x) + y = self.self_attention(y, y, y, x_mask) + y = self.dropout1(y) + x = x + y - y = self.ln2(x) - y = self.cross_attention(y, memory, memory, memory_mask) - x = x + y + y = self.ln2(x) + y = self.cross_attention(y, memory, memory, memory_mask) + y = self.dropout2(y) + x = x + y - y = self.ln3(x) - y = self.linear1(y) - y = mx.maximum(y, 0) - y = self.linear2(y) - x = x + y + y = self.ln3(x) + y = self.linear1(y) + y = self.activation(y) + y = self.dropout3(y) + y = self.linear2(y) + y = x + y - return x + else: + y = self.self_attention(x, x, x, x_mask) + y = self.dropout1(y) + x = self.ln1(x + y) + + y = self.cross_attention(y, memory, memory, memory_mask) + y = self.dropout2(y) + x = self.ln1(x + y) + + y = self.linear1(x) + y = self.activation(y) + y = self.dropout3(y) + y = self.linear2(y) + y = self.ln3(x + y) + + return y class TransformerDecoder(Module): def __init__( - self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + self, + num_layers: int, + dims: int, + num_heads: int, + mlp_dims: Optional[int] = None, + dropout: float = 0.0, + activation=relu, + norm_first: bool = False, ): super().__init__() self.layers = [ - TransformerDecoderLayer(dims, num_heads, mlp_dims) + TransformerDecoderLayer( + dims, num_heads, mlp_dims, dropout, activation, norm_first + ) for i in range(num_layers) ] self.ln = LayerNorm(dims) @@ -183,12 +268,47 @@ class TransformerDecoder(Module): def __call__(self, x, memory, x_mask, memory_mask): for l in self.layers: x = l(x, memory, x_mask, memory_mask) - x = self.ln(x) - - return x + return self.ln(x) class Transformer(Module): + """ + Implements a standard Transformer model. + + The implementation is based on `Attention Is All You Need + `_. + + The Transformer model contains an encoder and a decoder. The encoder + processes the input sequence and the decoder generates the output sequence. + The interaction between encoder and decoder happens through the attention + mechanism. + + Args: + dims (int, optional): The number of expected features in the + encoder/decoder inputs. Default: ``512``. + num_heads (int, optional): The number of attention heads. Default: + ``8``. + num_encoder_layers (int, optional): The number of encoder layers in the + Transformer encoder. Default: ``6``. + num_decoder_layers (int, optional): The number of decoder layers in the + Transformer decoder. Default: ``6``. + mlp_dims (int, optional): The hidden dimension of the MLP block in each + Transformer layer. Defaults to ``4*dims`` if not provided. Default: + ``None``. + dropout (float, optional): The dropout value for the Transformer + encoder and decoder. Dropout is used after each attention layer and + the activation in the MLP layer. Default: ``0.0``. + activation (function, optional): the activation function for the MLP + hidden layer. Default: :func:`mlx.nn.relu`. + custom_encoder (nn.Module, optional): A custom encoder to replace the + standard Transformer encoder. Default: ``None``. + custom_decoder (nn.Module, optional): A custom decoder to replace the + standard Transformer decoder. Default: ``None``. + norm_first (bool, optional): if ``True``, encoder and decoder layers + will perform layer normalization before attention and MLP + operations, otherwise after. Default: ``False``. + """ + def __init__( self, dims: int = 512, @@ -196,26 +316,39 @@ class Transformer(Module): num_encoder_layers: int = 6, num_decoder_layers: int = 6, mlp_dims: Optional[int] = None, + dropout: float = 0.0, + activation: Callable[[Any], Any] = relu, custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None, + norm_first: bool = False, ): super().__init__() if custom_encoder is not None: self.encoder = custom_encoder else: self.encoder = TransformerEncoder( - num_encoder_layers, dims, num_heads, mlp_dims + num_encoder_layers, + dims, + num_heads, + mlp_dims, + dropout, + activation, + norm_first, ) if custom_decoder is not None: self.decoder = custom_decoder else: self.decoder = TransformerDecoder( - num_decoder_layers, dims, num_heads, mlp_dims + num_decoder_layers, + dims, + num_heads, + mlp_dims, + dropout, + activation, + norm_first, ) def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask): memory = self.encoder(src, src_mask) - output = self.decoder(tgt, memory, tgt_mask, memory_mask) - - return output + return self.decoder(tgt, memory, tgt_mask, memory_mask)