diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 027f6bb48b..a4fa011d51 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -510,7 +510,27 @@ std::vector ArcTan2::vjp( const std::vector& cotangents, const std::vector& argnums, const std::vector&) { - return jvp(primals, cotangents, argnums); + assert(primals.size() == 2); + assert(argnums.size() == 2); + + const auto& s = stream(); + const array& x1 = primals[0]; + const array& x2 = primals[1]; + const array& dy = cotangents[0]; + + std::vector grads; + array dy_over_x1_x2_squared = + divide(dy, add(square(x1, s), square(x2, s)), s); + + for (auto arg : argnums) { + if (arg == 0) { + grads.emplace_back(multiply(x2, dy_over_x1_x2_squared, s)); + } else { + grads.emplace_back(multiply(negative(x1, s), dy_over_x1_x2_squared, s)); + } + } + + return grads; } std::vector ArcTan2::jvp( @@ -519,11 +539,17 @@ std::vector ArcTan2::jvp( const std::vector& argnums) { assert(primals.size() == 2); assert(argnums.size() == 2); - array t = - add(square(primals[0], stream()), square(primals[1], stream()), stream()); - return { - divide(tangents[0], t, stream()), - divide(negative(tangents[1], stream()), t, stream())}; + + const auto& s = stream(); + const array& x1 = primals[0]; + const array& x2 = primals[1]; + const array& dx1 = tangents[0]; + const array& dx2 = tangents[1]; + + return {divide( + subtract(multiply(x2, dx1, s), multiply(x1, dx2, s), s), + add(square(x1, s), square(x2, s), s), + s)}; } std::pair, std::vector> ArcTan2::vmap( diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index 5b3454bfce..3a373fb18b 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -413,6 +413,25 @@ TEST_CASE("test op vjps") { CHECK(out.second.item() == doctest::Approx(-std::sin(1.0f))); } + // Test arctan + { + auto out = vjp( + [](array input) { return arctan(input); }, array(2.0f), array(1.0f)); + CHECK(out.second.item() == doctest::Approx(0.2f)); + } + + // Test arctan2 + { + auto out = vjp( + [](const std::vector& xs) { + return std::vector{arctan2(xs[0], xs[1])}; + }, + {array(2.0f), array(3.0f)}, + {array(1.0f)}); + CHECK(out.second[0].item() == doctest::Approx(3.0f / 13.0f)); + CHECK(out.second[1].item() == doctest::Approx(-2.0f / 13.0f)); + } + // Test log { auto out = vjp([](array in) { return log(in); }, array(2.0f), array(1.0f));