Pass ln2 to cross attention

This commit is contained in:
Juarez Bochi 2023-12-18 15:05:05 -05:00
parent e899271275
commit 64e53e8415
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -210,7 +210,7 @@ class TransformerDecoderLayer(nn.Module):
x = x + y
y = self.ln2(x)
y, _ = self.cross_attention(x, memory, memory, memory_mask)
y, _ = self.cross_attention(y, memory, memory, memory_mask)
x = x + y
y = self.ln3(x)