mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix transformer (#1327)
This commit is contained in:
parent
eaaea02010
commit
63ae767232
@ -72,6 +72,8 @@ from mlx.nn.layers.recurrent import GRU, LSTM, RNN
|
||||
from mlx.nn.layers.transformer import (
|
||||
MultiHeadAttention,
|
||||
Transformer,
|
||||
TransformerDecoder,
|
||||
TransformerDecoderLayer,
|
||||
TransformerEncoder,
|
||||
TransformerEncoderLayer,
|
||||
)
|
||||
|
@ -147,9 +147,9 @@ class TransformerEncoderLayer(Module):
|
||||
else:
|
||||
y = self.attention(x, x, x, mask)
|
||||
y = self.dropout1(y)
|
||||
y = self.ln1(x + y)
|
||||
x = self.ln1(x + y)
|
||||
|
||||
y = self.linear1(y)
|
||||
y = self.linear1(x)
|
||||
y = self.activation(y)
|
||||
y = self.dropout2(y)
|
||||
y = self.linear2(y)
|
||||
|
Loading…
Reference in New Issue
Block a user