mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Change the transformer to norm_first by default (#599)
This commit is contained in:
parent
4a5f3b21bb
commit
ba8d6bf365
@ -116,7 +116,7 @@ class TransformerEncoderLayer(Module):
|
||||
mlp_dims: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
activation: Callable[[Any], Any] = relu,
|
||||
norm_first: bool = False,
|
||||
norm_first: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
mlp_dims = mlp_dims or dims * 4
|
||||
@ -167,7 +167,7 @@ class TransformerEncoder(Module):
|
||||
mlp_dims: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
activation=relu,
|
||||
norm_first: bool = False,
|
||||
norm_first: bool = True,
|
||||
checkpoint: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@ -182,10 +182,8 @@ class TransformerEncoder(Module):
|
||||
|
||||
def __call__(self, x, mask):
|
||||
for l in self.layers:
|
||||
if self.checkpoint:
|
||||
x = checkpoint(l)(x, mask)
|
||||
else:
|
||||
x = l(x, mask)
|
||||
l = checkpoint(l) if self.checkpoint else l
|
||||
x = l(x, mask)
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
@ -197,7 +195,7 @@ class TransformerDecoderLayer(Module):
|
||||
mlp_dims: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
activation: Callable[[Any], Any] = relu,
|
||||
norm_first: bool = False,
|
||||
norm_first: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
mlp_dims = mlp_dims or dims * 4
|
||||
@ -260,7 +258,7 @@ class TransformerDecoder(Module):
|
||||
mlp_dims: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
activation=relu,
|
||||
norm_first: bool = False,
|
||||
norm_first: bool = True,
|
||||
checkpoint: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@ -275,10 +273,8 @@ class TransformerDecoder(Module):
|
||||
|
||||
def __call__(self, x, memory, x_mask, memory_mask):
|
||||
for l in self.layers:
|
||||
if self.checkpoint:
|
||||
x = checkpoint(l)(x, memory, x_mask, memory_mask)
|
||||
else:
|
||||
x = l(x, memory, x_mask, memory_mask)
|
||||
l = checkpoint(l) if self.checkpoint else l
|
||||
x = l(x, memory, x_mask, memory_mask)
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
@ -317,7 +313,7 @@ class Transformer(Module):
|
||||
standard Transformer decoder. Default: ``None``.
|
||||
norm_first (bool, optional): if ``True``, encoder and decoder layers
|
||||
will perform layer normalization before attention and MLP
|
||||
operations, otherwise after. Default: ``False``.
|
||||
operations, otherwise after. Default: ``True``.
|
||||
chekpoint (bool, optional): if ``True`` perform gradient checkpointing
|
||||
to reduce the memory usage at the expense of more computation.
|
||||
Default: ``False``.
|
||||
@ -334,37 +330,32 @@ class Transformer(Module):
|
||||
activation: Callable[[Any], Any] = relu,
|
||||
custom_encoder: Optional[Any] = None,
|
||||
custom_decoder: Optional[Any] = None,
|
||||
norm_first: bool = False,
|
||||
norm_first: bool = True,
|
||||
checkpoint: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if custom_encoder is not None:
|
||||
self.encoder = custom_encoder
|
||||
else:
|
||||
self.encoder = TransformerEncoder(
|
||||
num_encoder_layers,
|
||||
dims,
|
||||
num_heads,
|
||||
mlp_dims,
|
||||
dropout,
|
||||
activation,
|
||||
norm_first,
|
||||
checkpoint,
|
||||
)
|
||||
|
||||
if custom_decoder is not None:
|
||||
self.decoder = custom_decoder
|
||||
else:
|
||||
self.decoder = TransformerDecoder(
|
||||
num_decoder_layers,
|
||||
dims,
|
||||
num_heads,
|
||||
mlp_dims,
|
||||
dropout,
|
||||
activation,
|
||||
norm_first,
|
||||
checkpoint,
|
||||
)
|
||||
self.encoder = custom_encoder or TransformerEncoder(
|
||||
num_encoder_layers,
|
||||
dims,
|
||||
num_heads,
|
||||
mlp_dims,
|
||||
dropout,
|
||||
activation,
|
||||
norm_first,
|
||||
checkpoint,
|
||||
)
|
||||
|
||||
self.decoder = custom_decoder or TransformerDecoder(
|
||||
num_decoder_layers,
|
||||
dims,
|
||||
num_heads,
|
||||
mlp_dims,
|
||||
dropout,
|
||||
activation,
|
||||
norm_first,
|
||||
checkpoint,
|
||||
)
|
||||
|
||||
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
||||
memory = self.encoder(src, src_mask)
|
||||
|
Loading…
Reference in New Issue
Block a user