mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-08 18:16:41 +08:00
Fix arctan2 grads (#2453)
This commit is contained in:
parent
be9bc96da4
commit
8831064493
@ -510,7 +510,27 @@ std::vector<array> ArcTan2::vjp(
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
assert(primals.size() == 2);
|
||||
assert(argnums.size() == 2);
|
||||
|
||||
const auto& s = stream();
|
||||
const array& x1 = primals[0];
|
||||
const array& x2 = primals[1];
|
||||
const array& dy = cotangents[0];
|
||||
|
||||
std::vector<array> grads;
|
||||
array dy_over_x1_x2_squared =
|
||||
divide(dy, add(square(x1, s), square(x2, s)), s);
|
||||
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
grads.emplace_back(multiply(x2, dy_over_x1_x2_squared, s));
|
||||
} else {
|
||||
grads.emplace_back(multiply(negative(x1, s), dy_over_x1_x2_squared, s));
|
||||
}
|
||||
}
|
||||
|
||||
return grads;
|
||||
}
|
||||
|
||||
std::vector<array> ArcTan2::jvp(
|
||||
@ -519,11 +539,17 @@ std::vector<array> ArcTan2::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 2);
|
||||
assert(argnums.size() == 2);
|
||||
array t =
|
||||
add(square(primals[0], stream()), square(primals[1], stream()), stream());
|
||||
return {
|
||||
divide(tangents[0], t, stream()),
|
||||
divide(negative(tangents[1], stream()), t, stream())};
|
||||
|
||||
const auto& s = stream();
|
||||
const array& x1 = primals[0];
|
||||
const array& x2 = primals[1];
|
||||
const array& dx1 = tangents[0];
|
||||
const array& dx2 = tangents[1];
|
||||
|
||||
return {divide(
|
||||
subtract(multiply(x2, dx1, s), multiply(x1, dx2, s), s),
|
||||
add(square(x1, s), square(x2, s), s),
|
||||
s)};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> ArcTan2::vmap(
|
||||
|
@ -413,6 +413,25 @@ TEST_CASE("test op vjps") {
|
||||
CHECK(out.second.item<float>() == doctest::Approx(-std::sin(1.0f)));
|
||||
}
|
||||
|
||||
// Test arctan
|
||||
{
|
||||
auto out = vjp(
|
||||
[](array input) { return arctan(input); }, array(2.0f), array(1.0f));
|
||||
CHECK(out.second.item<float>() == doctest::Approx(0.2f));
|
||||
}
|
||||
|
||||
// Test arctan2
|
||||
{
|
||||
auto out = vjp(
|
||||
[](const std::vector<array>& xs) {
|
||||
return std::vector<array>{arctan2(xs[0], xs[1])};
|
||||
},
|
||||
{array(2.0f), array(3.0f)},
|
||||
{array(1.0f)});
|
||||
CHECK(out.second[0].item<float>() == doctest::Approx(3.0f / 13.0f));
|
||||
CHECK(out.second[1].item<float>() == doctest::Approx(-2.0f / 13.0f));
|
||||
}
|
||||
|
||||
// Test log
|
||||
{
|
||||
auto out = vjp([](array in) { return log(in); }, array(2.0f), array(1.0f));
|
||||
|
Loading…
Reference in New Issue
Block a user