Adds round op and primitive (#203)

This commit is contained in:
Angelos Katharopoulos
2023-12-18 11:32:48 -08:00
committed by GitHub
parent 477397bc98
commit 4d4af12c6f
17 changed files with 187 additions and 2 deletions

View File

@@ -47,6 +47,7 @@ DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT(RandomBits)
DEFAULT(Reshape)
DEFAULT(Round)
DEFAULT(Scatter)
DEFAULT(Sigmoid)
DEFAULT(Sign)

View File

@@ -65,6 +65,7 @@ DEFAULT(Power)
DEFAULT(RandomBits)
DEFAULT(Reduce)
DEFAULT(Reshape)
DEFAULT(Round)
DEFAULT(Scan)
DEFAULT(Scatter)
DEFAULT(Sigmoid)

View File

@@ -466,6 +466,17 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
}
}
void Round::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
unary_fp(in, out, RoundOp());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];

View File

@@ -53,6 +53,17 @@ struct SignOp {
}
};
struct RoundOp {
template <typename T>
T operator()(T x) {
return std::round(x);
}
complex64_t operator()(complex64_t x) {
return {std::round(x.real()), std::round(x.imag())};
}
};
template <typename T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();

View File

@@ -133,6 +133,11 @@ struct Negative {
template <typename T> T operator()(T x) { return -x; };
};
struct Round {
template <typename T> T operator()(T x) { return metal::round(x); };
template <> complex64_t operator()(complex64_t x) { return {metal::round(x.real), metal::round(x.imag)}; };
};
struct Sigmoid {
template <typename T>
T operator()(T x) {
@@ -300,6 +305,7 @@ instantiate_unary_float(sqrt, Sqrt)
instantiate_unary_float(rsqrt, Rsqrt)
instantiate_unary_float(tan, Tan)
instantiate_unary_float(tanh, Tanh)
instantiate_unary_float(round, Round)
instantiate_unary_all(abs, complex64, complex64_t, Abs)
instantiate_unary_all(cos, complex64, complex64_t, Cos)
@@ -310,5 +316,6 @@ instantiate_unary_all(sin, complex64, complex64_t, Sin)
instantiate_unary_all(sinh, complex64, complex64_t, Sinh)
instantiate_unary_all(tan, complex64, complex64_t, Tan)
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
instantiate_unary_all(round, complex64, complex64_t, Round)
instantiate_unary_all(lnot, bool_, bool, LogicalNot)

View File

@@ -563,6 +563,17 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (not is_integral(in.dtype())) {
unary_op(inputs, out, "round");
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sigmoid");
}

View File

@@ -61,6 +61,7 @@ NO_GPU(Power)
NO_GPU(RandomBits)
NO_GPU(Reduce)
NO_GPU(Reshape)
NO_GPU(Round)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(Sigmoid)