From 63ae767232fb3f51c4a1e136f99e5e6b03cb496c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 13 Aug 2024 16:04:26 -0700 Subject: [PATCH] fix transformer (#1327) --- python/mlx/nn/layers/__init__.py | 2 ++ python/mlx/nn/layers/transformer.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index f528c9908..890c4ee5d 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -72,6 +72,8 @@ from mlx.nn.layers.recurrent import GRU, LSTM, RNN from mlx.nn.layers.transformer import ( MultiHeadAttention, Transformer, + TransformerDecoder, + TransformerDecoderLayer, TransformerEncoder, TransformerEncoderLayer, ) diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 0c98e4c4e..35bafe380 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -147,9 +147,9 @@ class TransformerEncoderLayer(Module): else: y = self.attention(x, x, x, mask) y = self.dropout1(y) - y = self.ln1(x + y) + x = self.ln1(x + y) - y = self.linear1(y) + y = self.linear1(x) y = self.activation(y) y = self.dropout2(y) y = self.linear2(y)