fix transformer decoder post norm LN (#1637)

This commit is contained in:
Awni Hannun
2024-12-02 07:02:17 -08:00
committed by GitHub
parent 974bb54ab2
commit aa86876813

View File

@@ -238,7 +238,7 @@ class TransformerDecoderLayer(Module):
y = self.cross_attention(y, memory, memory, memory_mask)
y = self.dropout2(y)
x = self.ln1(x + y)
x = self.ln2(x + y)
y = self.linear1(x)
y = self.activation(y)