Fix arctan2 grads (#2453)

This commit is contained in:
Angelos Katharopoulos
2025-08-01 21:06:04 -07:00
committed by GitHub
parent be9bc96da4
commit 8831064493
2 changed files with 51 additions and 6 deletions

View File

@@ -510,7 +510,27 @@ std::vector<array> ArcTan2::vjp(
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
assert(primals.size() == 2);
assert(argnums.size() == 2);
const auto& s = stream();
const array& x1 = primals[0];
const array& x2 = primals[1];
const array& dy = cotangents[0];
std::vector<array> grads;
array dy_over_x1_x2_squared =
divide(dy, add(square(x1, s), square(x2, s)), s);
for (auto arg : argnums) {
if (arg == 0) {
grads.emplace_back(multiply(x2, dy_over_x1_x2_squared, s));
} else {
grads.emplace_back(multiply(negative(x1, s), dy_over_x1_x2_squared, s));
}
}
return grads;
}
std::vector<array> ArcTan2::jvp(
@@ -519,11 +539,17 @@ std::vector<array> ArcTan2::jvp(
const std::vector<int>& argnums) {
assert(primals.size() == 2);
assert(argnums.size() == 2);
array t =
add(square(primals[0], stream()), square(primals[1], stream()), stream());
return {
divide(tangents[0], t, stream()),
divide(negative(tangents[1], stream()), t, stream())};
const auto& s = stream();
const array& x1 = primals[0];
const array& x2 = primals[1];
const array& dx1 = tangents[0];
const array& dx2 = tangents[1];
return {divide(
subtract(multiply(x2, dx1, s), multiply(x1, dx2, s), s),
add(square(x1, s), square(x2, s), s),
s)};
}
std::pair<std::vector<array>, std::vector<int>> ArcTan2::vmap(