Custom VJP and checkpointing (#541)

* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
This commit is contained in:
Angelos Katharopoulos
2024-01-30 16:04:45 -08:00
committed by GitHub
parent 143e2690d5
commit 0de5988f92
22 changed files with 527 additions and 37 deletions

View File

@@ -35,7 +35,7 @@ class Synchronizer : public Primitive {
int detail::InTracing::tracing_counter{0};
void eval(const std::vector<array>& outputs) {
std::function<void(const array&)> recurse;
std::function<void(const array&, bool)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
@@ -52,21 +52,57 @@ void eval(const std::vector<array>& outputs) {
auto synchronizer =
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs);
recurse = [&](const array& a) {
recurse = [&](const array& a, bool largest_branch_first) {
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
for (auto in : a.inputs()) {
recurse(in);
// If one of the inputs is being computed on a different
// stream, we need to manage the dependency.
// If the input is being computed on a different stream, we need to manage
// the dependency.
auto check_dependency = [&](const array& in) {
if (!in.is_evaled()) {
if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.primitive_id(), std::shared_future<void>{}});
}
}
};
// Recurse to the largest or smallest branch first.
size_t num_inputs = a.inputs().size();
if (num_inputs == 1) {
auto& in = a.inputs()[0];
recurse(in, true);
check_dependency(in);
} else if (num_inputs == 2) {
auto depth_1 = a.inputs()[0].graph_depth();
auto depth_2 = a.inputs()[1].graph_depth();
auto& in1 = a.inputs()[static_cast<int>(
!((depth_1 > depth_2) == largest_branch_first))];
auto& in2 = a.inputs()[static_cast<int>(
((depth_1 > depth_2) == largest_branch_first))];
recurse(in1, true);
check_dependency(in1);
recurse(in2, true);
check_dependency(in2);
} else if (num_inputs > 2) {
std::vector<int> recursion_order(a.inputs().size());
std::iota(recursion_order.begin(), recursion_order.end(), 0);
std::sort(
recursion_order.begin(),
recursion_order.end(),
[&a, largest_branch_first](int i, int j) {
auto depth_i = a.inputs()[i].graph_depth();
auto depth_j = a.inputs()[j].graph_depth();
return largest_branch_first ? depth_i > depth_j : depth_j < depth_i;
});
for (int idx : recursion_order) {
auto& in = a.inputs()[idx];
recurse(in, true);
check_dependency(in);
}
}
cache.insert(id);
for (auto& s : a.siblings()) {
cache.insert(s.id());
@@ -80,7 +116,7 @@ void eval(const std::vector<array>& outputs) {
}
};
recurse(synchronizer);
recurse(synchronizer, false);
uintptr_t synch_id = synchronizer.primitive_id();
deps.insert({synch_id, std::shared_future<void>{}});
@@ -713,4 +749,58 @@ std::function<array(const array&)> vmap(
return [vfun](const array& a) { return vfun({a})[0]; };
}
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun_vjp) {
return [fun = std::move(fun),
fun_vjp = std::move(fun_vjp)](const std::vector<array>& args) {
// Compute the outputs
auto outputs = fun(args);
for (auto& out : outputs) {
out = stop_gradient(out);
}
// Prepare the inputs to the primitive
// We also add the outputs to the primitive so that it can "run" the forward
// pass.
std::vector<array> inputs = args;
inputs.insert(inputs.end(), outputs.begin(), outputs.end());
// Compute the stream. Maybe do it in a smarter way at some point in the
// future.
Stream s = (outputs[0].has_primitive()) ? outputs[0].primitive().stream()
: default_stream(default_device());
// Make the output info
std::vector<std::vector<int>> shapes;
std::vector<Dtype> dtypes;
for (const auto& out : outputs) {
shapes.emplace_back(out.shape());
dtypes.emplace_back(out.dtype());
}
return array::make_arrays(
shapes,
dtypes,
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
inputs);
};
}
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
std::function<std::vector<array>(const std::vector<array>&)> fun) {
auto vjp_fun = [fun](
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<array>& outputs) -> std::vector<array> {
auto [__, vjps] = vjp(fun, depends(primals, outputs), cotangents);
return vjps;
};
return custom_vjp(fun, vjp_fun);
}
} // namespace mlx::core