mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing * Add a dependency management primitive * Change the eval order to deep branches first * Add graph depth tracking to the array
This commit is contained in:
committed by
GitHub
parent
143e2690d5
commit
0de5988f92
@@ -9,6 +9,7 @@ from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.dropout import Dropout
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.nn.layers.normalization import LayerNorm
|
||||
from mlx.nn.utils import checkpoint
|
||||
|
||||
|
||||
class MultiHeadAttention(Module):
|
||||
@@ -167,6 +168,7 @@ class TransformerEncoder(Module):
|
||||
dropout: float = 0.0,
|
||||
activation=relu,
|
||||
norm_first: bool = False,
|
||||
checkpoint: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
@@ -176,10 +178,14 @@ class TransformerEncoder(Module):
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(dims)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def __call__(self, x, mask):
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
if self.checkpoint:
|
||||
x = checkpoint(l)(x, mask)
|
||||
else:
|
||||
x = l(x, mask)
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
@@ -255,6 +261,7 @@ class TransformerDecoder(Module):
|
||||
dropout: float = 0.0,
|
||||
activation=relu,
|
||||
norm_first: bool = False,
|
||||
checkpoint: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
@@ -264,10 +271,14 @@ class TransformerDecoder(Module):
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(dims)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def __call__(self, x, memory, x_mask, memory_mask):
|
||||
for l in self.layers:
|
||||
x = l(x, memory, x_mask, memory_mask)
|
||||
if self.checkpoint:
|
||||
x = checkpoint(l)(x, memory, x_mask, memory_mask)
|
||||
else:
|
||||
x = l(x, memory, x_mask, memory_mask)
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
@@ -307,6 +318,9 @@ class Transformer(Module):
|
||||
norm_first (bool, optional): if ``True``, encoder and decoder layers
|
||||
will perform layer normalization before attention and MLP
|
||||
operations, otherwise after. Default: ``False``.
|
||||
chekpoint (bool, optional): if ``True`` perform gradient checkpointing
|
||||
to reduce the memory usage at the expense of more computation.
|
||||
Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -321,6 +335,7 @@ class Transformer(Module):
|
||||
custom_encoder: Optional[Any] = None,
|
||||
custom_decoder: Optional[Any] = None,
|
||||
norm_first: bool = False,
|
||||
checkpoint: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if custom_encoder is not None:
|
||||
@@ -334,6 +349,7 @@ class Transformer(Module):
|
||||
dropout,
|
||||
activation,
|
||||
norm_first,
|
||||
checkpoint,
|
||||
)
|
||||
|
||||
if custom_decoder is not None:
|
||||
@@ -347,6 +363,7 @@ class Transformer(Module):
|
||||
dropout,
|
||||
activation,
|
||||
norm_first,
|
||||
checkpoint,
|
||||
)
|
||||
|
||||
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
||||
|
||||
Reference in New Issue
Block a user