mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Added ArcTan2 operation (#1079)
* Added ArcTan2 operation * Cleanup, bug fixes from code review * Minor cleanup, fixed Linux tests
This commit is contained in:
parent
fe96ceee66
commit
cc05a281c4
@ -19,6 +19,7 @@ Operations
|
||||
arcsin
|
||||
arcsinh
|
||||
arctan
|
||||
arctan2
|
||||
arctanh
|
||||
argmax
|
||||
argmin
|
||||
|
@ -193,6 +193,26 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
if (a.is_donatable()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
int size = a.data_size();
|
||||
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
@ -293,4 +293,25 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[arctan2] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctan2] Cannot compute inverse tangent for arrays"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -34,6 +34,7 @@ DEFAULT(ArcCosh)
|
||||
DEFAULT(ArcSin)
|
||||
DEFAULT(ArcSinh)
|
||||
DEFAULT(ArcTan)
|
||||
DEFAULT(ArcTan2)
|
||||
DEFAULT(ArcTanh)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
|
@ -161,6 +161,13 @@ struct ArcTan {
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return std::atan2(y, x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@ -264,3 +264,10 @@ struct RightShift {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return metal::precise::atan2(y, x);
|
||||
}
|
||||
};
|
||||
|
@ -241,6 +241,7 @@ instantiate_binary_types(mul, Multiply)
|
||||
instantiate_binary_types(sub, Subtract)
|
||||
instantiate_binary_types(pow, Power)
|
||||
instantiate_binary_types(rem, Remainder)
|
||||
instantiate_binary_float(arctan2, ArcTan2)
|
||||
|
||||
// NaNEqual only needed for floating point types with boolean output
|
||||
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
||||
|
@ -451,6 +451,10 @@ void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arctan");
|
||||
}
|
||||
|
||||
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "arctan2");
|
||||
}
|
||||
|
||||
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "arctanh");
|
||||
}
|
||||
|
@ -25,6 +25,7 @@ NO_GPU(ArcCosh)
|
||||
NO_GPU(ArcSin)
|
||||
NO_GPU(ArcSinh)
|
||||
NO_GPU(ArcTan)
|
||||
NO_GPU(ArcTan2)
|
||||
NO_GPU(ArcTanh)
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(ArgReduce)
|
||||
|
@ -45,7 +45,8 @@ bool is_binary(const Primitive& p) {
|
||||
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
|
||||
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
|
||||
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
|
||||
typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary));
|
||||
typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary) ||
|
||||
typeid(p) == typeid(ArcTan2));
|
||||
}
|
||||
|
||||
bool is_ternary(const Primitive& p) {
|
||||
|
@ -2152,6 +2152,14 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) {
|
||||
a.shape(), dtype, std::make_shared<ArcTan>(to_stream(s)), {input});
|
||||
}
|
||||
|
||||
array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, dtype, std::make_shared<ArcTan2>(to_stream(s)), std::move(inputs));
|
||||
}
|
||||
|
||||
array sinh(const array& a, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto input = astype(a, dtype, s);
|
||||
|
@ -828,6 +828,9 @@ array arccos(const array& a, StreamOrDevice s = {});
|
||||
/** Arc Tangent of the elements of an array */
|
||||
array arctan(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Inverse tangent of the ratio of two arrays */
|
||||
array arctan2(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
/** Hyperbolic Sine of the elements of an array */
|
||||
array sinh(const array& a, StreamOrDevice s = {});
|
||||
|
||||
|
@ -402,6 +402,36 @@ std::pair<std::vector<array>, std::vector<int>> ArcTan::vmap(
|
||||
return {{arctan(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> ArcTan2::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
std::vector<array> ArcTan2::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 2);
|
||||
assert(argnums.size() == 2);
|
||||
array t =
|
||||
add(square(primals[0], stream()), square(primals[1], stream()), stream());
|
||||
return {
|
||||
divide(tangents[0], t, stream()),
|
||||
divide(negative(tangents[1], stream()), t, stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> ArcTan2::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 2);
|
||||
assert(axes.size() == 2);
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {{arctan2(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> ArcTanh::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@ -314,6 +314,23 @@ class ArcTan : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArcTan2 : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcTan2(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(ArcTan2)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class ArcTanh : public UnaryPrimitive {
|
||||
public:
|
||||
explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
@ -930,6 +930,25 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The inverse tangent of ``a``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"arctan2",
|
||||
&mlx::core::arctan2,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def arctan2(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Element-wise inverse tangent of the ratio of two arrays.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
b (array): Input array.
|
||||
|
||||
Returns:
|
||||
array: The inverse tangent of the ratio of ``a`` and ``b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"sinh",
|
||||
&mlx::core::sinh,
|
||||
|
@ -1273,6 +1273,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
"arcsin": lambda primal, cotan: cotan / np.sqrt(1.0 - primal**2),
|
||||
"arccos": lambda primal, cotan: -cotan / np.sqrt(1.0 - primal**2),
|
||||
"arctan": lambda primal, cotan: cotan / (1.0 + primal**2),
|
||||
"arctan2": lambda primal, cotan: cotan / (1.0 + primal**2),
|
||||
"arcsinh": lambda primal, cotan: cotan / np.sqrt(primal**2 + 1),
|
||||
"arccosh": lambda primal, cotan: cotan / np.sqrt(primal**2 - 1),
|
||||
"arctanh": lambda primal, cotan: cotan / (1.0 - primal**2),
|
||||
|
Loading…
Reference in New Issue
Block a user