mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 14:58:11 +08:00
Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing * Add a dependency management primitive * Change the eval order to deep branches first * Add graph depth tracking to the array
This commit is contained in:
committed by
GitHub
parent
143e2690d5
commit
0de5988f92
@@ -21,6 +21,7 @@ target_sources(tests PRIVATE
|
||||
autograd_tests.cpp
|
||||
blas_tests.cpp
|
||||
compile_tests.cpp
|
||||
custom_vjp_tests.cpp
|
||||
creations_tests.cpp
|
||||
device_tests.cpp
|
||||
eval_tests.cpp
|
||||
|
||||
57
tests/custom_vjp_tests.cpp
Normal file
57
tests/custom_vjp_tests.cpp
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test simple custom vjp") {
|
||||
auto one = array(1.0);
|
||||
auto x = array(2.0);
|
||||
auto y = array(3.0);
|
||||
|
||||
auto fn = [](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{inputs[0] * inputs[1], inputs[0] + inputs[1]};
|
||||
};
|
||||
auto transformed_fn = custom_vjp(
|
||||
fn,
|
||||
[&](const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&) {
|
||||
return std::vector<array>{one, one};
|
||||
});
|
||||
|
||||
auto [z, g] = vjp(fn, {x, y}, {one, one});
|
||||
CHECK_EQ(z[0].item<float>(), 6.0f);
|
||||
CHECK_EQ(z[1].item<float>(), 5.0f);
|
||||
CHECK_EQ(g[0].item<float>(), 4.0f);
|
||||
CHECK_EQ(g[1].item<float>(), 3.0f);
|
||||
|
||||
std::tie(z, g) = vjp(transformed_fn, {x, y}, {one, one});
|
||||
CHECK_EQ(z[0].item<float>(), 6.0f);
|
||||
CHECK_EQ(z[1].item<float>(), 5.0f);
|
||||
CHECK_EQ(g[0].item<float>(), 1.0f);
|
||||
CHECK_EQ(g[1].item<float>(), 1.0f);
|
||||
}
|
||||
|
||||
TEST_CASE("test checkpointing") {
|
||||
auto one = array(1.0);
|
||||
auto x = array(2.0);
|
||||
auto y = array(3.0);
|
||||
|
||||
int cnt = 0;
|
||||
auto fn = [&cnt](const std::vector<array>& inputs) {
|
||||
cnt++;
|
||||
auto x = inputs[0] * inputs[1];
|
||||
auto y = inputs[0] + inputs[1];
|
||||
return std::vector<array>{square(x + y)};
|
||||
};
|
||||
auto checkpointed_fn = checkpoint(fn);
|
||||
|
||||
auto [z, g] = vjp(checkpointed_fn, {x, y}, {one});
|
||||
CHECK_EQ(z[0].item<float>(), 121.0f);
|
||||
CHECK_EQ(g[0].item<float>(), 88.0f);
|
||||
CHECK_EQ(g[1].item<float>(), 66.0f);
|
||||
CHECK_EQ(cnt, 2);
|
||||
}
|
||||
Reference in New Issue
Block a user