diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index f5b31c2ff..35132d514 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -7,6 +7,7 @@ with a short description of your contribution(s) below. For example: MLX was developed with contributions from the following individuals: +- Juarez Bochi: Fixed bug in cross attention. # Third-Party Software diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index e24957066..8d9efe171 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -157,7 +157,7 @@ class TransformerDecoderLayer(Module): x = x + y y = self.ln2(x) - y = self.cross_attention(x, memory, memory, memory_mask) + y = self.cross_attention(y, memory, memory, memory_mask) x = x + y y = self.ln3(x)