diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 4b170e88f..79e87dc1d 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -238,7 +238,7 @@ class TransformerDecoderLayer(Module): y = self.cross_attention(y, memory, memory, memory_mask) y = self.dropout2(y) - x = self.ln1(x + y) + x = self.ln2(x + y) y = self.linear1(x) y = self.activation(y)