Fix arctan2 grads (#2453)

This commit is contained in:
Angelos Katharopoulos
2025-08-01 21:06:04 -07:00
committed by GitHub
parent be9bc96da4
commit 8831064493
2 changed files with 51 additions and 6 deletions

View File

@@ -413,6 +413,25 @@ TEST_CASE("test op vjps") {
CHECK(out.second.item<float>() == doctest::Approx(-std::sin(1.0f)));
}
// Test arctan
{
auto out = vjp(
[](array input) { return arctan(input); }, array(2.0f), array(1.0f));
CHECK(out.second.item<float>() == doctest::Approx(0.2f));
}
// Test arctan2
{
auto out = vjp(
[](const std::vector<array>& xs) {
return std::vector<array>{arctan2(xs[0], xs[1])};
},
{array(2.0f), array(3.0f)},
{array(1.0f)});
CHECK(out.second[0].item<float>() == doctest::Approx(3.0f / 13.0f));
CHECK(out.second[1].item<float>() == doctest::Approx(-2.0f / 13.0f));
}
// Test log
{
auto out = vjp([](array in) { return log(in); }, array(2.0f), array(1.0f));