mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 10:51:21 +08:00
fix transformer decoder post norm LN (#1637)
This commit is contained in:
parent
974bb54ab2
commit
aa86876813
@ -238,7 +238,7 @@ class TransformerDecoderLayer(Module):
|
||||
|
||||
y = self.cross_attention(y, memory, memory, memory_mask)
|
||||
y = self.dropout2(y)
|
||||
x = self.ln1(x + y)
|
||||
x = self.ln2(x + y)
|
||||
|
||||
y = self.linear1(x)
|
||||
y = self.activation(y)
|
||||
|
Loading…
Reference in New Issue
Block a user