Custom VJP and checkpointing (#541)

* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
This commit is contained in:
Angelos Katharopoulos
2024-01-30 16:04:45 -08:00
committed by GitHub
parent 143e2690d5
commit 0de5988f92
22 changed files with 527 additions and 37 deletions

View File

@@ -9,6 +9,7 @@ from mlx.nn.layers.base import Module
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm
from mlx.nn.utils import checkpoint
class MultiHeadAttention(Module):
@@ -167,6 +168,7 @@ class TransformerEncoder(Module):
dropout: float = 0.0,
activation=relu,
norm_first: bool = False,
checkpoint: bool = False,
):
super().__init__()
self.layers = [
@@ -176,10 +178,14 @@ class TransformerEncoder(Module):
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
self.checkpoint = checkpoint
def __call__(self, x, mask):
for l in self.layers:
x = l(x, mask)
if self.checkpoint:
x = checkpoint(l)(x, mask)
else:
x = l(x, mask)
return self.ln(x)
@@ -255,6 +261,7 @@ class TransformerDecoder(Module):
dropout: float = 0.0,
activation=relu,
norm_first: bool = False,
checkpoint: bool = False,
):
super().__init__()
self.layers = [
@@ -264,10 +271,14 @@ class TransformerDecoder(Module):
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
self.checkpoint = checkpoint
def __call__(self, x, memory, x_mask, memory_mask):
for l in self.layers:
x = l(x, memory, x_mask, memory_mask)
if self.checkpoint:
x = checkpoint(l)(x, memory, x_mask, memory_mask)
else:
x = l(x, memory, x_mask, memory_mask)
return self.ln(x)
@@ -307,6 +318,9 @@ class Transformer(Module):
norm_first (bool, optional): if ``True``, encoder and decoder layers
will perform layer normalization before attention and MLP
operations, otherwise after. Default: ``False``.
chekpoint (bool, optional): if ``True`` perform gradient checkpointing
to reduce the memory usage at the expense of more computation.
Default: ``False``.
"""
def __init__(
@@ -321,6 +335,7 @@ class Transformer(Module):
custom_encoder: Optional[Any] = None,
custom_decoder: Optional[Any] = None,
norm_first: bool = False,
checkpoint: bool = False,
):
super().__init__()
if custom_encoder is not None:
@@ -334,6 +349,7 @@ class Transformer(Module):
dropout,
activation,
norm_first,
checkpoint,
)
if custom_decoder is not None:
@@ -347,6 +363,7 @@ class Transformer(Module):
dropout,
activation,
norm_first,
checkpoint,
)
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):

View File

@@ -1,11 +1,14 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
from functools import wraps
from typing import Callable
import mlx.core as mx
from .layers.base import Module
def value_and_grad(model: "mlx.nn.Module", fn: Callable):
def value_and_grad(model: Module, fn: Callable):
"""Transform the passed function ``fn`` to a function that computes the
gradients of ``fn`` wrt the model's trainable parameters and also its
value.
@@ -26,8 +29,42 @@ def value_and_grad(model: "mlx.nn.Module", fn: Callable):
value_grad_fn = mx.value_and_grad(inner_fn)
@wraps(fn)
def wrapped_value_grad_fn(*args, **kwargs):
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
return value, grad
return wrapped_value_grad_fn
def checkpoint(module: Module, fn: Callable = None):
"""Transform the passed callable to one that performs gradient
checkpointing with respect to the trainable parameters of the module (and
the callable's inputs).
Args:
module (mlx.nn.Module): The module for whose parameters we will be
performing gradient checkpointing.
fn (Callable, optional): The function to checkpoint. If not provided it
defaults to the provided module.
Returns:
A callable that saves the inputs and outputs during the forward pass
and recomputes all intermediate states during the backward pass.
"""
if fn is None:
# Capturing module instead of module.__call__ allows someone to
# monkey-patch __call__ later on and the correct method will be used
fn = module
def inner_fn(params, *args, **kwargs):
module.update(params)
return fn(*args, **kwargs)
checkpointed_fn = mx.checkpoint(inner_fn)
@wraps(fn)
def wrapped_checkpointed_fn(*args, **kwargs):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
return wrapped_checkpointed_fn