mlx/python
Angelos Katharopoulos 0de5988f92
Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
..
mlx Custom VJP and checkpointing (#541) 2024-01-30 16:04:45 -08:00
src Custom VJP and checkpointing (#541) 2024-01-30 16:04:45 -08:00
tests Make shape a tuple (#591) 2024-01-30 13:11:01 -08:00