mlx/tests/custom_vjp_tests.cpp
Angelos Katharopoulos 0de5988f92
Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00

58 lines
1.5 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test simple custom vjp") {
auto one = array(1.0);
auto x = array(2.0);
auto y = array(3.0);
auto fn = [](const std::vector<array>& inputs) {
return std::vector<array>{inputs[0] * inputs[1], inputs[0] + inputs[1]};
};
auto transformed_fn = custom_vjp(
fn,
[&](const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&) {
return std::vector<array>{one, one};
});
auto [z, g] = vjp(fn, {x, y}, {one, one});
CHECK_EQ(z[0].item<float>(), 6.0f);
CHECK_EQ(z[1].item<float>(), 5.0f);
CHECK_EQ(g[0].item<float>(), 4.0f);
CHECK_EQ(g[1].item<float>(), 3.0f);
std::tie(z, g) = vjp(transformed_fn, {x, y}, {one, one});
CHECK_EQ(z[0].item<float>(), 6.0f);
CHECK_EQ(z[1].item<float>(), 5.0f);
CHECK_EQ(g[0].item<float>(), 1.0f);
CHECK_EQ(g[1].item<float>(), 1.0f);
}
TEST_CASE("test checkpointing") {
auto one = array(1.0);
auto x = array(2.0);
auto y = array(3.0);
int cnt = 0;
auto fn = [&cnt](const std::vector<array>& inputs) {
cnt++;
auto x = inputs[0] * inputs[1];
auto y = inputs[0] + inputs[1];
return std::vector<array>{square(x + y)};
};
auto checkpointed_fn = checkpoint(fn);
auto [z, g] = vjp(checkpointed_fn, {x, y}, {one});
CHECK_EQ(z[0].item<float>(), 121.0f);
CHECK_EQ(g[0].item<float>(), 88.0f);
CHECK_EQ(g[1].item<float>(), 66.0f);
CHECK_EQ(cnt, 2);
}