Custom VJP and checkpointing (#541)

* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
This commit is contained in:
Angelos Katharopoulos
2024-01-30 16:04:45 -08:00
committed by GitHub
parent 143e2690d5
commit 0de5988f92
22 changed files with 527 additions and 37 deletions

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <functional>
@@ -97,11 +97,13 @@ void array::detach() {
s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear();
s.array_desc_->position = 0;
s.array_desc_->depth = 0;
s.array_desc_->primitive = nullptr;
}
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;
array_desc_->depth = 0;
array_desc_->primitive = nullptr;
}
@@ -180,7 +182,9 @@ array::ArrayDesc::ArrayDesc(
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayDesc::ArrayDesc(
@@ -195,7 +199,9 @@ array::ArrayDesc::ArrayDesc(
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayIterator::ArrayIterator(const array& arr, int idx)