mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing * Add a dependency management primitive * Change the eval order to deep branches first * Add graph depth tracking to the array
This commit is contained in:
committed by
GitHub
parent
143e2690d5
commit
0de5988f92
@@ -35,7 +35,7 @@ class Synchronizer : public Primitive {
|
||||
int detail::InTracing::tracing_counter{0};
|
||||
|
||||
void eval(const std::vector<array>& outputs) {
|
||||
std::function<void(const array&)> recurse;
|
||||
std::function<void(const array&, bool)> recurse;
|
||||
std::queue<array> tape;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
||||
@@ -52,21 +52,57 @@ void eval(const std::vector<array>& outputs) {
|
||||
auto synchronizer =
|
||||
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs);
|
||||
|
||||
recurse = [&](const array& a) {
|
||||
recurse = [&](const array& a, bool largest_branch_first) {
|
||||
auto id = a.id();
|
||||
if (cache.find(id) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
for (auto in : a.inputs()) {
|
||||
recurse(in);
|
||||
// If one of the inputs is being computed on a different
|
||||
// stream, we need to manage the dependency.
|
||||
|
||||
// If the input is being computed on a different stream, we need to manage
|
||||
// the dependency.
|
||||
auto check_dependency = [&](const array& in) {
|
||||
if (!in.is_evaled()) {
|
||||
if (a.primitive().stream() != in.primitive().stream()) {
|
||||
deps.insert({in.primitive_id(), std::shared_future<void>{}});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Recurse to the largest or smallest branch first.
|
||||
size_t num_inputs = a.inputs().size();
|
||||
if (num_inputs == 1) {
|
||||
auto& in = a.inputs()[0];
|
||||
recurse(in, true);
|
||||
check_dependency(in);
|
||||
} else if (num_inputs == 2) {
|
||||
auto depth_1 = a.inputs()[0].graph_depth();
|
||||
auto depth_2 = a.inputs()[1].graph_depth();
|
||||
auto& in1 = a.inputs()[static_cast<int>(
|
||||
!((depth_1 > depth_2) == largest_branch_first))];
|
||||
auto& in2 = a.inputs()[static_cast<int>(
|
||||
((depth_1 > depth_2) == largest_branch_first))];
|
||||
recurse(in1, true);
|
||||
check_dependency(in1);
|
||||
recurse(in2, true);
|
||||
check_dependency(in2);
|
||||
} else if (num_inputs > 2) {
|
||||
std::vector<int> recursion_order(a.inputs().size());
|
||||
std::iota(recursion_order.begin(), recursion_order.end(), 0);
|
||||
std::sort(
|
||||
recursion_order.begin(),
|
||||
recursion_order.end(),
|
||||
[&a, largest_branch_first](int i, int j) {
|
||||
auto depth_i = a.inputs()[i].graph_depth();
|
||||
auto depth_j = a.inputs()[j].graph_depth();
|
||||
return largest_branch_first ? depth_i > depth_j : depth_j < depth_i;
|
||||
});
|
||||
for (int idx : recursion_order) {
|
||||
auto& in = a.inputs()[idx];
|
||||
recurse(in, true);
|
||||
check_dependency(in);
|
||||
}
|
||||
}
|
||||
|
||||
cache.insert(id);
|
||||
for (auto& s : a.siblings()) {
|
||||
cache.insert(s.id());
|
||||
@@ -80,7 +116,7 @@ void eval(const std::vector<array>& outputs) {
|
||||
}
|
||||
};
|
||||
|
||||
recurse(synchronizer);
|
||||
recurse(synchronizer, false);
|
||||
uintptr_t synch_id = synchronizer.primitive_id();
|
||||
deps.insert({synch_id, std::shared_future<void>{}});
|
||||
|
||||
@@ -713,4 +749,58 @@ std::function<array(const array&)> vmap(
|
||||
return [vfun](const array& a) { return vfun({a})[0]; };
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)> fun_vjp) {
|
||||
return [fun = std::move(fun),
|
||||
fun_vjp = std::move(fun_vjp)](const std::vector<array>& args) {
|
||||
// Compute the outputs
|
||||
auto outputs = fun(args);
|
||||
for (auto& out : outputs) {
|
||||
out = stop_gradient(out);
|
||||
}
|
||||
|
||||
// Prepare the inputs to the primitive
|
||||
// We also add the outputs to the primitive so that it can "run" the forward
|
||||
// pass.
|
||||
std::vector<array> inputs = args;
|
||||
inputs.insert(inputs.end(), outputs.begin(), outputs.end());
|
||||
|
||||
// Compute the stream. Maybe do it in a smarter way at some point in the
|
||||
// future.
|
||||
Stream s = (outputs[0].has_primitive()) ? outputs[0].primitive().stream()
|
||||
: default_stream(default_device());
|
||||
|
||||
// Make the output info
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Dtype> dtypes;
|
||||
for (const auto& out : outputs) {
|
||||
shapes.emplace_back(out.shape());
|
||||
dtypes.emplace_back(out.dtype());
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
shapes,
|
||||
dtypes,
|
||||
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
|
||||
inputs);
|
||||
};
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun) {
|
||||
auto vjp_fun = [fun](
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<array>& outputs) -> std::vector<array> {
|
||||
auto [__, vjps] = vjp(fun, depends(primals, outputs), cotangents);
|
||||
return vjps;
|
||||
};
|
||||
|
||||
return custom_vjp(fun, vjp_fun);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user