12template <
typename T,
typename U,
typename Op,
int D>
19 const std::vector<int>& shape,
20 const std::vector<size_t>& a_strides,
21 const std::vector<size_t>& b_strides,
22 const std::vector<size_t>& out_strides,
24 auto stride_a = a_strides[axis];
25 auto stride_b = b_strides[axis];
26 auto stride_out = out_strides[axis];
29 for (
int i = 0; i < N; i++) {
30 if constexpr (D > 1) {
31 binary_op_dims<T, U, Op, D - 1>(
43 std::tie(*out_a, *out_b) =
op(*a, *b);
52template <
typename T,
typename U,
typename Op>
53void binary_op_dispatch_dims(
60 a.shape(), {a.strides(), b.strides(), out_a.strides()});
61 const auto& a_strides = strides[0];
62 const auto& b_strides = strides[1];
63 const auto& out_strides = strides[2];
64 const T* a_ptr = a.data<T>();
65 const T* b_ptr = b.data<T>();
66 U* out_a_ptr = out_a.data<U>();
67 U* out_b_ptr = out_b.data<U>();
69 int ndim = shape.size();
72 binary_op_dims<T, U, Op, 1>(
85 binary_op_dims<T, U, Op, 2>(
99 ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
100 ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
101 size_t stride = out_strides[ndim - 3];
102 for (
size_t elem = 0; elem < a.size(); elem += stride) {
103 binary_op_dims<T, U, Op, 2>(
119template <
typename T,
typename U = T,
typename Op>
123 std::vector<array>& outputs,
125 auto bopt = get_binary_op_type(a, b);
126 auto& out_a = outputs[0];
127 auto& out_b = outputs[1];
128 set_binary_op_output_data(a, b, out_a, bopt);
129 set_binary_op_output_data(a, b, out_b, bopt);
132 if (bopt == BinaryOpType::General) {
133 binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b,
op);
137 auto a_ptr = a.data<T>();
138 auto b_ptr = b.data<T>();
139 auto out_a_ptr = out_a.data<U>();
140 auto out_b_ptr = out_b.data<U>();
141 if (bopt == BinaryOpType::ScalarScalar) {
142 std::tie(*out_a_ptr, *out_b_ptr) =
op(*a_ptr, *b_ptr);
143 }
else if (bopt == BinaryOpType::ScalarVector) {
144 for (
size_t i = 0; i < b.size(); ++i) {
145 std::tie(*out_a_ptr, *out_b_ptr) =
op(*a_ptr, *b_ptr);
150 }
else if (bopt == BinaryOpType::VectorScalar) {
151 for (
size_t i = 0; i < a.size(); ++i) {
152 std::tie(*out_a_ptr, *out_b_ptr) =
op(*a_ptr, *b_ptr);
158 for (
size_t i = 0; i < a.size(); ++i) {
159 std::tie(*out_a_ptr, *out_b_ptr) =
op(*a_ptr, *b_ptr);
168template <
typename Op>
172 std::vector<array>& outputs,
174 switch (outputs[0].dtype()) {
176 binary_op<bool>(a, b, outputs,
op);
179 binary_op<uint8_t>(a, b, outputs,
op);
182 binary_op<uint16_t>(a, b, outputs,
op);
185 binary_op<uint32_t>(a, b, outputs,
op);
188 binary_op<uint64_t>(a, b, outputs,
op);
191 binary_op<int8_t>(a, b, outputs,
op);
194 binary_op<int16_t>(a, b, outputs,
op);
197 binary_op<int32_t>(a, b, outputs,
op);
200 binary_op<int64_t>(a, b, outputs,
op);
203 binary_op<float16_t>(a, b, outputs,
op);
206 binary_op<float>(a, b, outputs,
op);
209 binary_op<bfloat16_t>(a, b, outputs,
op);
212 binary_op<complex64_t>(a, b, outputs,
op);
Op op
Definition binary.h:129
constexpr Dtype bool_
Definition dtype.h:67
constexpr Dtype uint64
Definition dtype.h:72
constexpr Dtype uint16
Definition dtype.h:70
std::tuple< std::vector< int >, std::vector< std::vector< int64_t > > > collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< int64_t > > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
constexpr Dtype bfloat16
Definition dtype.h:81
constexpr Dtype int32
Definition dtype.h:76
constexpr Dtype float32
Definition dtype.h:80
constexpr Dtype int16
Definition dtype.h:75
constexpr Dtype int8
Definition dtype.h:74
constexpr Dtype int64
Definition dtype.h:77
constexpr Dtype uint8
Definition dtype.h:69
constexpr Dtype float16
Definition dtype.h:79
constexpr Dtype uint32
Definition dtype.h:71
constexpr Dtype complex64
Definition dtype.h:82