14void set_unary_output_data(
const array& in, array& out) {
16 out.copy_shared_buffer(in);
18 auto size = in.data_size();
27template <
typename T,
typename U = T,
typename Op>
28void unary_op(
const T* a, U* out, Op
op,
size_t shape,
size_t stride) {
29 for (
size_t i = 0; i < shape; i += 1) {
35template <
typename T,
typename U = T,
typename Op>
36void unary_op(
const array& a, array& out, Op
op) {
37 const T* a_ptr = a.data<T>();
38 if (a.flags().contiguous) {
39 set_unary_output_data(a, out);
40 U* dst = out.data<U>();
41 for (
size_t i = 0; i < a.data_size(); ++i) {
42 dst[i] =
op(a_ptr[i]);
46 U* dst = out.data<U>();
47 size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
48 size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
50 unary_op(a_ptr, dst,
op, shape, stride);
53 ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
54 for (
size_t elem = 0; elem < a.size(); elem += shape) {
55 unary_op(a_ptr + it.loc, dst + elem,
op, shape, stride);
62void unary(
const array& a, array& out, Op
op) {
63 switch (out.dtype()) {
65 unary_op<bool>(a, out,
op);
68 unary_op<uint8_t>(a, out,
op);
71 unary_op<uint16_t>(a, out,
op);
74 unary_op<uint32_t>(a, out,
op);
77 unary_op<uint64_t>(a, out,
op);
80 unary_op<int8_t>(a, out,
op);
83 unary_op<int16_t>(a, out,
op);
86 unary_op<int32_t>(a, out,
op);
89 unary_op<int64_t>(a, out,
op);
92 unary_op<float16_t>(a, out,
op);
95 unary_op<float>(a, out,
op);
98 unary_op<bfloat16_t>(a, out,
op);
101 unary_op<complex64_t>(a, out,
op);
106template <
typename Op>
107void unary_fp(
const array& a, array& out, Op
op) {
108 switch (out.dtype()) {
110 unary_op<bfloat16_t>(a, out,
op);
113 unary_op<float16_t>(a, out,
op);
116 unary_op<float>(a, out,
op);
119 unary_op<complex64_t>(a, out,
op);
122 std::ostringstream err;
123 err <<
"[unary_fp] Does not support " << out.dtype();
124 throw std::runtime_error(err.str());
Op op
Definition binary.h:129
Buffer malloc_or_wait(size_t size)
constexpr Dtype bool_
Definition dtype.h:67
constexpr Dtype uint64
Definition dtype.h:72
constexpr Dtype uint16
Definition dtype.h:70
constexpr Dtype bfloat16
Definition dtype.h:81
constexpr Dtype int32
Definition dtype.h:76
constexpr Dtype float32
Definition dtype.h:80
constexpr Dtype int16
Definition dtype.h:75
constexpr Dtype int8
Definition dtype.h:74
constexpr Dtype int64
Definition dtype.h:77
constexpr Dtype uint8
Definition dtype.h:69
constexpr Dtype float16
Definition dtype.h:79
constexpr Dtype uint32
Definition dtype.h:71
bool is_donatable(const array &in, const array &out)
Definition utils.h:174
constexpr Dtype complex64
Definition dtype.h:82