mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 03:31:17 +08:00
Change the transformer to norm_first by default (#599)
This commit is contained in:
parent
4a5f3b21bb
commit
ba8d6bf365
@ -116,7 +116,7 @@ class TransformerEncoderLayer(Module):
|
|||||||
mlp_dims: Optional[int] = None,
|
mlp_dims: Optional[int] = None,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
activation: Callable[[Any], Any] = relu,
|
activation: Callable[[Any], Any] = relu,
|
||||||
norm_first: bool = False,
|
norm_first: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = mlp_dims or dims * 4
|
mlp_dims = mlp_dims or dims * 4
|
||||||
@ -167,7 +167,7 @@ class TransformerEncoder(Module):
|
|||||||
mlp_dims: Optional[int] = None,
|
mlp_dims: Optional[int] = None,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
activation=relu,
|
activation=relu,
|
||||||
norm_first: bool = False,
|
norm_first: bool = True,
|
||||||
checkpoint: bool = False,
|
checkpoint: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -182,10 +182,8 @@ class TransformerEncoder(Module):
|
|||||||
|
|
||||||
def __call__(self, x, mask):
|
def __call__(self, x, mask):
|
||||||
for l in self.layers:
|
for l in self.layers:
|
||||||
if self.checkpoint:
|
l = checkpoint(l) if self.checkpoint else l
|
||||||
x = checkpoint(l)(x, mask)
|
x = l(x, mask)
|
||||||
else:
|
|
||||||
x = l(x, mask)
|
|
||||||
return self.ln(x)
|
return self.ln(x)
|
||||||
|
|
||||||
|
|
||||||
@ -197,7 +195,7 @@ class TransformerDecoderLayer(Module):
|
|||||||
mlp_dims: Optional[int] = None,
|
mlp_dims: Optional[int] = None,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
activation: Callable[[Any], Any] = relu,
|
activation: Callable[[Any], Any] = relu,
|
||||||
norm_first: bool = False,
|
norm_first: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = mlp_dims or dims * 4
|
mlp_dims = mlp_dims or dims * 4
|
||||||
@ -260,7 +258,7 @@ class TransformerDecoder(Module):
|
|||||||
mlp_dims: Optional[int] = None,
|
mlp_dims: Optional[int] = None,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
activation=relu,
|
activation=relu,
|
||||||
norm_first: bool = False,
|
norm_first: bool = True,
|
||||||
checkpoint: bool = False,
|
checkpoint: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -275,10 +273,8 @@ class TransformerDecoder(Module):
|
|||||||
|
|
||||||
def __call__(self, x, memory, x_mask, memory_mask):
|
def __call__(self, x, memory, x_mask, memory_mask):
|
||||||
for l in self.layers:
|
for l in self.layers:
|
||||||
if self.checkpoint:
|
l = checkpoint(l) if self.checkpoint else l
|
||||||
x = checkpoint(l)(x, memory, x_mask, memory_mask)
|
x = l(x, memory, x_mask, memory_mask)
|
||||||
else:
|
|
||||||
x = l(x, memory, x_mask, memory_mask)
|
|
||||||
return self.ln(x)
|
return self.ln(x)
|
||||||
|
|
||||||
|
|
||||||
@ -317,7 +313,7 @@ class Transformer(Module):
|
|||||||
standard Transformer decoder. Default: ``None``.
|
standard Transformer decoder. Default: ``None``.
|
||||||
norm_first (bool, optional): if ``True``, encoder and decoder layers
|
norm_first (bool, optional): if ``True``, encoder and decoder layers
|
||||||
will perform layer normalization before attention and MLP
|
will perform layer normalization before attention and MLP
|
||||||
operations, otherwise after. Default: ``False``.
|
operations, otherwise after. Default: ``True``.
|
||||||
chekpoint (bool, optional): if ``True`` perform gradient checkpointing
|
chekpoint (bool, optional): if ``True`` perform gradient checkpointing
|
||||||
to reduce the memory usage at the expense of more computation.
|
to reduce the memory usage at the expense of more computation.
|
||||||
Default: ``False``.
|
Default: ``False``.
|
||||||
@ -334,37 +330,32 @@ class Transformer(Module):
|
|||||||
activation: Callable[[Any], Any] = relu,
|
activation: Callable[[Any], Any] = relu,
|
||||||
custom_encoder: Optional[Any] = None,
|
custom_encoder: Optional[Any] = None,
|
||||||
custom_decoder: Optional[Any] = None,
|
custom_decoder: Optional[Any] = None,
|
||||||
norm_first: bool = False,
|
norm_first: bool = True,
|
||||||
checkpoint: bool = False,
|
checkpoint: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if custom_encoder is not None:
|
|
||||||
self.encoder = custom_encoder
|
|
||||||
else:
|
|
||||||
self.encoder = TransformerEncoder(
|
|
||||||
num_encoder_layers,
|
|
||||||
dims,
|
|
||||||
num_heads,
|
|
||||||
mlp_dims,
|
|
||||||
dropout,
|
|
||||||
activation,
|
|
||||||
norm_first,
|
|
||||||
checkpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
if custom_decoder is not None:
|
self.encoder = custom_encoder or TransformerEncoder(
|
||||||
self.decoder = custom_decoder
|
num_encoder_layers,
|
||||||
else:
|
dims,
|
||||||
self.decoder = TransformerDecoder(
|
num_heads,
|
||||||
num_decoder_layers,
|
mlp_dims,
|
||||||
dims,
|
dropout,
|
||||||
num_heads,
|
activation,
|
||||||
mlp_dims,
|
norm_first,
|
||||||
dropout,
|
checkpoint,
|
||||||
activation,
|
)
|
||||||
norm_first,
|
|
||||||
checkpoint,
|
self.decoder = custom_decoder or TransformerDecoder(
|
||||||
)
|
num_decoder_layers,
|
||||||
|
dims,
|
||||||
|
num_heads,
|
||||||
|
mlp_dims,
|
||||||
|
dropout,
|
||||||
|
activation,
|
||||||
|
norm_first,
|
||||||
|
checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
||||||
memory = self.encoder(src, src_mask)
|
memory = self.encoder(src, src_mask)
|
||||||
|
Loading…
Reference in New Issue
Block a user