Custom VJP and checkpointing (#541)

* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
This commit is contained in:
Angelos Katharopoulos
2024-01-30 16:04:45 -08:00
committed by GitHub
parent 143e2690d5
commit 0de5988f92
22 changed files with 527 additions and 37 deletions

View File

@@ -21,6 +21,7 @@ target_sources(tests PRIVATE
autograd_tests.cpp
blas_tests.cpp
compile_tests.cpp
custom_vjp_tests.cpp
creations_tests.cpp
device_tests.cpp
eval_tests.cpp