mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix transformer (#1327)
This commit is contained in:
parent
eaaea02010
commit
63ae767232
@ -72,6 +72,8 @@ from mlx.nn.layers.recurrent import GRU, LSTM, RNN
|
|||||||
from mlx.nn.layers.transformer import (
|
from mlx.nn.layers.transformer import (
|
||||||
MultiHeadAttention,
|
MultiHeadAttention,
|
||||||
Transformer,
|
Transformer,
|
||||||
|
TransformerDecoder,
|
||||||
|
TransformerDecoderLayer,
|
||||||
TransformerEncoder,
|
TransformerEncoder,
|
||||||
TransformerEncoderLayer,
|
TransformerEncoderLayer,
|
||||||
)
|
)
|
||||||
|
@ -147,9 +147,9 @@ class TransformerEncoderLayer(Module):
|
|||||||
else:
|
else:
|
||||||
y = self.attention(x, x, x, mask)
|
y = self.attention(x, x, x, mask)
|
||||||
y = self.dropout1(y)
|
y = self.dropout1(y)
|
||||||
y = self.ln1(x + y)
|
x = self.ln1(x + y)
|
||||||
|
|
||||||
y = self.linear1(y)
|
y = self.linear1(x)
|
||||||
y = self.activation(y)
|
y = self.activation(y)
|
||||||
y = self.dropout2(y)
|
y = self.dropout2(y)
|
||||||
y = self.linear2(y)
|
y = self.linear2(y)
|
||||||
|
Loading…
Reference in New Issue
Block a user