diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index abd5d1997..8cd648e31 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -19,6 +19,7 @@ Operations arcsin arcsinh arctan + arctan2 arctanh argmax argmin diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 7b48e62f7..1fecd9ca6 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -193,6 +193,26 @@ void ArcTan::eval_cpu(const std::vector& inputs, array& out) { } } +void ArcTan2::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + if (out.dtype() == float32 && a.flags().row_contiguous && + b.flags().row_contiguous) { + if (a.is_donatable()) { + out.copy_shared_buffer(a); + } else if (b.is_donatable()) { + out.copy_shared_buffer(b); + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + int size = a.data_size(); + vvatan2f(out.data(), a.data(), b.data(), &size); + } else { + eval(inputs, out); + } +} + void ArcTanh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp index 060d7565c..eba96fc5d 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/common/binary.cpp @@ -293,4 +293,25 @@ void BitwiseBinary::eval_cpu(const std::vector& inputs, array& out) { } } +void ArcTan2::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + const auto& a = inputs[0]; + const auto& b = inputs[1]; + if (out.dtype() == float32) { + binary_op(a, b, out, detail::ArcTan2()); + } else if (out.dtype() == float16) { + binary_op(a, b, out, detail::ArcTan2()); + } else if (out.dtype() == bfloat16) { + binary_op(a, b, out, detail::ArcTan2()); + } else if (issubdtype(out.dtype(), inexact)) { + std::ostringstream err; + err << "[arctan2] Does not support " << out.dtype(); + throw std::invalid_argument(err.str()); + } else { + throw std::invalid_argument( + "[arctan2] Cannot compute inverse tangent for arrays" + " with non floating point type."); + } +} + } // namespace mlx::core diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index d8ec303f1..a0695feeb 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -34,6 +34,7 @@ DEFAULT(ArcCosh) DEFAULT(ArcSin) DEFAULT(ArcSinh) DEFAULT(ArcTan) +DEFAULT(ArcTan2) DEFAULT(ArcTanh) DEFAULT(ArgPartition) DEFAULT(ArgReduce) diff --git a/mlx/backend/common/ops.h b/mlx/backend/common/ops.h index ae2e2d225..8733dbe74 100644 --- a/mlx/backend/common/ops.h +++ b/mlx/backend/common/ops.h @@ -161,6 +161,13 @@ struct ArcTan { }; }; +struct ArcTan2 { + template + T operator()(T y, T x) { + return std::atan2(y, x); + }; +}; + struct ArcTanh { template T operator()(T x) { diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index e9606d4f5..9eea3c7b8 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -264,3 +264,10 @@ struct RightShift { return x >> y; }; }; + +struct ArcTan2 { + template + T operator()(T y, T x) { + return metal::precise::atan2(y, x); + } +}; diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 8dba35958..67967f928 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -241,6 +241,7 @@ instantiate_binary_types(mul, Multiply) instantiate_binary_types(sub, Subtract) instantiate_binary_types(pow, Power) instantiate_binary_types(rem, Remainder) +instantiate_binary_float(arctan2, ArcTan2) // NaNEqual only needed for floating point types with boolean output instantiate_binary_all(naneq, float16, half, bool, NaNEqual) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 364132eba..c4ed2618b 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -451,6 +451,10 @@ void ArcTan::eval_gpu(const std::vector& inputs, array& out) { unary_op(inputs, out, "arctan"); } +void ArcTan2::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "arctan2"); +} + void ArcTanh::eval_gpu(const std::vector& inputs, array& out) { unary_op(inputs, out, "arctanh"); } diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 2b10e416a..0b6a8c4ab 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -25,6 +25,7 @@ NO_GPU(ArcCosh) NO_GPU(ArcSin) NO_GPU(ArcSinh) NO_GPU(ArcTan) +NO_GPU(ArcTan2) NO_GPU(ArcTanh) NO_GPU(ArgPartition) NO_GPU(ArgReduce) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 55aafffaf..f002b20eb 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -45,7 +45,8 @@ bool is_binary(const Primitive& p) { typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) || typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) || typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) || - typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary)); + typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary) || + typeid(p) == typeid(ArcTan2)); } bool is_ternary(const Primitive& p) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 92c6137ef..1903a34ad 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2152,6 +2152,14 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) { a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } +array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(promote_types(a.dtype(), b.dtype())); + auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto& shape = inputs[0].shape(); + return array( + shape, dtype, std::make_shared(to_stream(s)), std::move(inputs)); +} + array sinh(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); diff --git a/mlx/ops.h b/mlx/ops.h index c75dc1846..454b22c0d 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -828,6 +828,9 @@ array arccos(const array& a, StreamOrDevice s = {}); /** Arc Tangent of the elements of an array */ array arctan(const array& a, StreamOrDevice s = {}); +/** Inverse tangent of the ratio of two arrays */ +array arctan2(const array& a, const array& b, StreamOrDevice s = {}); + /** Hyperbolic Sine of the elements of an array */ array sinh(const array& a, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 1f50d1e9c..6f8a0511a 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -402,6 +402,36 @@ std::pair, std::vector> ArcTan::vmap( return {{arctan(inputs[0], stream())}, axes}; } +std::vector ArcTan2::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + return jvp(primals, cotangents, argnums); +} + +std::vector ArcTan2::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 2); + assert(argnums.size() == 2); + array t = + add(square(primals[0], stream()), square(primals[1], stream()), stream()); + return { + divide(tangents[0], t, stream()), + divide(negative(tangents[1], stream()), t, stream())}; +} + +std::pair, std::vector> ArcTan2::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 2); + assert(axes.size() == 2); + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {{arctan2(a, b, stream())}, {to_ax}}; +} + std::vector ArcTanh::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 7d0aca52b..316ae13d5 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -314,6 +314,23 @@ class ArcTan : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class ArcTan2 : public UnaryPrimitive { + public: + explicit ArcTan2(Stream stream) : UnaryPrimitive(stream) {}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(ArcTan2) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() + + private: + void eval(const std::vector& inputs, array& out); +}; + class ArcTanh : public UnaryPrimitive { public: explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {}; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a33ed822d..4857f813d 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -930,6 +930,25 @@ void init_ops(nb::module_& m) { Returns: array: The inverse tangent of ``a``. )pbdoc"); + m.def( + "arctan2", + &mlx::core::arctan2, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def arctan2(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise inverse tangent of the ratio of two arrays. + + Args: + a (array): Input array. + b (array): Input array. + + Returns: + array: The inverse tangent of the ratio of ``a`` and ``b``. + )pbdoc"); m.def( "sinh", &mlx::core::sinh, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index af22bd14e..76a77ccd2 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1273,6 +1273,7 @@ class TestOps(mlx_tests.MLXTestCase): "arcsin": lambda primal, cotan: cotan / np.sqrt(1.0 - primal**2), "arccos": lambda primal, cotan: -cotan / np.sqrt(1.0 - primal**2), "arctan": lambda primal, cotan: cotan / (1.0 + primal**2), + "arctan2": lambda primal, cotan: cotan / (1.0 + primal**2), "arcsinh": lambda primal, cotan: cotan / np.sqrt(primal**2 + 1), "arccosh": lambda primal, cotan: cotan / np.sqrt(primal**2 - 1), "arctanh": lambda primal, cotan: cotan / (1.0 - primal**2),