mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Adds round op and primitive (#203)
This commit is contained in:

committed by
GitHub

parent
477397bc98
commit
4d4af12c6f
@@ -133,6 +133,11 @@ struct Negative {
|
||||
template <typename T> T operator()(T x) { return -x; };
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T> T operator()(T x) { return metal::round(x); };
|
||||
template <> complex64_t operator()(complex64_t x) { return {metal::round(x.real), metal::round(x.imag)}; };
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
@@ -300,6 +305,7 @@ instantiate_unary_float(sqrt, Sqrt)
|
||||
instantiate_unary_float(rsqrt, Rsqrt)
|
||||
instantiate_unary_float(tan, Tan)
|
||||
instantiate_unary_float(tanh, Tanh)
|
||||
instantiate_unary_float(round, Round)
|
||||
|
||||
instantiate_unary_all(abs, complex64, complex64_t, Abs)
|
||||
instantiate_unary_all(cos, complex64, complex64_t, Cos)
|
||||
@@ -310,5 +316,6 @@ instantiate_unary_all(sin, complex64, complex64_t, Sin)
|
||||
instantiate_unary_all(sinh, complex64, complex64_t, Sinh)
|
||||
instantiate_unary_all(tan, complex64, complex64_t, Tan)
|
||||
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
||||
instantiate_unary_all(round, complex64, complex64_t, Round)
|
||||
|
||||
instantiate_unary_all(lnot, bool_, bool, LogicalNot)
|
||||
|
@@ -563,6 +563,17 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
unary_op(inputs, out, "round");
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sigmoid");
|
||||
}
|
||||
|
Reference in New Issue
Block a user