mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-30 23:38:09 +08:00
fix transformer decoder post norm LN (#1637)
This commit is contained in:
@@ -238,7 +238,7 @@ class TransformerDecoderLayer(Module):
|
||||
|
||||
y = self.cross_attention(y, memory, memory, memory_mask)
|
||||
y = self.dropout2(y)
|
||||
x = self.ln1(x + y)
|
||||
x = self.ln2(x + y)
|
||||
|
||||
y = self.linear1(x)
|
||||
y = self.activation(y)
|
||||
|
||||
Reference in New Issue
Block a user