mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 21:37:50 +08:00
fix transformer decoder post norm LN (#1637)
This commit is contained in:
@@ -238,7 +238,7 @@ class TransformerDecoderLayer(Module):
|
||||
|
||||
y = self.cross_attention(y, memory, memory, memory_mask)
|
||||
y = self.dropout2(y)
|
||||
x = self.ln1(x + y)
|
||||
x = self.ln2(x + y)
|
||||
|
||||
y = self.linear1(x)
|
||||
y = self.activation(y)
|
||||
|
Reference in New Issue
Block a user