Fix cross-attention (#210)

* Fix cross-attention

With the current code, ln2 is a no-op. Its output should be passed to the cross-attention layer

* Add name to contributors
This commit is contained in:
Juarez Bochi 2023-12-18 15:27:27 -05:00 committed by GitHub
parent 4d4af12c6f
commit f4f6e17d45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 1 deletions

View File

@ -7,6 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Juarez Bochi: Fixed bug in cross attention.
# Third-Party Software

View File

@ -157,7 +157,7 @@ class TransformerDecoderLayer(Module):
x = x + y
y = self.ln2(x)
y = self.cross_attention(x, memory, memory, memory_mask)
y = self.cross_attention(y, memory, memory, memory_mask)
x = x + y
y = self.ln3(x)