From ba8d6bf365a4847edcdb8e6112bd0bdd66edb210 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 31 Jan 2024 12:55:30 -0800 Subject: [PATCH] Change the transformer to norm_first by default (#599) --- python/mlx/nn/layers/transformer.py | 71 +++++++++++++---------------- 1 file changed, 31 insertions(+), 40 deletions(-) diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 6b88d4a8f..0c98e4c4e 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -116,7 +116,7 @@ class TransformerEncoderLayer(Module): mlp_dims: Optional[int] = None, dropout: float = 0.0, activation: Callable[[Any], Any] = relu, - norm_first: bool = False, + norm_first: bool = True, ): super().__init__() mlp_dims = mlp_dims or dims * 4 @@ -167,7 +167,7 @@ class TransformerEncoder(Module): mlp_dims: Optional[int] = None, dropout: float = 0.0, activation=relu, - norm_first: bool = False, + norm_first: bool = True, checkpoint: bool = False, ): super().__init__() @@ -182,10 +182,8 @@ class TransformerEncoder(Module): def __call__(self, x, mask): for l in self.layers: - if self.checkpoint: - x = checkpoint(l)(x, mask) - else: - x = l(x, mask) + l = checkpoint(l) if self.checkpoint else l + x = l(x, mask) return self.ln(x) @@ -197,7 +195,7 @@ class TransformerDecoderLayer(Module): mlp_dims: Optional[int] = None, dropout: float = 0.0, activation: Callable[[Any], Any] = relu, - norm_first: bool = False, + norm_first: bool = True, ): super().__init__() mlp_dims = mlp_dims or dims * 4 @@ -260,7 +258,7 @@ class TransformerDecoder(Module): mlp_dims: Optional[int] = None, dropout: float = 0.0, activation=relu, - norm_first: bool = False, + norm_first: bool = True, checkpoint: bool = False, ): super().__init__() @@ -275,10 +273,8 @@ class TransformerDecoder(Module): def __call__(self, x, memory, x_mask, memory_mask): for l in self.layers: - if self.checkpoint: - x = checkpoint(l)(x, memory, x_mask, memory_mask) - else: - x = l(x, memory, x_mask, memory_mask) + l = checkpoint(l) if self.checkpoint else l + x = l(x, memory, x_mask, memory_mask) return self.ln(x) @@ -317,7 +313,7 @@ class Transformer(Module): standard Transformer decoder. Default: ``None``. norm_first (bool, optional): if ``True``, encoder and decoder layers will perform layer normalization before attention and MLP - operations, otherwise after. Default: ``False``. + operations, otherwise after. Default: ``True``. chekpoint (bool, optional): if ``True`` perform gradient checkpointing to reduce the memory usage at the expense of more computation. Default: ``False``. @@ -334,37 +330,32 @@ class Transformer(Module): activation: Callable[[Any], Any] = relu, custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None, - norm_first: bool = False, + norm_first: bool = True, checkpoint: bool = False, ): super().__init__() - if custom_encoder is not None: - self.encoder = custom_encoder - else: - self.encoder = TransformerEncoder( - num_encoder_layers, - dims, - num_heads, - mlp_dims, - dropout, - activation, - norm_first, - checkpoint, - ) - if custom_decoder is not None: - self.decoder = custom_decoder - else: - self.decoder = TransformerDecoder( - num_decoder_layers, - dims, - num_heads, - mlp_dims, - dropout, - activation, - norm_first, - checkpoint, - ) + self.encoder = custom_encoder or TransformerEncoder( + num_encoder_layers, + dims, + num_heads, + mlp_dims, + dropout, + activation, + norm_first, + checkpoint, + ) + + self.decoder = custom_decoder or TransformerDecoder( + num_decoder_layers, + dims, + num_heads, + mlp_dims, + dropout, + activation, + norm_first, + checkpoint, + ) def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask): memory = self.encoder(src, src_mask)