fix transformer (#1327)

This commit is contained in:
Awni Hannun 2024-08-13 16:04:26 -07:00 committed by GitHub
parent eaaea02010
commit 63ae767232
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 2 deletions

View File

@ -72,6 +72,8 @@ from mlx.nn.layers.recurrent import GRU, LSTM, RNN
from mlx.nn.layers.transformer import ( from mlx.nn.layers.transformer import (
MultiHeadAttention, MultiHeadAttention,
Transformer, Transformer,
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder, TransformerEncoder,
TransformerEncoderLayer, TransformerEncoderLayer,
) )

View File

@ -147,9 +147,9 @@ class TransformerEncoderLayer(Module):
else: else:
y = self.attention(x, x, x, mask) y = self.attention(x, x, x, mask)
y = self.dropout1(y) y = self.dropout1(y)
y = self.ln1(x + y) x = self.ln1(x + y)
y = self.linear1(y) y = self.linear1(x)
y = self.activation(y) y = self.activation(y)
y = self.dropout2(y) y = self.dropout2(y)
y = self.linear2(y) y = self.linear2(y)