From f4f6e17d453635d43ed742cab13009b424c42094 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Mon, 18 Dec 2023 15:27:27 -0500 Subject: [PATCH] Fix cross-attention (#210) * Fix cross-attention With the current code, ln2 is a no-op. Its output should be passed to the cross-attention layer * Add name to contributors --- ACKNOWLEDGMENTS.md | 1 + python/mlx/nn/layers/transformer.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index f5b31c2ff..35132d514 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -7,6 +7,7 @@ with a short description of your contribution(s) below. For example: MLX was developed with contributions from the following individuals: +- Juarez Bochi: Fixed bug in cross attention. # Third-Party Software diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index e24957066..8d9efe171 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -157,7 +157,7 @@ class TransformerDecoderLayer(Module): x = x + y y = self.ln2(x) - y = self.cross_attention(x, memory, memory, memory_mask) + y = self.cross_attention(y, memory, memory, memory_mask) x = x + y y = self.ln3(x)