Change the transformer to norm_first by default (#599)

This commit is contained in:
Angelos Katharopoulos 2024-01-31 12:55:30 -08:00 committed by GitHub
parent 4a5f3b21bb
commit ba8d6bf365
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -116,7 +116,7 @@ class TransformerEncoderLayer(Module):
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation: Callable[[Any], Any] = relu,
norm_first: bool = False,
norm_first: bool = True,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
@ -167,7 +167,7 @@ class TransformerEncoder(Module):
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation=relu,
norm_first: bool = False,
norm_first: bool = True,
checkpoint: bool = False,
):
super().__init__()
@ -182,9 +182,7 @@ class TransformerEncoder(Module):
def __call__(self, x, mask):
for l in self.layers:
if self.checkpoint:
x = checkpoint(l)(x, mask)
else:
l = checkpoint(l) if self.checkpoint else l
x = l(x, mask)
return self.ln(x)
@ -197,7 +195,7 @@ class TransformerDecoderLayer(Module):
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation: Callable[[Any], Any] = relu,
norm_first: bool = False,
norm_first: bool = True,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
@ -260,7 +258,7 @@ class TransformerDecoder(Module):
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
activation=relu,
norm_first: bool = False,
norm_first: bool = True,
checkpoint: bool = False,
):
super().__init__()
@ -275,9 +273,7 @@ class TransformerDecoder(Module):
def __call__(self, x, memory, x_mask, memory_mask):
for l in self.layers:
if self.checkpoint:
x = checkpoint(l)(x, memory, x_mask, memory_mask)
else:
l = checkpoint(l) if self.checkpoint else l
x = l(x, memory, x_mask, memory_mask)
return self.ln(x)
@ -317,7 +313,7 @@ class Transformer(Module):
standard Transformer decoder. Default: ``None``.
norm_first (bool, optional): if ``True``, encoder and decoder layers
will perform layer normalization before attention and MLP
operations, otherwise after. Default: ``False``.
operations, otherwise after. Default: ``True``.
chekpoint (bool, optional): if ``True`` perform gradient checkpointing
to reduce the memory usage at the expense of more computation.
Default: ``False``.
@ -334,14 +330,12 @@ class Transformer(Module):
activation: Callable[[Any], Any] = relu,
custom_encoder: Optional[Any] = None,
custom_decoder: Optional[Any] = None,
norm_first: bool = False,
norm_first: bool = True,
checkpoint: bool = False,
):
super().__init__()
if custom_encoder is not None:
self.encoder = custom_encoder
else:
self.encoder = TransformerEncoder(
self.encoder = custom_encoder or TransformerEncoder(
num_encoder_layers,
dims,
num_heads,
@ -352,10 +346,7 @@ class Transformer(Module):
checkpoint,
)
if custom_decoder is not None:
self.decoder = custom_decoder
else:
self.decoder = TransformerDecoder(
self.decoder = custom_decoder or TransformerDecoder(
num_decoder_layers,
dims,
num_heads,