mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 23:15:09 +08:00
Fix arctan2 grads (#2453)
This commit is contained in:
committed by
GitHub
parent
be9bc96da4
commit
8831064493
@@ -413,6 +413,25 @@ TEST_CASE("test op vjps") {
|
||||
CHECK(out.second.item<float>() == doctest::Approx(-std::sin(1.0f)));
|
||||
}
|
||||
|
||||
// Test arctan
|
||||
{
|
||||
auto out = vjp(
|
||||
[](array input) { return arctan(input); }, array(2.0f), array(1.0f));
|
||||
CHECK(out.second.item<float>() == doctest::Approx(0.2f));
|
||||
}
|
||||
|
||||
// Test arctan2
|
||||
{
|
||||
auto out = vjp(
|
||||
[](const std::vector<array>& xs) {
|
||||
return std::vector<array>{arctan2(xs[0], xs[1])};
|
||||
},
|
||||
{array(2.0f), array(3.0f)},
|
||||
{array(1.0f)});
|
||||
CHECK(out.second[0].item<float>() == doctest::Approx(3.0f / 13.0f));
|
||||
CHECK(out.second[1].item<float>() == doctest::Approx(-2.0f / 13.0f));
|
||||
}
|
||||
|
||||
// Test log
|
||||
{
|
||||
auto out = vjp([](array in) { return log(in); }, array(2.0f), array(1.0f));
|
||||
|
||||
Reference in New Issue
Block a user