diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 15e310b35..971a2ad56 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -5,51 +5,9 @@ from typing import Any, Optional import mlx.core as mx from mlx.nn.layers.base import Module +from mlx.nn.layers.dropout import Dropout from mlx.nn.layers.linear import Linear from mlx.nn.layers.normalization import LayerNorm -from mlx.nn.layers.dropout import Dropout -from mlx.nn.layers.positional_encoding import SinusoidalPositionalEncoding - -class MyPosEncoding(SinusoidalPositionalEncoding): - def __init__( - self, - dims: int, - min_freq: float = 0.0001, - max_freq: float = 1, - scale: Optional[float] = None, - cos_first: bool = False, - full_turns: bool = False, - ): - super().__init__( - dims, - min_freq=min_freq, - max_freq=max_freq, - scale=scale, - cos_first=cos_first, - full_turns=full_turns - ) - self.dims = dims - - def __call__(self, x): - seq_length = x.shape[1] # Assuming x.shape [batch_size, sequence_length, embedding_dim] - position = mx.arange(seq_length)[..., None] * self._sigmas - - # Generate positional encodings - div_term = mx.exp(mx.arange(0, self.dims, 2) * -(math.log(10000.0) / self.dims)) - sinusoid_inp = position * div_term - - y = mx.zeros((seq_length, self.dims)) - if self.cos_first: - y[:, 0::2] = mx.cos(sinusoid_inp) - y[:, 1::2] = mx.sin(sinusoid_inp) - else: - y[:, 0::2] = mx.sin(sinusoid_inp) - y[:, 1::2] = mx.cos(sinusoid_inp) - - if self.scale != 1: - y = y * self.scale - - return x + y class MultiHeadAttention(Module): """Implements the scaled dot product attention with multiple heads. @@ -151,7 +109,7 @@ class TransformerEncoderLayer(Module): def __call__(self, x, mask): y = self.attention(x, x, x, mask) y = self.ln1(x + y) - + y = self.linear1(y) y = mx.maximum(y, 0) y = self.linear2(y) @@ -161,7 +119,13 @@ class TransformerEncoderLayer(Module): class TransformerEncoderLayerWithDropout(Module): - def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout_rate: float = 0.1): + def __init__( + self, + dims: int, + num_heads: int, + mlp_dims: Optional[int] = None, + dropout_rate: float = 0.1, + ): super().__init__() mlp_dims = mlp_dims or dims * 4 self.attention = MultiHeadAttention(dims, num_heads) @@ -176,7 +140,7 @@ class TransformerEncoderLayerWithDropout(Module): y = self.attention(x, x, x, mask) y = self.dropout1(y) y = self.ln1(x + y) - + y = self.linear1(y) y = mx.maximum(y, 0) y = self.linear2(y) @@ -224,7 +188,6 @@ class TransformerEncoderWithDropout(Module): return x - class TransformerDecoderLayer(Module): def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): super().__init__() @@ -240,7 +203,7 @@ class TransformerDecoderLayer(Module): def __call__(self, x, memory, x_mask, memory_mask): y = self.self_attention(x, x, x, x_mask) x = self.ln1(x + y) - + y = self.cross_attention(y, memory, memory, memory_mask) x = self.ln1(x + y) @@ -253,7 +216,13 @@ class TransformerDecoderLayer(Module): class TransformerDecoderLayerWithDropout(Module): - def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout_rate: float = 0.1): + def __init__( + self, + dims: int, + num_heads: int, + mlp_dims: Optional[int] = None, + dropout_rate: float = 0.1, + ): super().__init__() mlp_dims = mlp_dims or dims * 4 self.self_attention = MultiHeadAttention(dims, num_heads) @@ -271,7 +240,7 @@ class TransformerDecoderLayerWithDropout(Module): y = self.self_attention(x, x, x, x_mask) y = self.dropout1(y) x = self.ln1(x + y) - + y = self.cross_attention(y, memory, memory, memory_mask) y = self.dropout2(y) x = self.ln1(x + y)