Custom VJP and checkpointing (#541)

* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
This commit is contained in:
Angelos Katharopoulos
2024-01-30 16:04:45 -08:00
committed by GitHub
parent 143e2690d5
commit 0de5988f92
22 changed files with 527 additions and 37 deletions

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <algorithm>
@@ -142,7 +141,8 @@ std::vector<array> tree_flatten(py::object tree, bool strict = true) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
} else if (strict) {
throw std::invalid_argument("Argument is not an array");
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
@@ -162,12 +162,48 @@ py::object tree_unflatten(
});
}
py::object tree_unflatten_none(
py::object structure_sentinel() {
static py::object sentinel;
if (sentinel.ptr() == nullptr) {
sentinel = py::capsule(&sentinel);
// probably not needed but this should make certain that we won't ever
// delete the sentinel
sentinel.inc_ref();
}
return sentinel;
}
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree,
bool strict = true) {
auto sentinel = structure_sentinel();
std::vector<array> flat_tree;
auto structure = tree_map(
tree,
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
return sentinel;
} else if (!strict) {
return py::cast<py::object>(obj);
} else {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return {flat_tree, structure};
}
py::object tree_unflatten_from_structure(
py::object structure,
const std::vector<array>& values,
int index = 0) {
return tree_map(tree, [&](py::handle obj) {
if (py::isinstance<py::none>(obj)) {
auto sentinel = structure_sentinel();
return tree_map(structure, [&](py::handle obj) {
if (obj.is(sentinel)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
@@ -472,14 +508,10 @@ struct PyCompiledFun {
py::object operator()(const py::args& args) {
auto compile_fun = [this, &args](const std::vector<array>& a) {
// Call the python function
py::object py_outputs = this->fun(*tree_unflatten(args, a));
// Call the python function and flatten the outputs
auto [outputs, py_outputs] = tree_flatten_with_structure(
std::move(this->fun(*tree_unflatten(args, a))), true);
// Flatten the outputs
auto outputs = tree_flatten(py_outputs, true);
py_outputs =
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
tree_cache().insert({this->fun_id, py_outputs});
return outputs;
};
@@ -492,15 +524,75 @@ struct PyCompiledFun {
// Put the outputs back in the container
py::object py_outputs = tree_cache().at(fun_id);
return tree_unflatten_none(py_outputs, outputs);
return tree_unflatten_from_structure(py_outputs, outputs);
};
~PyCompiledFun() {
py::gil_scoped_acquire gil;
tree_cache().erase(fun_id);
detail::compile_erase(fun_id);
fun.release().dec_ref();
}
};
class PyCheckpointedFun {
public:
PyCheckpointedFun(py::function fun) : fun_(std::move(fun)) {}
~PyCheckpointedFun() {
py::gil_scoped_acquire gil;
fun_.release().dec_ref();
}
struct InnerFunction {
py::object fun_;
py::object args_structure_;
std::weak_ptr<py::object> output_structure_;
InnerFunction(
py::object fun,
py::object args_structure,
std::weak_ptr<py::object> output_structure)
: fun_(std::move(fun)),
args_structure_(std::move(args_structure)),
output_structure_(output_structure) {}
~InnerFunction() {
py::gil_scoped_acquire gil;
fun_.release().dec_ref();
args_structure_.release().dec_ref();
}
std::vector<array> operator()(const std::vector<array>& inputs) {
auto args = py::cast<py::tuple>(
tree_unflatten_from_structure(args_structure_, inputs));
auto [outputs, output_structure] =
tree_flatten_with_structure(fun_(*args[0], **args[1]), false);
if (auto s = output_structure_.lock()) {
*s = output_structure;
}
return outputs;
}
};
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
auto output_structure = std::make_shared<py::object>();
auto full_args = py::make_tuple(args, kwargs);
auto [inputs, args_structure] =
tree_flatten_with_structure(full_args, false);
auto outputs = checkpoint(
InnerFunction(fun_, args_structure, output_structure))(inputs);
return tree_unflatten_from_structure(*output_structure, outputs);
}
private:
py::function fun_;
};
void init_transforms(py::module_& m) {
py::options options;
options.disable_function_signatures();
@@ -802,6 +894,10 @@ void init_transforms(py::module_& m) {
Globally enable compilation. This will override the environment
variable ``MLX_DISABLE_COMPILE`` if set.
)pbdoc");
m.def(
"checkpoint",
[](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); },
"fun"_a);
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");