diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index d2c224cbf..2c586cd3e 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -1,7 +1,7 @@ # Copyright © 2023 Apple Inc. import math -from typing import Optional +from typing import Optional, Any import mlx.core as mx from mlx.nn.layers.base import Module @@ -136,3 +136,85 @@ class TransformerEncoder(Module): x = self.ln(x) return x + + +class TransformerDecoderLayer(Module): + def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): + super().__init__() + mlp_dims = mlp_dims or dims * 4 + self.self_attention = MultiHeadAttention(dims, num_heads) + self.cross_attention = MultiHeadAttention(dims, num_heads) + self.ln1 = LayerNorm(dims) + self.ln2 = LayerNorm(dims) + self.ln3 = LayerNorm(dims) + self.linear1 = Linear(dims, mlp_dims) + self.linear2 = Linear(mlp_dims, dims) + + def __call__(self, x, memory, x_mask, memory_mask): + y = self.ln1(x) + y = self.self_attention(y, y, y, x_mask) + x = x + y + + y = self.ln2(x) + y = self.cross_attention(x, memory, memory, memory_mask) + x = x + y + + y = self.ln3(x) + y = self.linear1(y) + y = mx.maximum(y, 0) + y = self.linear2(y) + x = x + y + + return x + + +class TransformerDecoder(Module): + def __init__( + self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + ): + super().__init__() + self.layers = [ + TransformerDecoderLayer(dims, num_heads, mlp_dims) + for i in range(num_layers) + ] + self.ln = LayerNorm(dims) + + def __call__(self, x, memory, x_mask, memory_mask): + for l in self.layers: + x = l(x, memory, x_mask, memory_mask) + x = self.ln(x) + + return x + + +class Transformer(Module): + def __init__( + self, + dims: int = 512, + num_heads: int = 8, + num_encoder_layers: int = 6, + num_decoder_layers: int = 6, + mlp_dims: Optional[int] = None, + custom_encoder: Optional[Any] = None, + custom_decoder: Optional[Any] = None, + ): + super().__init__() + if custom_encoder is not None: + self.encoder = custom_encoder + else: + self.encoder = TransformerEncoder( + num_encoder_layers, dims, num_heads, mlp_dims + ) + + if custom_decoder is not None: + self.decoder = custom_decoder + else: + self.decoder = TransformerDecoder( + num_decoder_layers, dims, num_heads, mlp_dims + ) + + def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask): + memory = self.encoder(src, src_mask) + output = self.decoder(tgt, memory, tgt_mask, memory_mask) + + return output