Fix arctan2 grads (#2453)

This commit is contained in:
Angelos Katharopoulos 2025-08-01 21:06:04 -07:00 committed by GitHub
parent be9bc96da4
commit 8831064493
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 6 deletions

View File

@ -510,7 +510,27 @@ std::vector<array> ArcTan2::vjp(
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
assert(primals.size() == 2);
assert(argnums.size() == 2);
const auto& s = stream();
const array& x1 = primals[0];
const array& x2 = primals[1];
const array& dy = cotangents[0];
std::vector<array> grads;
array dy_over_x1_x2_squared =
divide(dy, add(square(x1, s), square(x2, s)), s);
for (auto arg : argnums) {
if (arg == 0) {
grads.emplace_back(multiply(x2, dy_over_x1_x2_squared, s));
} else {
grads.emplace_back(multiply(negative(x1, s), dy_over_x1_x2_squared, s));
}
}
return grads;
}
std::vector<array> ArcTan2::jvp(
@ -519,11 +539,17 @@ std::vector<array> ArcTan2::jvp(
const std::vector<int>& argnums) {
assert(primals.size() == 2);
assert(argnums.size() == 2);
array t =
add(square(primals[0], stream()), square(primals[1], stream()), stream());
return {
divide(tangents[0], t, stream()),
divide(negative(tangents[1], stream()), t, stream())};
const auto& s = stream();
const array& x1 = primals[0];
const array& x2 = primals[1];
const array& dx1 = tangents[0];
const array& dx2 = tangents[1];
return {divide(
subtract(multiply(x2, dx1, s), multiply(x1, dx2, s), s),
add(square(x1, s), square(x2, s), s),
s)};
}
std::pair<std::vector<array>, std::vector<int>> ArcTan2::vmap(

View File

@ -413,6 +413,25 @@ TEST_CASE("test op vjps") {
CHECK(out.second.item<float>() == doctest::Approx(-std::sin(1.0f)));
}
// Test arctan
{
auto out = vjp(
[](array input) { return arctan(input); }, array(2.0f), array(1.0f));
CHECK(out.second.item<float>() == doctest::Approx(0.2f));
}
// Test arctan2
{
auto out = vjp(
[](const std::vector<array>& xs) {
return std::vector<array>{arctan2(xs[0], xs[1])};
},
{array(2.0f), array(3.0f)},
{array(1.0f)});
CHECK(out.second[0].item<float>() == doctest::Approx(3.0f / 13.0f));
CHECK(out.second[1].item<float>() == doctest::Approx(-2.0f / 13.0f));
}
// Test log
{
auto out = vjp([](array in) { return log(in); }, array(2.0f), array(1.0f));