Added ArcTan2 operation (#1079)

* Added ArcTan2 operation

* Cleanup, bug fixes from code review

* Minor cleanup, fixed Linux tests
This commit is contained in:
Rahul Yedida
2024-05-08 11:35:15 -04:00
committed by GitHub
parent fe96ceee66
commit cc05a281c4
16 changed files with 143 additions and 1 deletions

View File

@@ -264,3 +264,10 @@ struct RightShift {
return x >> y;
};
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return metal::precise::atan2(y, x);
}
};

View File

@@ -241,6 +241,7 @@ instantiate_binary_types(mul, Multiply)
instantiate_binary_types(sub, Subtract)
instantiate_binary_types(pow, Power)
instantiate_binary_types(rem, Remainder)
instantiate_binary_float(arctan2, ArcTan2)
// NaNEqual only needed for floating point types with boolean output
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)

View File

@@ -451,6 +451,10 @@ void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctan");
}
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "arctan2");
}
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctanh");
}