14enum class BinaryOpType {
22BinaryOpType get_binary_op_type(
const array& a,
const array& b) {
24 if (a.data_size() == 1 && b.data_size() == 1) {
25 bopt = BinaryOpType::ScalarScalar;
26 }
else if (a.data_size() == 1 && b.flags().contiguous) {
27 bopt = BinaryOpType::ScalarVector;
28 }
else if (b.data_size() == 1 && a.flags().contiguous) {
29 bopt = BinaryOpType::VectorScalar;
31 a.flags().row_contiguous && b.flags().row_contiguous ||
32 a.flags().col_contiguous && b.flags().col_contiguous) {
33 bopt = BinaryOpType::VectorVector;
35 bopt = BinaryOpType::General;
40void set_binary_op_output_data(
45 bool donate_with_move =
false) {
49 case BinaryOpType::ScalarScalar:
53 case BinaryOpType::ScalarVector:
55 if (donate_with_move) {
56 out.move_shared_buffer(b);
58 out.copy_shared_buffer(b);
68 case BinaryOpType::VectorScalar:
70 if (donate_with_move) {
71 out.move_shared_buffer(a);
73 out.copy_shared_buffer(a);
83 case BinaryOpType::VectorVector:
85 if (donate_with_move) {
86 out.move_shared_buffer(a);
88 out.copy_shared_buffer(a);
90 }
else if (b_donatable) {
91 if (donate_with_move) {
92 out.move_shared_buffer(b);
94 out.copy_shared_buffer(b);
104 case BinaryOpType::General:
105 if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
106 if (donate_with_move) {
107 out.move_shared_buffer(a);
109 out.copy_shared_buffer(a);
112 b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
113 if (donate_with_move) {
114 out.move_shared_buffer(b);
116 out.copy_shared_buffer(b);
125struct UseDefaultBinaryOp {
126 template <
typename T,
typename U>
127 void operator()(
const T* a,
const T* b, U* dst,
int size) {
132 template <
typename T,
typename U>
133 void operator()(
const T* a,
const T* b, U* dst_a, U* dst_b,
int size) {
139template <
typename T,
typename U,
typename Op>
140struct DefaultVectorScalar {
143 DefaultVectorScalar(Op op_) :
op(op_) {}
145 void operator()(
const T* a,
const T* b, U* dst,
int size) {
148 *dst =
op(*a, scalar);
154 void operator()(
const T* a,
const T* b, U* dst_a, U* dst_b,
int size) {
157 auto dst =
op(*a, scalar);
167template <
typename T,
typename U,
typename Op>
168struct DefaultScalarVector {
171 DefaultScalarVector(Op op_) :
op(op_) {}
173 void operator()(
const T* a,
const T* b, U* dst,
int size) {
176 *dst =
op(scalar, *b);
182 void operator()(
const T* a,
const T* b, U* dst_a, U* dst_b,
int size) {
185 auto dst =
op(scalar, *b);
195template <
typename T,
typename U,
typename Op>
196struct DefaultVectorVector {
199 DefaultVectorVector(Op op_) :
op(op_) {}
201 void operator()(
const T* a,
const T* b, U* dst,
int size) {
210 void operator()(
const T* a,
const T* b, U* dst_a, U* dst_b,
int size) {
212 auto dst =
op(*a, *b);
223template <
typename T,
typename U,
typename Op>
224void binary_op_dims1(
const array& a,
const array& b, array& out, Op
op) {
225 const T* a_ptr = a.data<T>();
226 const T* b_ptr = b.data<T>();
227 U* dst = out.data<U>();
230 for (
size_t i = 0; i < out.size(); ++i) {
231 dst[i] =
op(a_ptr[a_idx], b_ptr[b_idx]);
232 a_idx += a.strides()[0];
233 b_idx += b.strides()[0];
237template <
typename T,
typename U,
typename Op>
244 const T* a_ptr = a.data<T>();
245 const T* b_ptr = b.data<T>();
246 U* dst = out.data<U>();
249 for (
size_t i = 0; i < a.shape()[0]; i++) {
250 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
251 a_idx += a.strides()[0];
252 b_idx += b.strides()[0];
257template <
typename T,
typename U,
typename Op>
258void binary_op_dims2(
const array& a,
const array& b, array& out, Op
op) {
259 const T* a_ptr = a.data<T>();
260 const T* b_ptr = b.data<T>();
261 U* dst = out.data<U>();
265 for (
size_t i = 0; i < a.shape()[0]; ++i) {
266 for (
size_t j = 0; j < a.shape()[1]; ++j) {
267 dst[out_idx++] =
op(a_ptr[a_idx], b_ptr[b_idx]);
268 a_idx += a.strides()[1];
269 b_idx += b.strides()[1];
271 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
272 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
276template <
typename T,
typename U,
typename Op>
283 const T* a_ptr = a.data<T>();
284 const T* b_ptr = b.data<T>();
285 U* dst = out.data<U>();
288 for (
size_t i = 0; i < a.shape()[0]; ++i) {
289 for (
size_t j = 0; j < a.shape()[1]; ++j) {
290 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
291 a_idx += a.strides()[1];
292 b_idx += b.strides()[1];
295 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
296 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
300template <
typename T,
typename U,
typename Op>
301void binary_op_dims3(
const array& a,
const array& b, array& out, Op
op) {
302 const T* a_ptr = a.data<T>();
303 const T* b_ptr = b.data<T>();
304 U* dst = out.data<U>();
308 for (
size_t i = 0; i < a.shape()[0]; ++i) {
309 for (
size_t j = 0; j < a.shape()[1]; ++j) {
310 for (
size_t k = 0; k < a.shape()[2]; ++k) {
311 dst[out_idx++] =
op(a_ptr[a_idx], b_ptr[b_idx]);
312 a_idx += a.strides()[2];
313 b_idx += b.strides()[2];
315 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
316 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
318 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
319 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
323template <
typename T,
typename U,
typename Op>
324void binary_op_dims4(
const array& a,
const array& b, array& out, Op
op) {
325 const T* a_ptr = a.data<T>();
326 const T* b_ptr = b.data<T>();
327 U* dst = out.data<U>();
331 for (
size_t i = 0; i < a.shape()[0]; ++i) {
332 for (
size_t j = 0; j < a.shape()[1]; ++j) {
333 for (
size_t k = 0; k < a.shape()[2]; ++k) {
334 for (
size_t ii = 0; ii < a.shape()[3]; ++ii) {
335 dst[out_idx++] =
op(a_ptr[a_idx], b_ptr[b_idx]);
336 a_idx += a.strides()[3];
337 b_idx += b.strides()[3];
339 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
340 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
342 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
343 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
345 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
346 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
350template <
typename T,
typename U,
typename Op>
351void binary_op_dispatch_dims(
356 switch (out.ndim()) {
358 binary_op_dims1<T, U, Op>(a, b, out,
op);
361 binary_op_dims2<T, U, Op>(a, b, out,
op);
364 binary_op_dims3<T, U, Op>(a, b, out,
op);
367 binary_op_dims4<T, U, Op>(a, b, out,
op);
371 const T* a_ptr = a.data<T>();
372 const T* b_ptr = b.data<T>();
373 U* dst = out.data<U>();
374 for (
size_t i = 0; i < out.size(); i++) {
375 int a_idx =
elem_to_loc(i, a.shape(), a.strides());
376 int b_idx =
elem_to_loc(i, b.shape(), b.strides());
377 dst[i] =
op(a_ptr[a_idx], b_ptr[b_idx]);
381template <
typename T,
typename U,
typename Op>
382void binary_op_dispatch_dims(
392 binary_op_dims1<T, U, Op>(a, b, out,
op, stride);
395 binary_op_dims2<T, U, Op>(a, b, out,
op, stride);
399 const T* a_ptr = a.data<T>();
400 const T* b_ptr = b.data<T>();
401 U* dst = out.data<U>();
402 for (
size_t i = 0; i < out.size(); i += stride) {
403 int a_idx =
elem_to_loc(i, a.shape(), a.strides());
404 int b_idx =
elem_to_loc(i, b.shape(), b.strides());
405 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
425 auto bopt = get_binary_op_type(a, b);
426 set_binary_op_output_data(a, b, out, bopt);
429 if (bopt == BinaryOpType::ScalarScalar) {
430 *(out.data<U>()) =
op(*a.data<T>(), *b.data<T>());
435 if (bopt == BinaryOpType::ScalarVector) {
436 opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
441 if (bopt == BinaryOpType::VectorScalar) {
442 opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
447 if (bopt == BinaryOpType::VectorVector) {
448 opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
455 auto& strides = out.strides();
456 auto leftmost_rc_dim = [&strides](
const array& arr) {
457 int d = arr.ndim() - 1;
458 for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
462 auto a_rc_dim = leftmost_rc_dim(a);
463 auto b_rc_dim = leftmost_rc_dim(b);
466 auto leftmost_s_dim = [](
const array& arr) {
467 int d = arr.ndim() - 1;
468 for (; d >= 0 && arr.strides()[d] == 0; d--) {
472 auto a_s_dim = leftmost_s_dim(a);
473 auto b_s_dim = leftmost_s_dim(b);
475 auto ndim = out.ndim();
479 if (
int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
480 bopt = BinaryOpType::VectorVector;
484 }
else if (
int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
485 bopt = BinaryOpType::VectorScalar;
489 }
else if (
int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
490 bopt = BinaryOpType::ScalarVector;
498 if (dim == 0 || strides[dim - 1] < 16) {
500 bopt = BinaryOpType::General;
503 stride = strides[dim - 1];
507 case BinaryOpType::VectorVector:
508 binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
510 case BinaryOpType::VectorScalar:
511 binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
513 case BinaryOpType::ScalarVector:
514 binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
517 binary_op_dispatch_dims<T, U>(a, b, out,
op);
522template <
typename T,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
534 if (std::is_same<
decltype(opsv), UseDefaultBinaryOp>::value) {
535 if (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::value) {
536 if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
543 DefaultScalarVector<T, T, Op>(
op),
544 DefaultVectorScalar<T, T, Op>(
op),
545 DefaultVectorVector<T, T, Op>(
op));
553 DefaultScalarVector<T, T, Op>(
op),
554 DefaultVectorScalar<T, T, Op>(
op),
557 }
else if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
564 DefaultScalarVector<T, T, Op>(
op),
566 DefaultVectorVector<T, T, Op>(
op));
570 a, b, out,
op, DefaultScalarVector<T, T, Op>(
op), opvs, opvv);
572 }
else if (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::value) {
573 if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
581 DefaultVectorScalar<T, T, Op>(
op),
582 DefaultVectorVector<T, T, Op>(
op));
586 a, b, out,
op, opsv, DefaultVectorScalar<T, T, Op>(
op), opvv);
588 }
else if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
591 a, b, out,
op, opsv, opvs, DefaultVectorVector<T, T, Op>(
op));
594 binary_op<T, T>(a, b, out,
op, opsv, opvs, opvv);
598template <
typename T,
typename Op>
599void binary_op(
const array& a,
const array& b, array& out, Op
op) {
600 DefaultScalarVector<T, T, Op> opsv(
op);
601 DefaultVectorScalar<T, T, Op> opvs(
op);
602 DefaultVectorVector<T, T, Op> opvv(
op);
603 binary_op<T, T>(a, b, out,
op, opsv, opvs, opvv);
606template <
typename... Ops>
607void binary(
const array& a,
const array& b, array& out, Ops... ops) {
608 switch (out.dtype()) {
610 binary_op<bool>(a, b, out, ops...);
613 binary_op<uint8_t>(a, b, out, ops...);
616 binary_op<uint16_t>(a, b, out, ops...);
619 binary_op<uint32_t>(a, b, out, ops...);
622 binary_op<uint64_t>(a, b, out, ops...);
625 binary_op<int8_t>(a, b, out, ops...);
628 binary_op<int16_t>(a, b, out, ops...);
631 binary_op<int32_t>(a, b, out, ops...);
634 binary_op<int64_t>(a, b, out, ops...);
637 binary_op<float16_t>(a, b, out, ops...);
640 binary_op<float>(a, b, out, ops...);
643 binary_op<bfloat16_t>(a, b, out, ops...);
646 binary_op<complex64_t>(a, b, out, ops...);
Op op
Definition binary.h:141
Buffer malloc_or_wait(size_t size)
constexpr Dtype bool_
Definition dtype.h:58
constexpr Dtype uint64
Definition dtype.h:63
constexpr Dtype uint16
Definition dtype.h:61
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
constexpr Dtype bfloat16
Definition dtype.h:72
constexpr Dtype int32
Definition dtype.h:67
constexpr Dtype float32
Definition dtype.h:71
constexpr Dtype int16
Definition dtype.h:66
constexpr Dtype int8
Definition dtype.h:65
constexpr Dtype int64
Definition dtype.h:68
constexpr Dtype uint8
Definition dtype.h:60
constexpr Dtype float16
Definition dtype.h:70
constexpr Dtype uint32
Definition dtype.h:62
bool is_donatable(const array &in, const array &out)
Definition utils.h:158
constexpr Dtype complex64
Definition dtype.h:73