mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Added ArcTan2 operation (#1079)
* Added ArcTan2 operation * Cleanup, bug fixes from code review * Minor cleanup, fixed Linux tests
This commit is contained in:
parent
fe96ceee66
commit
cc05a281c4
@ -19,6 +19,7 @@ Operations
|
|||||||
arcsin
|
arcsin
|
||||||
arcsinh
|
arcsinh
|
||||||
arctan
|
arctan
|
||||||
|
arctan2
|
||||||
arctanh
|
arctanh
|
||||||
argmax
|
argmax
|
||||||
argmin
|
argmin
|
||||||
|
@ -193,6 +193,26 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||||
|
b.flags().row_contiguous) {
|
||||||
|
if (a.is_donatable()) {
|
||||||
|
out.copy_shared_buffer(a);
|
||||||
|
} else if (b.is_donatable()) {
|
||||||
|
out.copy_shared_buffer(b);
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
}
|
||||||
|
int size = a.data_size();
|
||||||
|
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||||
|
} else {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
|
@ -293,4 +293,25 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
const auto& a = inputs[0];
|
||||||
|
const auto& b = inputs[1];
|
||||||
|
if (out.dtype() == float32) {
|
||||||
|
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||||
|
} else if (out.dtype() == float16) {
|
||||||
|
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||||
|
} else if (out.dtype() == bfloat16) {
|
||||||
|
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||||
|
} else if (issubdtype(out.dtype(), inexact)) {
|
||||||
|
std::ostringstream err;
|
||||||
|
err << "[arctan2] Does not support " << out.dtype();
|
||||||
|
throw std::invalid_argument(err.str());
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[arctan2] Cannot compute inverse tangent for arrays"
|
||||||
|
" with non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -34,6 +34,7 @@ DEFAULT(ArcCosh)
|
|||||||
DEFAULT(ArcSin)
|
DEFAULT(ArcSin)
|
||||||
DEFAULT(ArcSinh)
|
DEFAULT(ArcSinh)
|
||||||
DEFAULT(ArcTan)
|
DEFAULT(ArcTan)
|
||||||
|
DEFAULT(ArcTan2)
|
||||||
DEFAULT(ArcTanh)
|
DEFAULT(ArcTanh)
|
||||||
DEFAULT(ArgPartition)
|
DEFAULT(ArgPartition)
|
||||||
DEFAULT(ArgReduce)
|
DEFAULT(ArgReduce)
|
||||||
|
@ -161,6 +161,13 @@ struct ArcTan {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ArcTan2 {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T y, T x) {
|
||||||
|
return std::atan2(y, x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
struct ArcTanh {
|
struct ArcTanh {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
|
@ -264,3 +264,10 @@ struct RightShift {
|
|||||||
return x >> y;
|
return x >> y;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ArcTan2 {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T y, T x) {
|
||||||
|
return metal::precise::atan2(y, x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
@ -241,6 +241,7 @@ instantiate_binary_types(mul, Multiply)
|
|||||||
instantiate_binary_types(sub, Subtract)
|
instantiate_binary_types(sub, Subtract)
|
||||||
instantiate_binary_types(pow, Power)
|
instantiate_binary_types(pow, Power)
|
||||||
instantiate_binary_types(rem, Remainder)
|
instantiate_binary_types(rem, Remainder)
|
||||||
|
instantiate_binary_float(arctan2, ArcTan2)
|
||||||
|
|
||||||
// NaNEqual only needed for floating point types with boolean output
|
// NaNEqual only needed for floating point types with boolean output
|
||||||
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
||||||
|
@ -451,6 +451,10 @@ void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
unary_op(inputs, out, "arctan");
|
unary_op(inputs, out, "arctan");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "arctan2");
|
||||||
|
}
|
||||||
|
|
||||||
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
unary_op(inputs, out, "arctanh");
|
unary_op(inputs, out, "arctanh");
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,7 @@ NO_GPU(ArcCosh)
|
|||||||
NO_GPU(ArcSin)
|
NO_GPU(ArcSin)
|
||||||
NO_GPU(ArcSinh)
|
NO_GPU(ArcSinh)
|
||||||
NO_GPU(ArcTan)
|
NO_GPU(ArcTan)
|
||||||
|
NO_GPU(ArcTan2)
|
||||||
NO_GPU(ArcTanh)
|
NO_GPU(ArcTanh)
|
||||||
NO_GPU(ArgPartition)
|
NO_GPU(ArgPartition)
|
||||||
NO_GPU(ArgReduce)
|
NO_GPU(ArgReduce)
|
||||||
|
@ -45,7 +45,8 @@ bool is_binary(const Primitive& p) {
|
|||||||
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
|
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
|
||||||
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
|
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
|
||||||
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
|
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
|
||||||
typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary));
|
typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary) ||
|
||||||
|
typeid(p) == typeid(ArcTan2));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_ternary(const Primitive& p) {
|
bool is_ternary(const Primitive& p) {
|
||||||
|
@ -2152,6 +2152,14 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
a.shape(), dtype, std::make_shared<ArcTan>(to_stream(s)), {input});
|
a.shape(), dtype, std::make_shared<ArcTan>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
|
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||||
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||||
|
auto& shape = inputs[0].shape();
|
||||||
|
return array(
|
||||||
|
shape, dtype, std::make_shared<ArcTan2>(to_stream(s)), std::move(inputs));
|
||||||
|
}
|
||||||
|
|
||||||
array sinh(const array& a, StreamOrDevice s /* = {} */) {
|
array sinh(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
|
@ -828,6 +828,9 @@ array arccos(const array& a, StreamOrDevice s = {});
|
|||||||
/** Arc Tangent of the elements of an array */
|
/** Arc Tangent of the elements of an array */
|
||||||
array arctan(const array& a, StreamOrDevice s = {});
|
array arctan(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Inverse tangent of the ratio of two arrays */
|
||||||
|
array arctan2(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Hyperbolic Sine of the elements of an array */
|
/** Hyperbolic Sine of the elements of an array */
|
||||||
array sinh(const array& a, StreamOrDevice s = {});
|
array sinh(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
@ -402,6 +402,36 @@ std::pair<std::vector<array>, std::vector<int>> ArcTan::vmap(
|
|||||||
return {{arctan(inputs[0], stream())}, axes};
|
return {{arctan(inputs[0], stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> ArcTan2::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>&) {
|
||||||
|
return jvp(primals, cotangents, argnums);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> ArcTan2::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
assert(primals.size() == 2);
|
||||||
|
assert(argnums.size() == 2);
|
||||||
|
array t =
|
||||||
|
add(square(primals[0], stream()), square(primals[1], stream()), stream());
|
||||||
|
return {
|
||||||
|
divide(tangents[0], t, stream()),
|
||||||
|
divide(negative(tangents[1], stream()), t, stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> ArcTan2::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
assert(axes.size() == 2);
|
||||||
|
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||||
|
return {{arctan2(a, b, stream())}, {to_ax}};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> ArcTanh::vjp(
|
std::vector<array> ArcTanh::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
|
@ -314,6 +314,23 @@ class ArcTan : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ArcTan2 : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit ArcTan2(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(ArcTan2)
|
||||||
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
|
private:
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
};
|
||||||
|
|
||||||
class ArcTanh : public UnaryPrimitive {
|
class ArcTanh : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {};
|
explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
@ -930,6 +930,25 @@ void init_ops(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The inverse tangent of ``a``.
|
array: The inverse tangent of ``a``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"arctan2",
|
||||||
|
&mlx::core::arctan2,
|
||||||
|
nb::arg(),
|
||||||
|
nb::arg(),
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def arctan2(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Element-wise inverse tangent of the ratio of two arrays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
b (array): Input array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The inverse tangent of the ratio of ``a`` and ``b``.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"sinh",
|
"sinh",
|
||||||
&mlx::core::sinh,
|
&mlx::core::sinh,
|
||||||
|
@ -1273,6 +1273,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
"arcsin": lambda primal, cotan: cotan / np.sqrt(1.0 - primal**2),
|
"arcsin": lambda primal, cotan: cotan / np.sqrt(1.0 - primal**2),
|
||||||
"arccos": lambda primal, cotan: -cotan / np.sqrt(1.0 - primal**2),
|
"arccos": lambda primal, cotan: -cotan / np.sqrt(1.0 - primal**2),
|
||||||
"arctan": lambda primal, cotan: cotan / (1.0 + primal**2),
|
"arctan": lambda primal, cotan: cotan / (1.0 + primal**2),
|
||||||
|
"arctan2": lambda primal, cotan: cotan / (1.0 + primal**2),
|
||||||
"arcsinh": lambda primal, cotan: cotan / np.sqrt(primal**2 + 1),
|
"arcsinh": lambda primal, cotan: cotan / np.sqrt(primal**2 + 1),
|
||||||
"arccosh": lambda primal, cotan: cotan / np.sqrt(primal**2 - 1),
|
"arccosh": lambda primal, cotan: cotan / np.sqrt(primal**2 - 1),
|
||||||
"arctanh": lambda primal, cotan: cotan / (1.0 - primal**2),
|
"arctanh": lambda primal, cotan: cotan / (1.0 - primal**2),
|
||||||
|
Loading…
Reference in New Issue
Block a user