fix transformer (#1327)

This commit is contained in:
Awni Hannun 2024-08-13 16:04:26 -07:00 committed by GitHub
parent eaaea02010
commit 63ae767232
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 2 deletions

View File

@ -72,6 +72,8 @@ from mlx.nn.layers.recurrent import GRU, LSTM, RNN
from mlx.nn.layers.transformer import (
MultiHeadAttention,
Transformer,
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
)

View File

@ -147,9 +147,9 @@ class TransformerEncoderLayer(Module):
else:
y = self.attention(x, x, x, mask)
y = self.dropout1(y)
y = self.ln1(x + y)
x = self.ln1(x + y)
y = self.linear1(y)
y = self.linear1(x)
y = self.activation(y)
y = self.dropout2(y)
y = self.linear2(y)