// Copyright © 2023-2024 Apple Inc. #include "doctest/doctest.h" #include "mlx/mlx.h" using namespace mlx::core; TEST_CASE("test simple custom vjp") { auto one = array(1.0); auto x = array(2.0); auto y = array(3.0); auto fn = [](const std::vector& inputs) { return std::vector{inputs[0] * inputs[1], inputs[0] + inputs[1]}; }; auto transformed_fn = custom_vjp( fn, [&](const std::vector&, const std::vector&, const std::vector&) { return std::vector{one, one}; }); auto [z, g] = vjp(fn, {x, y}, {one, one}); CHECK_EQ(z[0].item(), 6.0f); CHECK_EQ(z[1].item(), 5.0f); CHECK_EQ(g[0].item(), 4.0f); CHECK_EQ(g[1].item(), 3.0f); std::tie(z, g) = vjp(transformed_fn, {x, y}, {one, one}); CHECK_EQ(z[0].item(), 6.0f); CHECK_EQ(z[1].item(), 5.0f); CHECK_EQ(g[0].item(), 1.0f); CHECK_EQ(g[1].item(), 1.0f); } TEST_CASE("test checkpointing") { auto one = array(1.0); auto x = array(2.0); auto y = array(3.0); int cnt = 0; auto fn = [&cnt](const std::vector& inputs) { cnt++; auto x = inputs[0] * inputs[1]; auto y = inputs[0] + inputs[1]; return std::vector{square(x + y)}; }; auto checkpointed_fn = checkpoint(fn); auto [z, g] = vjp(checkpointed_fn, {x, y}, {one}); CHECK_EQ(z[0].item(), 121.0f); CHECK_EQ(g[0].item(), 88.0f); CHECK_EQ(g[1].item(), 66.0f); CHECK_EQ(cnt, 2); }