Custom VJP and checkpointing (#541)

* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
This commit is contained in:
Angelos Katharopoulos
2024-01-30 16:04:45 -08:00
committed by GitHub
parent 143e2690d5
commit 0de5988f92
22 changed files with 527 additions and 37 deletions

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
@@ -615,7 +615,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
assert(inputs.size() == 3);
if (!is_floating_point(out.dtype())) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");