12template <
typename T,
typename U,
typename Op,
int D>
24 auto stride_a = a_strides[axis];
25 auto stride_b = b_strides[axis];
26 auto stride_out = out_strides[axis];
29 for (
int i = 0; i <
N; i++) {
30 if constexpr (D > 1) {
43 std::tie(*out_a, *out_b) = op(*a, *b);
52template <
typename T,
typename U,
typename Op>
60 a.shape(), {a.strides(), b.strides(), out_a.strides()});
61 const auto& a_strides = strides[0];
62 const auto& b_strides = strides[1];
63 const auto& out_strides = strides[2];
64 const T* a_ptr = a.data<T>();
65 const T* b_ptr = b.data<T>();
66 U* out_a_ptr = out_a.data<U>();
67 U* out_b_ptr = out_b.data<U>();
69 int ndim = shape.size();
101 auto stride = out_strides[ndim - 3];
102 for (
size_t elem = 0; elem < a.size(); elem += stride) {
119template <
typename T,
typename U = T,
typename Op>
123 std::vector<array>& outputs,
126 auto& out_a = outputs[0];
127 auto& out_b = outputs[1];
137 auto a_ptr = a.data<T>();
138 auto b_ptr = b.data<T>();
139 auto out_a_ptr = out_a.data<U>();
140 auto out_b_ptr = out_b.data<U>();
142 std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
144 for (
size_t i = 0; i < b.size(); ++i) {
145 std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
151 for (
size_t i = 0; i < a.size(); ++i) {
152 std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
158 for (
size_t i = 0; i < a.size(); ++i) {
159 std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
168template <
typename Op>
172 std::vector<array>& outputs,
174 switch (outputs[0].dtype()) {
constexpr int N
Definition neon_fp16_simd.h:9
constexpr Dtype bool_
Definition dtype.h:68
constexpr Dtype uint64
Definition dtype.h:73
BinaryOpType get_binary_op_type(const array &a, const array &b)
Definition binary.h:19
constexpr Dtype uint16
Definition dtype.h:71
constexpr Dtype float64
Definition dtype.h:82
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
constexpr Dtype bfloat16
Definition dtype.h:83
@ General
Definition binary.h:16
@ ScalarScalar
Definition binary.h:12
@ VectorScalar
Definition binary.h:14
@ ScalarVector
Definition binary.h:13
constexpr Dtype int32
Definition dtype.h:77
void binary_op_dispatch_dims(const array &a, const array &b, array &out, Op op, int dim, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides)
Definition binary.h:122
constexpr Dtype float32
Definition dtype.h:81
std::vector< ShapeElem > Shape
Definition array.h:21
void set_binary_op_output_data(const array &a, const array &b, array &out, BinaryOpType bopt, bool donate_with_move=false)
Definition binary.h:37
constexpr Dtype int16
Definition dtype.h:76
std::vector< int64_t > Strides
Definition array.h:22
void binary_op_dims(const T *a, const T *b, U *out, Op op, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides, int axis)
Definition binary.h:89
constexpr Dtype int8
Definition dtype.h:75
constexpr Dtype int64
Definition dtype.h:78
constexpr Dtype uint8
Definition dtype.h:70
void binary_op(const array &a, const array &b, array &out, Op op)
Definition binary.h:194
constexpr Dtype float16
Definition dtype.h:80
constexpr Dtype uint32
Definition dtype.h:72
void binary(const array &a, const array &b, array &out, Op op)
Definition binary.h:326
constexpr Dtype complex64
Definition dtype.h:84