mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 20:07:59 +08:00
fix transformer decoder post norm LN (#1637)
This commit is contained in:
parent
974bb54ab2
commit
aa86876813
@ -238,7 +238,7 @@ class TransformerDecoderLayer(Module):
|
|||||||
|
|
||||||
y = self.cross_attention(y, memory, memory, memory_mask)
|
y = self.cross_attention(y, memory, memory, memory_mask)
|
||||||
y = self.dropout2(y)
|
y = self.dropout2(y)
|
||||||
x = self.ln1(x + y)
|
x = self.ln2(x + y)
|
||||||
|
|
||||||
y = self.linear1(x)
|
y = self.linear1(x)
|
||||||
y = self.activation(y)
|
y = self.activation(y)
|
||||||
|
Loading…
Reference in New Issue
Block a user