From aa86876813efd8666a6c74f1fd98735836f97974 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 2 Dec 2024 07:02:17 -0800 Subject: [PATCH] fix transformer decoder post norm LN (#1637) --- python/mlx/nn/layers/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 4b170e88f..79e87dc1d 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -238,7 +238,7 @@ class TransformerDecoderLayer(Module): y = self.cross_attention(y, memory, memory, memory_mask) y = self.dropout2(y) - x = self.ln1(x + y) + x = self.ln2(x + y) y = self.linear1(x) y = self.activation(y)