mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing * Add a dependency management primitive * Change the eval order to deep branches first * Add graph depth tracking to the array
This commit is contained in:
committed by
GitHub
parent
143e2690d5
commit
0de5988f92
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <functional>
|
||||
|
||||
@@ -97,11 +97,13 @@ void array::detach() {
|
||||
s.array_desc_->inputs.clear();
|
||||
s.array_desc_->siblings.clear();
|
||||
s.array_desc_->position = 0;
|
||||
s.array_desc_->depth = 0;
|
||||
s.array_desc_->primitive = nullptr;
|
||||
}
|
||||
array_desc_->inputs.clear();
|
||||
array_desc_->siblings.clear();
|
||||
array_desc_->position = 0;
|
||||
array_desc_->depth = 0;
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
@@ -180,7 +182,9 @@ array::ArrayDesc::ArrayDesc(
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
depth = std::max(in.graph_depth(), depth);
|
||||
}
|
||||
depth++;
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
@@ -195,7 +199,9 @@ array::ArrayDesc::ArrayDesc(
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
depth = std::max(in.graph_depth(), depth);
|
||||
}
|
||||
depth++;
|
||||
}
|
||||
|
||||
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
|
||||
Reference in New Issue
Block a user