Change the transformer to norm_first by default (#599)

This commit is contained in:
Angelos Katharopoulos 2024-01-31 12:55:30 -08:00 committed by GitHub
parent 4a5f3b21bb
commit ba8d6bf365
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -116,7 +116,7 @@ class TransformerEncoderLayer(Module):
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation: Callable[[Any], Any] = relu,
norm_first: bool = False,
norm_first: bool = True,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
@ -167,7 +167,7 @@ class TransformerEncoder(Module):
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation=relu,
norm_first: bool = False,
norm_first: bool = True,
checkpoint: bool = False,
):
super().__init__()
@ -182,10 +182,8 @@ class TransformerEncoder(Module):
def __call__(self, x, mask):
for l in self.layers:
if self.checkpoint:
x = checkpoint(l)(x, mask)
else:
x = l(x, mask)
l = checkpoint(l) if self.checkpoint else l
x = l(x, mask)
return self.ln(x)
@ -197,7 +195,7 @@ class TransformerDecoderLayer(Module):
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation: Callable[[Any], Any] = relu,
norm_first: bool = False,
norm_first: bool = True,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
@ -260,7 +258,7 @@ class TransformerDecoder(Module):
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation=relu,
norm_first: bool = False,
norm_first: bool = True,
checkpoint: bool = False,
):
super().__init__()
@ -275,10 +273,8 @@ class TransformerDecoder(Module):
def __call__(self, x, memory, x_mask, memory_mask):
for l in self.layers:
if self.checkpoint:
x = checkpoint(l)(x, memory, x_mask, memory_mask)
else:
x = l(x, memory, x_mask, memory_mask)
l = checkpoint(l) if self.checkpoint else l
x = l(x, memory, x_mask, memory_mask)
return self.ln(x)
@ -317,7 +313,7 @@ class Transformer(Module):
standard Transformer decoder. Default: ``None``.
norm_first (bool, optional): if ``True``, encoder and decoder layers
will perform layer normalization before attention and MLP
operations, otherwise after. Default: ``False``.
operations, otherwise after. Default: ``True``.
chekpoint (bool, optional): if ``True`` perform gradient checkpointing
to reduce the memory usage at the expense of more computation.
Default: ``False``.
@ -334,37 +330,32 @@ class Transformer(Module):
activation: Callable[[Any], Any] = relu,
custom_encoder: Optional[Any] = None,
custom_decoder: Optional[Any] = None,
norm_first: bool = False,
norm_first: bool = True,
checkpoint: bool = False,
):
super().__init__()
if custom_encoder is not None:
self.encoder = custom_encoder
else:
self.encoder = TransformerEncoder(
num_encoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
checkpoint,
)
if custom_decoder is not None:
self.decoder = custom_decoder
else:
self.decoder = TransformerDecoder(
num_decoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
checkpoint,
)
self.encoder = custom_encoder or TransformerEncoder(
num_encoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
checkpoint,
)
self.decoder = custom_decoder or TransformerDecoder(
num_decoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
checkpoint,
)
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
memory = self.encoder(src, src_mask)