14void set_unary_output_data(
const array& in, array& out) {
15 if (in.is_donatable() && in.itemsize() == out.itemsize()) {
16 out.copy_shared_buffer(in);
18 auto size = in.data_size();
27template <
typename T,
typename Op>
28void unary_op(
const array& a, array& out, Op
op) {
29 const T* a_ptr = a.data<T>();
30 if (a.flags().contiguous) {
31 set_unary_output_data(a, out);
32 T* dst = out.data<T>();
33 for (
size_t i = 0; i < a.data_size(); ++i) {
34 dst[i] =
op(a_ptr[i]);
38 T* dst = out.data<T>();
39 for (
size_t i = 0; i < out.size(); ++i) {
42 dst[i] =
op(a_ptr[a_idx]);
48void unary(
const array& a, array& out, Op
op) {
49 switch (out.dtype()) {
51 unary_op<bool>(a, out,
op);
54 unary_op<uint8_t>(a, out,
op);
57 unary_op<uint16_t>(a, out,
op);
60 unary_op<uint32_t>(a, out,
op);
63 unary_op<uint64_t>(a, out,
op);
66 unary_op<int8_t>(a, out,
op);
69 unary_op<int16_t>(a, out,
op);
72 unary_op<int32_t>(a, out,
op);
75 unary_op<int64_t>(a, out,
op);
78 unary_op<float16_t>(a, out,
op);
81 unary_op<float>(a, out,
op);
84 unary_op<bfloat16_t>(a, out,
op);
87 unary_op<complex64_t>(a, out,
op);
93void unary_fp(
const array& a, array& out, Op
op) {
94 switch (out.dtype()) {
96 unary_op<bfloat16_t>(a, out,
op);
99 unary_op<float16_t>(a, out,
op);
102 unary_op<float>(a, out,
op);
105 unary_op<complex64_t>(a, out,
op);
108 std::ostringstream err;
109 err <<
"[unary_fp] Does not support " << out.dtype();
110 throw std::runtime_error(err.str());
Op op
Definition binary.h:141
Buffer malloc_or_wait(size_t size)
constexpr Dtype bool_
Definition dtype.h:58
constexpr Dtype uint64
Definition dtype.h:63
constexpr Dtype uint16
Definition dtype.h:61
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
constexpr Dtype bfloat16
Definition dtype.h:72
constexpr Dtype int32
Definition dtype.h:67
constexpr Dtype float32
Definition dtype.h:71
constexpr Dtype int16
Definition dtype.h:66
constexpr Dtype int8
Definition dtype.h:65
constexpr Dtype int64
Definition dtype.h:68
constexpr Dtype uint8
Definition dtype.h:60
constexpr Dtype float16
Definition dtype.h:70
constexpr Dtype uint32
Definition dtype.h:62
constexpr Dtype complex64
Definition dtype.h:73