mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	fix transformer (#1327)
This commit is contained in:
		| @@ -72,6 +72,8 @@ from mlx.nn.layers.recurrent import GRU, LSTM, RNN | ||||
| from mlx.nn.layers.transformer import ( | ||||
|     MultiHeadAttention, | ||||
|     Transformer, | ||||
|     TransformerDecoder, | ||||
|     TransformerDecoderLayer, | ||||
|     TransformerEncoder, | ||||
|     TransformerEncoderLayer, | ||||
| ) | ||||
|   | ||||
| @@ -147,9 +147,9 @@ class TransformerEncoderLayer(Module): | ||||
|         else: | ||||
|             y = self.attention(x, x, x, mask) | ||||
|             y = self.dropout1(y) | ||||
|             y = self.ln1(x + y) | ||||
|             x = self.ln1(x + y) | ||||
|  | ||||
|             y = self.linear1(y) | ||||
|             y = self.linear1(x) | ||||
|             y = self.activation(y) | ||||
|             y = self.dropout2(y) | ||||
|             y = self.linear2(y) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun