fix transformer decoder post norm LN (#1637)

This commit is contained in:
Awni Hannun 2024-12-02 07:02:17 -08:00 committed by GitHub
parent 974bb54ab2
commit aa86876813
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -238,7 +238,7 @@ class TransformerDecoderLayer(Module):
y = self.cross_attention(y, memory, memory, memory_mask)
y = self.dropout2(y)
x = self.ln1(x + y)
x = self.ln2(x + y)
y = self.linear1(x)
y = self.activation(y)