Change the transformer to norm_first by default (#599)

This commit is contained in:
Angelos Katharopoulos 2024-01-31 12:55:30 -08:00 committed by GitHub
parent 4a5f3b21bb
commit ba8d6bf365
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -116,7 +116,7 @@ class TransformerEncoderLayer(Module):
mlp_dims: Optional[int] = None, mlp_dims: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
activation: Callable[[Any], Any] = relu, activation: Callable[[Any], Any] = relu,
norm_first: bool = False, norm_first: bool = True,
): ):
super().__init__() super().__init__()
mlp_dims = mlp_dims or dims * 4 mlp_dims = mlp_dims or dims * 4
@ -167,7 +167,7 @@ class TransformerEncoder(Module):
mlp_dims: Optional[int] = None, mlp_dims: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
activation=relu, activation=relu,
norm_first: bool = False, norm_first: bool = True,
checkpoint: bool = False, checkpoint: bool = False,
): ):
super().__init__() super().__init__()
@ -182,10 +182,8 @@ class TransformerEncoder(Module):
def __call__(self, x, mask): def __call__(self, x, mask):
for l in self.layers: for l in self.layers:
if self.checkpoint: l = checkpoint(l) if self.checkpoint else l
x = checkpoint(l)(x, mask) x = l(x, mask)
else:
x = l(x, mask)
return self.ln(x) return self.ln(x)
@ -197,7 +195,7 @@ class TransformerDecoderLayer(Module):
mlp_dims: Optional[int] = None, mlp_dims: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
activation: Callable[[Any], Any] = relu, activation: Callable[[Any], Any] = relu,
norm_first: bool = False, norm_first: bool = True,
): ):
super().__init__() super().__init__()
mlp_dims = mlp_dims or dims * 4 mlp_dims = mlp_dims or dims * 4
@ -260,7 +258,7 @@ class TransformerDecoder(Module):
mlp_dims: Optional[int] = None, mlp_dims: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
activation=relu, activation=relu,
norm_first: bool = False, norm_first: bool = True,
checkpoint: bool = False, checkpoint: bool = False,
): ):
super().__init__() super().__init__()
@ -275,10 +273,8 @@ class TransformerDecoder(Module):
def __call__(self, x, memory, x_mask, memory_mask): def __call__(self, x, memory, x_mask, memory_mask):
for l in self.layers: for l in self.layers:
if self.checkpoint: l = checkpoint(l) if self.checkpoint else l
x = checkpoint(l)(x, memory, x_mask, memory_mask) x = l(x, memory, x_mask, memory_mask)
else:
x = l(x, memory, x_mask, memory_mask)
return self.ln(x) return self.ln(x)
@ -317,7 +313,7 @@ class Transformer(Module):
standard Transformer decoder. Default: ``None``. standard Transformer decoder. Default: ``None``.
norm_first (bool, optional): if ``True``, encoder and decoder layers norm_first (bool, optional): if ``True``, encoder and decoder layers
will perform layer normalization before attention and MLP will perform layer normalization before attention and MLP
operations, otherwise after. Default: ``False``. operations, otherwise after. Default: ``True``.
chekpoint (bool, optional): if ``True`` perform gradient checkpointing chekpoint (bool, optional): if ``True`` perform gradient checkpointing
to reduce the memory usage at the expense of more computation. to reduce the memory usage at the expense of more computation.
Default: ``False``. Default: ``False``.
@ -334,37 +330,32 @@ class Transformer(Module):
activation: Callable[[Any], Any] = relu, activation: Callable[[Any], Any] = relu,
custom_encoder: Optional[Any] = None, custom_encoder: Optional[Any] = None,
custom_decoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
norm_first: bool = False, norm_first: bool = True,
checkpoint: bool = False, checkpoint: bool = False,
): ):
super().__init__() super().__init__()
if custom_encoder is not None:
self.encoder = custom_encoder
else:
self.encoder = TransformerEncoder(
num_encoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
checkpoint,
)
if custom_decoder is not None: self.encoder = custom_encoder or TransformerEncoder(
self.decoder = custom_decoder num_encoder_layers,
else: dims,
self.decoder = TransformerDecoder( num_heads,
num_decoder_layers, mlp_dims,
dims, dropout,
num_heads, activation,
mlp_dims, norm_first,
dropout, checkpoint,
activation, )
norm_first,
checkpoint, self.decoder = custom_decoder or TransformerDecoder(
) num_decoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
checkpoint,
)
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask): def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
memory = self.encoder(src, src_mask) memory = self.encoder(src, src_mask)