From 4bd3c02c0039b13817f2b5929879fd2f52f0fcde Mon Sep 17 00:00:00 2001 From: Jyun1998 Date: Fri, 15 Dec 2023 03:26:17 +0800 Subject: [PATCH] add transformer with dropout, fix transformer ffm, layernorm order --- python/mlx/nn/layers/transformer.py | 171 +++++++++++++++++++++++++--- 1 file changed, 153 insertions(+), 18 deletions(-) diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index e24957066..15e310b35 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -7,7 +7,49 @@ import mlx.core as mx from mlx.nn.layers.base import Module from mlx.nn.layers.linear import Linear from mlx.nn.layers.normalization import LayerNorm +from mlx.nn.layers.dropout import Dropout +from mlx.nn.layers.positional_encoding import SinusoidalPositionalEncoding +class MyPosEncoding(SinusoidalPositionalEncoding): + def __init__( + self, + dims: int, + min_freq: float = 0.0001, + max_freq: float = 1, + scale: Optional[float] = None, + cos_first: bool = False, + full_turns: bool = False, + ): + super().__init__( + dims, + min_freq=min_freq, + max_freq=max_freq, + scale=scale, + cos_first=cos_first, + full_turns=full_turns + ) + self.dims = dims + + def __call__(self, x): + seq_length = x.shape[1] # Assuming x.shape [batch_size, sequence_length, embedding_dim] + position = mx.arange(seq_length)[..., None] * self._sigmas + + # Generate positional encodings + div_term = mx.exp(mx.arange(0, self.dims, 2) * -(math.log(10000.0) / self.dims)) + sinusoid_inp = position * div_term + + y = mx.zeros((seq_length, self.dims)) + if self.cos_first: + y[:, 0::2] = mx.cos(sinusoid_inp) + y[:, 1::2] = mx.sin(sinusoid_inp) + else: + y[:, 0::2] = mx.sin(sinusoid_inp) + y[:, 1::2] = mx.cos(sinusoid_inp) + + if self.scale != 1: + y = y * self.scale + + return x + y class MultiHeadAttention(Module): """Implements the scaled dot product attention with multiple heads. @@ -107,17 +149,41 @@ class TransformerEncoderLayer(Module): self.linear2 = Linear(mlp_dims, dims) def __call__(self, x, mask): - y = self.ln1(x) - y = self.attention(y, y, y, mask) - x = x + y - - y = self.ln2(x) + y = self.attention(x, x, x, mask) + y = self.ln1(x + y) + y = self.linear1(y) y = mx.maximum(y, 0) y = self.linear2(y) - x = x + y - return x + y = self.ln2(x + y) + return y + + +class TransformerEncoderLayerWithDropout(Module): + def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout_rate: float = 0.1): + super().__init__() + mlp_dims = mlp_dims or dims * 4 + self.attention = MultiHeadAttention(dims, num_heads) + self.ln1 = LayerNorm(dims) + self.ln2 = LayerNorm(dims) + self.linear1 = Linear(dims, mlp_dims) + self.linear2 = Linear(mlp_dims, dims) + self.dropout1 = Dropout(dropout_rate) + self.dropout2 = Dropout(dropout_rate) + + def __call__(self, x, mask): + y = self.attention(x, x, x, mask) + y = self.dropout1(y) + y = self.ln1(x + y) + + y = self.linear1(y) + y = mx.maximum(y, 0) + y = self.linear2(y) + y = self.dropout2(y) + + y = self.ln2(x + y) + return y class TransformerEncoder(Module): @@ -139,6 +205,26 @@ class TransformerEncoder(Module): return x +class TransformerEncoderWithDropout(Module): + def __init__( + self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + ): + super().__init__() + self.layers = [ + TransformerEncoderLayerWithDropout(dims, num_heads, mlp_dims) + for i in range(num_layers) + ] + self.ln = LayerNorm(dims) + + def __call__(self, x, mask): + for l in self.layers: + x = l(x, mask) + x = self.ln(x) + + return x + + + class TransformerDecoderLayer(Module): def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): super().__init__() @@ -152,21 +238,51 @@ class TransformerDecoderLayer(Module): self.linear2 = Linear(mlp_dims, dims) def __call__(self, x, memory, x_mask, memory_mask): - y = self.ln1(x) - y = self.self_attention(y, y, y, x_mask) - x = x + y + y = self.self_attention(x, x, x, x_mask) + x = self.ln1(x + y) + + y = self.cross_attention(y, memory, memory, memory_mask) + x = self.ln1(x + y) - y = self.ln2(x) - y = self.cross_attention(x, memory, memory, memory_mask) - x = x + y - - y = self.ln3(x) - y = self.linear1(y) + y = self.linear1(x) y = mx.maximum(y, 0) y = self.linear2(y) - x = x + y + y = self.ln3(x + y) - return x + return y + + +class TransformerDecoderLayerWithDropout(Module): + def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout_rate: float = 0.1): + super().__init__() + mlp_dims = mlp_dims or dims * 4 + self.self_attention = MultiHeadAttention(dims, num_heads) + self.cross_attention = MultiHeadAttention(dims, num_heads) + self.ln1 = LayerNorm(dims) + self.ln2 = LayerNorm(dims) + self.ln3 = LayerNorm(dims) + self.linear1 = Linear(dims, mlp_dims) + self.linear2 = Linear(mlp_dims, dims) + self.dropout1 = Dropout(dropout_rate) + self.dropout2 = Dropout(dropout_rate) + self.dropout3 = Dropout(dropout_rate) + + def __call__(self, x, memory, x_mask, memory_mask): + y = self.self_attention(x, x, x, x_mask) + y = self.dropout1(y) + x = self.ln1(x + y) + + y = self.cross_attention(y, memory, memory, memory_mask) + y = self.dropout2(y) + x = self.ln1(x + y) + + y = self.linear1(x) + y = mx.maximum(y, 0) + y = self.linear2(y) + y = self.dropout3(y) + y = self.ln3(x + y) + + return y class TransformerDecoder(Module): @@ -188,6 +304,25 @@ class TransformerDecoder(Module): return x +class TransformerDecoderWithDropout(Module): + def __init__( + self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + ): + super().__init__() + self.layers = [ + TransformerDecoderLayerWithDropout(dims, num_heads, mlp_dims) + for i in range(num_layers) + ] + self.ln = LayerNorm(dims) + + def __call__(self, x, memory, x_mask, memory_mask): + for l in self.layers: + x = l(x, memory, x_mask, memory_mask) + x = self.ln(x) + + return x + + class Transformer(Module): def __init__( self,