12template <
typename T,
typename U,
typename Op>
19 const T* a_ptr = a.data<T>();
20 const T* b_ptr = b.data<T>();
21 U* dst_a = out_a.data<U>();
22 U* dst_b = out_b.data<U>();
25 for (
size_t i = 0; i < out_a.size(); ++i) {
26 auto dst =
op(a_ptr[a_idx], b_ptr[b_idx]);
28 dst_b[i] = dst.second;
29 a_idx += a.strides()[0];
30 b_idx += b.strides()[0];
34template <
typename T,
typename U,
typename Op>
42 const T* a_ptr = a.data<T>();
43 const T* b_ptr = b.data<T>();
44 U* dst_a = out_a.data<U>();
45 U* dst_b = out_b.data<U>();
48 for (
size_t i = 0; i < a.shape()[0]; i++) {
49 op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
50 a_idx += a.strides()[0];
51 b_idx += b.strides()[0];
57template <
typename T,
typename U,
typename Op>
64 const T* a_ptr = a.data<T>();
65 const T* b_ptr = b.data<T>();
66 U* dst_a = out_a.data<U>();
67 U* dst_b = out_b.data<U>();
71 for (
size_t i = 0; i < a.shape()[0]; ++i) {
72 for (
size_t j = 0; j < a.shape()[1]; ++j) {
73 auto dst =
op(a_ptr[a_idx], b_ptr[b_idx]);
74 dst_a[out_idx] = dst.first;
75 dst_b[out_idx++] = dst.second;
76 a_idx += a.strides()[1];
77 b_idx += b.strides()[1];
79 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
80 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
84template <
typename T,
typename U,
typename Op>
92 const T* a_ptr = a.data<T>();
93 const T* b_ptr = b.data<T>();
94 U* dst_a = out_a.data<U>();
95 U* dst_b = out_b.data<U>();
98 for (
size_t i = 0; i < a.shape()[0]; ++i) {
99 for (
size_t j = 0; j < a.shape()[1]; ++j) {
100 op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
101 a_idx += a.strides()[1];
102 b_idx += b.strides()[1];
106 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
107 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
111template <
typename T,
typename U,
typename Op>
118 const T* a_ptr = a.data<T>();
119 const T* b_ptr = b.data<T>();
120 U* dst_a = out_a.data<U>();
121 U* dst_b = out_b.data<U>();
125 for (
size_t i = 0; i < a.shape()[0]; ++i) {
126 for (
size_t j = 0; j < a.shape()[1]; ++j) {
127 for (
size_t k = 0; k < a.shape()[2]; ++k) {
128 auto dst =
op(a_ptr[a_idx], b_ptr[b_idx]);
129 dst_a[out_idx] = dst.first;
130 dst_b[out_idx++] = dst.second;
131 a_idx += a.strides()[2];
132 b_idx += b.strides()[2];
134 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
135 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
137 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
138 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
142template <
typename T,
typename U,
typename Op>
149 const T* a_ptr = a.data<T>();
150 const T* b_ptr = b.data<T>();
151 U* dst_a = out_a.data<U>();
152 U* dst_b = out_b.data<U>();
156 for (
size_t i = 0; i < a.shape()[0]; ++i) {
157 for (
size_t j = 0; j < a.shape()[1]; ++j) {
158 for (
size_t k = 0; k < a.shape()[2]; ++k) {
159 for (
size_t ii = 0; ii < a.shape()[3]; ++ii) {
160 auto dst =
op(a_ptr[a_idx], b_ptr[b_idx]);
161 dst_a[out_idx] = dst.first;
162 dst_b[out_idx++] = dst.second;
163 a_idx += a.strides()[3];
164 b_idx += b.strides()[3];
166 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
167 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
169 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
170 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
172 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
173 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
177template <
typename T,
typename U,
typename Op>
178void binary_op_dispatch_dims(
184 switch (out_a.ndim()) {
186 binary_op_dims1<T, U, Op>(a, b, out_a, out_b,
op);
189 binary_op_dims2<T, U, Op>(a, b, out_a, out_b,
op);
192 binary_op_dims3<T, U, Op>(a, b, out_a, out_b,
op);
195 binary_op_dims4<T, U, Op>(a, b, out_a, out_b,
op);
199 const T* a_ptr = a.data<T>();
200 const T* b_ptr = b.data<T>();
201 U* dst_a = out_a.data<U>();
202 U* dst_b = out_b.data<U>();
203 for (
size_t i = 0; i < out_a.size(); i++) {
204 int a_idx =
elem_to_loc(i, a.shape(), a.strides());
205 int b_idx =
elem_to_loc(i, b.shape(), b.strides());
206 std::tie(dst_a[i], dst_b[i]) =
op(a_ptr[a_idx], b_ptr[b_idx]);
210template <
typename T,
typename U,
typename Op>
211void binary_op_dispatch_dims(
222 binary_op_dims1<T, U, Op>(a, b, out_a, out_b,
op, stride);
225 binary_op_dims2<T, U, Op>(a, b, out_a, out_b,
op, stride);
229 const T* a_ptr = a.data<T>();
230 const T* b_ptr = b.data<T>();
231 U* dst_a = out_a.data<U>();
232 U* dst_b = out_b.data<U>();
233 for (
size_t i = 0; i < out_a.size(); i += stride) {
234 int a_idx =
elem_to_loc(i, a.shape(), a.strides());
235 int b_idx =
elem_to_loc(i, b.shape(), b.strides());
236 op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
258 auto bopt = get_binary_op_type(a, b);
259 set_binary_op_output_data(a, b, out_a, bopt);
260 set_binary_op_output_data(a, b, out_b, bopt);
263 if (bopt == BinaryOpType::ScalarScalar) {
264 std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
265 op(*a.data<T>(), *b.data<T>());
270 if (bopt == BinaryOpType::ScalarVector) {
281 if (bopt == BinaryOpType::VectorScalar) {
292 if (bopt == BinaryOpType::VectorVector) {
305 auto& strides = out_a.strides();
306 auto leftmost_rc_dim = [&strides](
const array& arr) {
307 int d = arr.ndim() - 1;
308 for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
312 auto a_rc_dim = leftmost_rc_dim(a);
313 auto b_rc_dim = leftmost_rc_dim(b);
316 auto leftmost_s_dim = [](
const array& arr) {
317 int d = arr.ndim() - 1;
318 for (; d >= 0 && arr.strides()[d] == 0; d--) {
322 auto a_s_dim = leftmost_s_dim(a);
323 auto b_s_dim = leftmost_s_dim(b);
325 auto ndim = out_a.ndim();
329 if (
int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
330 bopt = BinaryOpType::VectorVector;
334 }
else if (
int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
335 bopt = BinaryOpType::VectorScalar;
339 }
else if (
int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
340 bopt = BinaryOpType::ScalarVector;
348 if (dim == 0 || strides[dim - 1] < 16) {
350 bopt = BinaryOpType::General;
353 stride = strides[dim - 1];
357 case BinaryOpType::VectorVector:
358 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
360 case BinaryOpType::VectorScalar:
361 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
363 case BinaryOpType::ScalarVector:
364 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
367 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b,
op);
372template <
typename T,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
376 std::vector<array>& outputs,
384 if (std::is_same<
decltype(opsv), UseDefaultBinaryOp>::value) {
385 if (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::value) {
386 if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
394 DefaultScalarVector<T, T, Op>(
op),
395 DefaultVectorScalar<T, T, Op>(
op),
396 DefaultVectorVector<T, T, Op>(
op));
405 DefaultScalarVector<T, T, Op>(
op),
406 DefaultVectorScalar<T, T, Op>(
op),
409 }
else if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
417 DefaultScalarVector<T, T, Op>(
op),
419 DefaultVectorVector<T, T, Op>(
op));
428 DefaultScalarVector<T, T, Op>(
op),
432 }
else if (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::value) {
433 if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
442 DefaultVectorScalar<T, T, Op>(
op),
443 DefaultVectorVector<T, T, Op>(
op));
453 DefaultVectorScalar<T, T, Op>(
op),
456 }
else if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
466 DefaultVectorVector<T, T, Op>(
op));
469 binary_op<T, T>(a, b, outputs[0], outputs[1],
op, opsv, opvs, opvv);
473template <
typename T,
typename Op>
477 std::vector<array>& outputs,
479 DefaultScalarVector<T, T, Op> opsv(
op);
480 DefaultVectorScalar<T, T, Op> opvs(
op);
481 DefaultVectorVector<T, T, Op> opvv(
op);
482 binary_op<T, T>(a, b, outputs[0], outputs[1],
op, opsv, opvs, opvv);
485template <
typename... Ops>
489 std::vector<array>& outputs,
491 switch (outputs[0].dtype()) {
493 binary_op<bool>(a, b, outputs, ops...);
496 binary_op<uint8_t>(a, b, outputs, ops...);
499 binary_op<uint16_t>(a, b, outputs, ops...);
502 binary_op<uint32_t>(a, b, outputs, ops...);
505 binary_op<uint64_t>(a, b, outputs, ops...);
508 binary_op<int8_t>(a, b, outputs, ops...);
511 binary_op<int16_t>(a, b, outputs, ops...);
514 binary_op<int32_t>(a, b, outputs, ops...);
517 binary_op<int64_t>(a, b, outputs, ops...);
520 binary_op<float16_t>(a, b, outputs, ops...);
523 binary_op<float>(a, b, outputs, ops...);
526 binary_op<bfloat16_t>(a, b, outputs, ops...);
529 binary_op<complex64_t>(a, b, outputs, ops...);
Op op
Definition binary.h:141
constexpr Dtype bool_
Definition dtype.h:60
constexpr Dtype uint64
Definition dtype.h:65
constexpr Dtype uint16
Definition dtype.h:63
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
constexpr Dtype bfloat16
Definition dtype.h:74
constexpr Dtype int32
Definition dtype.h:69
constexpr Dtype float32
Definition dtype.h:73
constexpr Dtype int16
Definition dtype.h:68
constexpr Dtype int8
Definition dtype.h:67
constexpr Dtype int64
Definition dtype.h:70
constexpr Dtype uint8
Definition dtype.h:62
constexpr Dtype float16
Definition dtype.h:72
constexpr Dtype uint32
Definition dtype.h:64
constexpr Dtype complex64
Definition dtype.h:75