mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Added ArcTan2 operation (#1079)
* Added ArcTan2 operation * Cleanup, bug fixes from code review * Minor cleanup, fixed Linux tests
This commit is contained in:
@@ -193,6 +193,26 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
if (a.is_donatable()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
int size = a.data_size();
|
||||
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
@@ -293,4 +293,25 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[arctan2] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctan2] Cannot compute inverse tangent for arrays"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -34,6 +34,7 @@ DEFAULT(ArcCosh)
|
||||
DEFAULT(ArcSin)
|
||||
DEFAULT(ArcSinh)
|
||||
DEFAULT(ArcTan)
|
||||
DEFAULT(ArcTan2)
|
||||
DEFAULT(ArcTanh)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
|
@@ -161,6 +161,13 @@ struct ArcTan {
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return std::atan2(y, x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@@ -264,3 +264,10 @@ struct RightShift {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return metal::precise::atan2(y, x);
|
||||
}
|
||||
};
|
||||
|
@@ -241,6 +241,7 @@ instantiate_binary_types(mul, Multiply)
|
||||
instantiate_binary_types(sub, Subtract)
|
||||
instantiate_binary_types(pow, Power)
|
||||
instantiate_binary_types(rem, Remainder)
|
||||
instantiate_binary_float(arctan2, ArcTan2)
|
||||
|
||||
// NaNEqual only needed for floating point types with boolean output
|
||||
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
||||
|
@@ -451,6 +451,10 @@ void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arctan");
|
||||
}
|
||||
|
||||
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "arctan2");
|
||||
}
|
||||
|
||||
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arctanh");
|
||||
}
|
||||
|
@@ -25,6 +25,7 @@ NO_GPU(ArcCosh)
|
||||
NO_GPU(ArcSin)
|
||||
NO_GPU(ArcSinh)
|
||||
NO_GPU(ArcTan)
|
||||
NO_GPU(ArcTan2)
|
||||
NO_GPU(ArcTanh)
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(ArgReduce)
|
||||
|
@@ -45,7 +45,8 @@ bool is_binary(const Primitive& p) {
|
||||
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
|
||||
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
|
||||
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
|
||||
typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary));
|
||||
typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary) ||
|
||||
typeid(p) == typeid(ArcTan2));
|
||||
}
|
||||
|
||||
bool is_ternary(const Primitive& p) {
|
||||
|
@@ -2152,6 +2152,14 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) {
|
||||
a.shape(), dtype, std::make_shared<ArcTan>(to_stream(s)), {input});
|
||||
}
|
||||
|
||||
array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, dtype, std::make_shared<ArcTan2>(to_stream(s)), std::move(inputs));
|
||||
}
|
||||
|
||||
array sinh(const array& a, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto input = astype(a, dtype, s);
|
||||
|
@@ -828,6 +828,9 @@ array arccos(const array& a, StreamOrDevice s = {});
|
||||
/** Arc Tangent of the elements of an array */
|
||||
array arctan(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Inverse tangent of the ratio of two arrays */
|
||||
array arctan2(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
/** Hyperbolic Sine of the elements of an array */
|
||||
array sinh(const array& a, StreamOrDevice s = {});
|
||||
|
||||
|
@@ -402,6 +402,36 @@ std::pair<std::vector<array>, std::vector<int>> ArcTan::vmap(
|
||||
return {{arctan(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> ArcTan2::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
std::vector<array> ArcTan2::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 2);
|
||||
assert(argnums.size() == 2);
|
||||
array t =
|
||||
add(square(primals[0], stream()), square(primals[1], stream()), stream());
|
||||
return {
|
||||
divide(tangents[0], t, stream()),
|
||||
divide(negative(tangents[1], stream()), t, stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> ArcTan2::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 2);
|
||||
assert(axes.size() == 2);
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {{arctan2(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> ArcTanh::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@@ -314,6 +314,23 @@ class ArcTan : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArcTan2 : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcTan2(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(ArcTan2)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArcTanh : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
Reference in New Issue
Block a user