mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing * Add a dependency management primitive * Change the eval order to deep branches first * Add graph depth tracking to the array
This commit is contained in:

committed by
GitHub

parent
143e2690d5
commit
0de5988f92
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <algorithm>
|
||||
@@ -142,7 +141,8 @@ std::vector<array> tree_flatten(py::object tree, bool strict = true) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
} else if (strict) {
|
||||
throw std::invalid_argument("Argument is not an array");
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
}
|
||||
});
|
||||
|
||||
@@ -162,12 +162,48 @@ py::object tree_unflatten(
|
||||
});
|
||||
}
|
||||
|
||||
py::object tree_unflatten_none(
|
||||
py::object structure_sentinel() {
|
||||
static py::object sentinel;
|
||||
|
||||
if (sentinel.ptr() == nullptr) {
|
||||
sentinel = py::capsule(&sentinel);
|
||||
// probably not needed but this should make certain that we won't ever
|
||||
// delete the sentinel
|
||||
sentinel.inc_ref();
|
||||
}
|
||||
|
||||
return sentinel;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
||||
py::object tree,
|
||||
bool strict = true) {
|
||||
auto sentinel = structure_sentinel();
|
||||
std::vector<array> flat_tree;
|
||||
auto structure = tree_map(
|
||||
tree,
|
||||
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
return sentinel;
|
||||
} else if (!strict) {
|
||||
return py::cast<py::object>(obj);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
}
|
||||
});
|
||||
|
||||
return {flat_tree, structure};
|
||||
}
|
||||
|
||||
py::object tree_unflatten_from_structure(
|
||||
py::object structure,
|
||||
const std::vector<array>& values,
|
||||
int index = 0) {
|
||||
return tree_map(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
auto sentinel = structure_sentinel();
|
||||
return tree_map(structure, [&](py::handle obj) {
|
||||
if (obj.is(sentinel)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
@@ -472,14 +508,10 @@ struct PyCompiledFun {
|
||||
|
||||
py::object operator()(const py::args& args) {
|
||||
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
||||
// Call the python function
|
||||
py::object py_outputs = this->fun(*tree_unflatten(args, a));
|
||||
// Call the python function and flatten the outputs
|
||||
auto [outputs, py_outputs] = tree_flatten_with_structure(
|
||||
std::move(this->fun(*tree_unflatten(args, a))), true);
|
||||
|
||||
// Flatten the outputs
|
||||
auto outputs = tree_flatten(py_outputs, true);
|
||||
|
||||
py_outputs =
|
||||
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
|
||||
tree_cache().insert({this->fun_id, py_outputs});
|
||||
return outputs;
|
||||
};
|
||||
@@ -492,15 +524,75 @@ struct PyCompiledFun {
|
||||
|
||||
// Put the outputs back in the container
|
||||
py::object py_outputs = tree_cache().at(fun_id);
|
||||
return tree_unflatten_none(py_outputs, outputs);
|
||||
return tree_unflatten_from_structure(py_outputs, outputs);
|
||||
};
|
||||
|
||||
~PyCompiledFun() {
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
tree_cache().erase(fun_id);
|
||||
detail::compile_erase(fun_id);
|
||||
fun.release().dec_ref();
|
||||
}
|
||||
};
|
||||
|
||||
class PyCheckpointedFun {
|
||||
public:
|
||||
PyCheckpointedFun(py::function fun) : fun_(std::move(fun)) {}
|
||||
|
||||
~PyCheckpointedFun() {
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
}
|
||||
|
||||
struct InnerFunction {
|
||||
py::object fun_;
|
||||
py::object args_structure_;
|
||||
std::weak_ptr<py::object> output_structure_;
|
||||
|
||||
InnerFunction(
|
||||
py::object fun,
|
||||
py::object args_structure,
|
||||
std::weak_ptr<py::object> output_structure)
|
||||
: fun_(std::move(fun)),
|
||||
args_structure_(std::move(args_structure)),
|
||||
output_structure_(output_structure) {}
|
||||
~InnerFunction() {
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
args_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::vector<array> operator()(const std::vector<array>& inputs) {
|
||||
auto args = py::cast<py::tuple>(
|
||||
tree_unflatten_from_structure(args_structure_, inputs));
|
||||
auto [outputs, output_structure] =
|
||||
tree_flatten_with_structure(fun_(*args[0], **args[1]), false);
|
||||
if (auto s = output_structure_.lock()) {
|
||||
*s = output_structure;
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
|
||||
auto output_structure = std::make_shared<py::object>();
|
||||
auto full_args = py::make_tuple(args, kwargs);
|
||||
auto [inputs, args_structure] =
|
||||
tree_flatten_with_structure(full_args, false);
|
||||
|
||||
auto outputs = checkpoint(
|
||||
InnerFunction(fun_, args_structure, output_structure))(inputs);
|
||||
|
||||
return tree_unflatten_from_structure(*output_structure, outputs);
|
||||
}
|
||||
|
||||
private:
|
||||
py::function fun_;
|
||||
};
|
||||
|
||||
void init_transforms(py::module_& m) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
@@ -802,6 +894,10 @@ void init_transforms(py::module_& m) {
|
||||
Globally enable compilation. This will override the environment
|
||||
variable ``MLX_DISABLE_COMPILE`` if set.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"checkpoint",
|
||||
[](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); },
|
||||
"fun"_a);
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
|
Reference in New Issue
Block a user