mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing * Add a dependency management primitive * Change the eval order to deep branches first * Add graph depth tracking to the array
This commit is contained in:

committed by
GitHub

parent
143e2690d5
commit
0de5988f92
@@ -9,6 +9,7 @@ from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.dropout import Dropout
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.nn.layers.normalization import LayerNorm
|
||||
from mlx.nn.utils import checkpoint
|
||||
|
||||
|
||||
class MultiHeadAttention(Module):
|
||||
@@ -167,6 +168,7 @@ class TransformerEncoder(Module):
|
||||
dropout: float = 0.0,
|
||||
activation=relu,
|
||||
norm_first: bool = False,
|
||||
checkpoint: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
@@ -176,10 +178,14 @@ class TransformerEncoder(Module):
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(dims)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def __call__(self, x, mask):
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
if self.checkpoint:
|
||||
x = checkpoint(l)(x, mask)
|
||||
else:
|
||||
x = l(x, mask)
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
@@ -255,6 +261,7 @@ class TransformerDecoder(Module):
|
||||
dropout: float = 0.0,
|
||||
activation=relu,
|
||||
norm_first: bool = False,
|
||||
checkpoint: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
@@ -264,10 +271,14 @@ class TransformerDecoder(Module):
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(dims)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def __call__(self, x, memory, x_mask, memory_mask):
|
||||
for l in self.layers:
|
||||
x = l(x, memory, x_mask, memory_mask)
|
||||
if self.checkpoint:
|
||||
x = checkpoint(l)(x, memory, x_mask, memory_mask)
|
||||
else:
|
||||
x = l(x, memory, x_mask, memory_mask)
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
@@ -307,6 +318,9 @@ class Transformer(Module):
|
||||
norm_first (bool, optional): if ``True``, encoder and decoder layers
|
||||
will perform layer normalization before attention and MLP
|
||||
operations, otherwise after. Default: ``False``.
|
||||
chekpoint (bool, optional): if ``True`` perform gradient checkpointing
|
||||
to reduce the memory usage at the expense of more computation.
|
||||
Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -321,6 +335,7 @@ class Transformer(Module):
|
||||
custom_encoder: Optional[Any] = None,
|
||||
custom_decoder: Optional[Any] = None,
|
||||
norm_first: bool = False,
|
||||
checkpoint: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if custom_encoder is not None:
|
||||
@@ -334,6 +349,7 @@ class Transformer(Module):
|
||||
dropout,
|
||||
activation,
|
||||
norm_first,
|
||||
checkpoint,
|
||||
)
|
||||
|
||||
if custom_decoder is not None:
|
||||
@@ -347,6 +363,7 @@ class Transformer(Module):
|
||||
dropout,
|
||||
activation,
|
||||
norm_first,
|
||||
checkpoint,
|
||||
)
|
||||
|
||||
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
||||
|
@@ -1,11 +1,14 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .layers.base import Module
|
||||
|
||||
def value_and_grad(model: "mlx.nn.Module", fn: Callable):
|
||||
|
||||
def value_and_grad(model: Module, fn: Callable):
|
||||
"""Transform the passed function ``fn`` to a function that computes the
|
||||
gradients of ``fn`` wrt the model's trainable parameters and also its
|
||||
value.
|
||||
@@ -26,8 +29,42 @@ def value_and_grad(model: "mlx.nn.Module", fn: Callable):
|
||||
|
||||
value_grad_fn = mx.value_and_grad(inner_fn)
|
||||
|
||||
@wraps(fn)
|
||||
def wrapped_value_grad_fn(*args, **kwargs):
|
||||
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
|
||||
return value, grad
|
||||
|
||||
return wrapped_value_grad_fn
|
||||
|
||||
|
||||
def checkpoint(module: Module, fn: Callable = None):
|
||||
"""Transform the passed callable to one that performs gradient
|
||||
checkpointing with respect to the trainable parameters of the module (and
|
||||
the callable's inputs).
|
||||
|
||||
Args:
|
||||
module (mlx.nn.Module): The module for whose parameters we will be
|
||||
performing gradient checkpointing.
|
||||
fn (Callable, optional): The function to checkpoint. If not provided it
|
||||
defaults to the provided module.
|
||||
|
||||
Returns:
|
||||
A callable that saves the inputs and outputs during the forward pass
|
||||
and recomputes all intermediate states during the backward pass.
|
||||
"""
|
||||
if fn is None:
|
||||
# Capturing module instead of module.__call__ allows someone to
|
||||
# monkey-patch __call__ later on and the correct method will be used
|
||||
fn = module
|
||||
|
||||
def inner_fn(params, *args, **kwargs):
|
||||
module.update(params)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
checkpointed_fn = mx.checkpoint(inner_fn)
|
||||
|
||||
@wraps(fn)
|
||||
def wrapped_checkpointed_fn(*args, **kwargs):
|
||||
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
||||
|
||||
return wrapped_checkpointed_fn
|
||||
|
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <algorithm>
|
||||
@@ -142,7 +141,8 @@ std::vector<array> tree_flatten(py::object tree, bool strict = true) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
} else if (strict) {
|
||||
throw std::invalid_argument("Argument is not an array");
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
}
|
||||
});
|
||||
|
||||
@@ -162,12 +162,48 @@ py::object tree_unflatten(
|
||||
});
|
||||
}
|
||||
|
||||
py::object tree_unflatten_none(
|
||||
py::object structure_sentinel() {
|
||||
static py::object sentinel;
|
||||
|
||||
if (sentinel.ptr() == nullptr) {
|
||||
sentinel = py::capsule(&sentinel);
|
||||
// probably not needed but this should make certain that we won't ever
|
||||
// delete the sentinel
|
||||
sentinel.inc_ref();
|
||||
}
|
||||
|
||||
return sentinel;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
||||
py::object tree,
|
||||
bool strict = true) {
|
||||
auto sentinel = structure_sentinel();
|
||||
std::vector<array> flat_tree;
|
||||
auto structure = tree_map(
|
||||
tree,
|
||||
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
return sentinel;
|
||||
} else if (!strict) {
|
||||
return py::cast<py::object>(obj);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
}
|
||||
});
|
||||
|
||||
return {flat_tree, structure};
|
||||
}
|
||||
|
||||
py::object tree_unflatten_from_structure(
|
||||
py::object structure,
|
||||
const std::vector<array>& values,
|
||||
int index = 0) {
|
||||
return tree_map(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
auto sentinel = structure_sentinel();
|
||||
return tree_map(structure, [&](py::handle obj) {
|
||||
if (obj.is(sentinel)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
@@ -472,14 +508,10 @@ struct PyCompiledFun {
|
||||
|
||||
py::object operator()(const py::args& args) {
|
||||
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
||||
// Call the python function
|
||||
py::object py_outputs = this->fun(*tree_unflatten(args, a));
|
||||
// Call the python function and flatten the outputs
|
||||
auto [outputs, py_outputs] = tree_flatten_with_structure(
|
||||
std::move(this->fun(*tree_unflatten(args, a))), true);
|
||||
|
||||
// Flatten the outputs
|
||||
auto outputs = tree_flatten(py_outputs, true);
|
||||
|
||||
py_outputs =
|
||||
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
|
||||
tree_cache().insert({this->fun_id, py_outputs});
|
||||
return outputs;
|
||||
};
|
||||
@@ -492,15 +524,75 @@ struct PyCompiledFun {
|
||||
|
||||
// Put the outputs back in the container
|
||||
py::object py_outputs = tree_cache().at(fun_id);
|
||||
return tree_unflatten_none(py_outputs, outputs);
|
||||
return tree_unflatten_from_structure(py_outputs, outputs);
|
||||
};
|
||||
|
||||
~PyCompiledFun() {
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
tree_cache().erase(fun_id);
|
||||
detail::compile_erase(fun_id);
|
||||
fun.release().dec_ref();
|
||||
}
|
||||
};
|
||||
|
||||
class PyCheckpointedFun {
|
||||
public:
|
||||
PyCheckpointedFun(py::function fun) : fun_(std::move(fun)) {}
|
||||
|
||||
~PyCheckpointedFun() {
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
}
|
||||
|
||||
struct InnerFunction {
|
||||
py::object fun_;
|
||||
py::object args_structure_;
|
||||
std::weak_ptr<py::object> output_structure_;
|
||||
|
||||
InnerFunction(
|
||||
py::object fun,
|
||||
py::object args_structure,
|
||||
std::weak_ptr<py::object> output_structure)
|
||||
: fun_(std::move(fun)),
|
||||
args_structure_(std::move(args_structure)),
|
||||
output_structure_(output_structure) {}
|
||||
~InnerFunction() {
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
args_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::vector<array> operator()(const std::vector<array>& inputs) {
|
||||
auto args = py::cast<py::tuple>(
|
||||
tree_unflatten_from_structure(args_structure_, inputs));
|
||||
auto [outputs, output_structure] =
|
||||
tree_flatten_with_structure(fun_(*args[0], **args[1]), false);
|
||||
if (auto s = output_structure_.lock()) {
|
||||
*s = output_structure;
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
|
||||
auto output_structure = std::make_shared<py::object>();
|
||||
auto full_args = py::make_tuple(args, kwargs);
|
||||
auto [inputs, args_structure] =
|
||||
tree_flatten_with_structure(full_args, false);
|
||||
|
||||
auto outputs = checkpoint(
|
||||
InnerFunction(fun_, args_structure, output_structure))(inputs);
|
||||
|
||||
return tree_unflatten_from_structure(*output_structure, outputs);
|
||||
}
|
||||
|
||||
private:
|
||||
py::function fun_;
|
||||
};
|
||||
|
||||
void init_transforms(py::module_& m) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
@@ -802,6 +894,10 @@ void init_transforms(py::module_& m) {
|
||||
Globally enable compilation. This will override the environment
|
||||
variable ``MLX_DISABLE_COMPILE`` if set.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"checkpoint",
|
||||
[](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); },
|
||||
"fun"_a);
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
|
Reference in New Issue
Block a user