mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
Fix arctan2 grads (#2453)
This commit is contained in:

committed by
GitHub

parent
be9bc96da4
commit
8831064493
@@ -510,7 +510,27 @@ std::vector<array> ArcTan2::vjp(
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
assert(primals.size() == 2);
|
||||
assert(argnums.size() == 2);
|
||||
|
||||
const auto& s = stream();
|
||||
const array& x1 = primals[0];
|
||||
const array& x2 = primals[1];
|
||||
const array& dy = cotangents[0];
|
||||
|
||||
std::vector<array> grads;
|
||||
array dy_over_x1_x2_squared =
|
||||
divide(dy, add(square(x1, s), square(x2, s)), s);
|
||||
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
grads.emplace_back(multiply(x2, dy_over_x1_x2_squared, s));
|
||||
} else {
|
||||
grads.emplace_back(multiply(negative(x1, s), dy_over_x1_x2_squared, s));
|
||||
}
|
||||
}
|
||||
|
||||
return grads;
|
||||
}
|
||||
|
||||
std::vector<array> ArcTan2::jvp(
|
||||
@@ -519,11 +539,17 @@ std::vector<array> ArcTan2::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 2);
|
||||
assert(argnums.size() == 2);
|
||||
array t =
|
||||
add(square(primals[0], stream()), square(primals[1], stream()), stream());
|
||||
return {
|
||||
divide(tangents[0], t, stream()),
|
||||
divide(negative(tangents[1], stream()), t, stream())};
|
||||
|
||||
const auto& s = stream();
|
||||
const array& x1 = primals[0];
|
||||
const array& x2 = primals[1];
|
||||
const array& dx1 = tangents[0];
|
||||
const array& dx2 = tangents[1];
|
||||
|
||||
return {divide(
|
||||
subtract(multiply(x2, dx1, s), multiply(x1, dx2, s), s),
|
||||
add(square(x1, s), square(x2, s), s),
|
||||
s)};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> ArcTan2::vmap(
|
||||
|
Reference in New Issue
Block a user