diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index f528c9908..890c4ee5d 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -72,6 +72,8 @@ from mlx.nn.layers.recurrent import GRU, LSTM, RNN from mlx.nn.layers.transformer import ( MultiHeadAttention, Transformer, + TransformerDecoder, + TransformerDecoderLayer, TransformerEncoder, TransformerEncoderLayer, ) diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 0c98e4c4e..35bafe380 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -147,9 +147,9 @@ class TransformerEncoderLayer(Module): else: y = self.attention(x, x, x, mask) y = self.dropout1(y) - y = self.ln1(x + y) + x = self.ln1(x + y) - y = self.linear1(y) + y = self.linear1(x) y = self.activation(y) y = self.dropout2(y) y = self.linear2(y)