12enum class BinaryOpType {
20BinaryOpType get_binary_op_type(
const array& a,
const array& b) {
22 if (a.data_size() == 1 && b.data_size() == 1) {
23 bopt = BinaryOpType::ScalarScalar;
24 }
else if (a.data_size() == 1 && b.flags().contiguous) {
25 bopt = BinaryOpType::ScalarVector;
26 }
else if (b.data_size() == 1 && a.flags().contiguous) {
27 bopt = BinaryOpType::VectorScalar;
29 a.flags().row_contiguous && b.flags().row_contiguous ||
30 a.flags().col_contiguous && b.flags().col_contiguous) {
31 bopt = BinaryOpType::VectorVector;
33 bopt = BinaryOpType::General;
38void set_binary_op_output_data(
43 bool donate_with_move =
false) {
45 case BinaryOpType::ScalarScalar:
49 case BinaryOpType::ScalarVector:
50 if (b.is_donatable() && b.itemsize() == out.itemsize()) {
51 if (donate_with_move) {
52 out.move_shared_buffer(b);
54 out.copy_shared_buffer(b);
64 case BinaryOpType::VectorScalar:
65 if (a.is_donatable() && a.itemsize() == out.itemsize()) {
66 if (donate_with_move) {
67 out.move_shared_buffer(a);
69 out.copy_shared_buffer(a);
79 case BinaryOpType::VectorVector:
80 if (a.is_donatable() && a.itemsize() == out.itemsize()) {
81 if (donate_with_move) {
82 out.move_shared_buffer(a);
84 out.copy_shared_buffer(a);
86 }
else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
87 if (donate_with_move) {
88 out.move_shared_buffer(b);
90 out.copy_shared_buffer(b);
100 case BinaryOpType::General:
101 if (a.is_donatable() && a.flags().row_contiguous &&
102 a.itemsize() == out.itemsize() && a.size() == out.size()) {
103 if (donate_with_move) {
104 out.move_shared_buffer(a);
106 out.copy_shared_buffer(a);
109 b.is_donatable() && b.flags().row_contiguous &&
110 b.itemsize() == out.itemsize() && b.size() == out.size()) {
111 if (donate_with_move) {
112 out.move_shared_buffer(b);
114 out.copy_shared_buffer(b);
123struct UseDefaultBinaryOp {
124 template <
typename T,
typename U>
125 void operator()(
const T* a,
const T* b, U* dst,
int size) {
130 template <
typename T,
typename U>
131 void operator()(
const T* a,
const T* b, U* dst_a, U* dst_b,
int size) {
137template <
typename T,
typename U,
typename Op>
138struct DefaultVectorScalar {
141 DefaultVectorScalar(Op op_) :
op(op_) {}
143 void operator()(
const T* a,
const T* b, U* dst,
int size) {
146 *dst =
op(*a, scalar);
152 void operator()(
const T* a,
const T* b, U* dst_a, U* dst_b,
int size) {
155 auto dst =
op(*a, scalar);
165template <
typename T,
typename U,
typename Op>
166struct DefaultScalarVector {
169 DefaultScalarVector(Op op_) :
op(op_) {}
171 void operator()(
const T* a,
const T* b, U* dst,
int size) {
174 *dst =
op(scalar, *b);
180 void operator()(
const T* a,
const T* b, U* dst_a, U* dst_b,
int size) {
183 auto dst =
op(scalar, *b);
193template <
typename T,
typename U,
typename Op>
194struct DefaultVectorVector {
197 DefaultVectorVector(Op op_) :
op(op_) {}
199 void operator()(
const T* a,
const T* b, U* dst,
int size) {
208 void operator()(
const T* a,
const T* b, U* dst_a, U* dst_b,
int size) {
210 auto dst =
op(*a, *b);
221template <
typename T,
typename U,
typename Op>
222void binary_op_dims1(
const array& a,
const array& b, array& out, Op
op) {
223 const T* a_ptr = a.data<T>();
224 const T* b_ptr = b.data<T>();
225 U* dst = out.data<U>();
228 for (
size_t i = 0; i < out.size(); ++i) {
229 dst[i] =
op(a_ptr[a_idx], b_ptr[b_idx]);
230 a_idx += a.strides()[0];
231 b_idx += b.strides()[0];
235template <
typename T,
typename U,
typename Op>
242 const T* a_ptr = a.data<T>();
243 const T* b_ptr = b.data<T>();
244 U* dst = out.data<U>();
247 for (
size_t i = 0; i < a.shape()[0]; i++) {
248 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
249 a_idx += a.strides()[0];
250 b_idx += b.strides()[0];
255template <
typename T,
typename U,
typename Op>
256void binary_op_dims2(
const array& a,
const array& b, array& out, Op
op) {
257 const T* a_ptr = a.data<T>();
258 const T* b_ptr = b.data<T>();
259 U* dst = out.data<U>();
263 for (
size_t i = 0; i < a.shape()[0]; ++i) {
264 for (
size_t j = 0; j < a.shape()[1]; ++j) {
265 dst[out_idx++] =
op(a_ptr[a_idx], b_ptr[b_idx]);
266 a_idx += a.strides()[1];
267 b_idx += b.strides()[1];
269 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
270 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
274template <
typename T,
typename U,
typename Op>
281 const T* a_ptr = a.data<T>();
282 const T* b_ptr = b.data<T>();
283 U* dst = out.data<U>();
286 for (
size_t i = 0; i < a.shape()[0]; ++i) {
287 for (
size_t j = 0; j < a.shape()[1]; ++j) {
288 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
289 a_idx += a.strides()[1];
290 b_idx += b.strides()[1];
293 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
294 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
298template <
typename T,
typename U,
typename Op>
299void binary_op_dims3(
const array& a,
const array& b, array& out, Op
op) {
300 const T* a_ptr = a.data<T>();
301 const T* b_ptr = b.data<T>();
302 U* dst = out.data<U>();
306 for (
size_t i = 0; i < a.shape()[0]; ++i) {
307 for (
size_t j = 0; j < a.shape()[1]; ++j) {
308 for (
size_t k = 0; k < a.shape()[2]; ++k) {
309 dst[out_idx++] =
op(a_ptr[a_idx], b_ptr[b_idx]);
310 a_idx += a.strides()[2];
311 b_idx += b.strides()[2];
313 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
314 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
316 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
317 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
321template <
typename T,
typename U,
typename Op>
322void binary_op_dims4(
const array& a,
const array& b, array& out, Op
op) {
323 const T* a_ptr = a.data<T>();
324 const T* b_ptr = b.data<T>();
325 U* dst = out.data<U>();
329 for (
size_t i = 0; i < a.shape()[0]; ++i) {
330 for (
size_t j = 0; j < a.shape()[1]; ++j) {
331 for (
size_t k = 0; k < a.shape()[2]; ++k) {
332 for (
size_t ii = 0; ii < a.shape()[3]; ++ii) {
333 dst[out_idx++] =
op(a_ptr[a_idx], b_ptr[b_idx]);
334 a_idx += a.strides()[3];
335 b_idx += b.strides()[3];
337 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
338 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
340 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
341 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
343 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
344 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
348template <
typename T,
typename U,
typename Op>
349void binary_op_dispatch_dims(
354 switch (out.ndim()) {
356 binary_op_dims1<T, U, Op>(a, b, out,
op);
359 binary_op_dims2<T, U, Op>(a, b, out,
op);
362 binary_op_dims3<T, U, Op>(a, b, out,
op);
365 binary_op_dims4<T, U, Op>(a, b, out,
op);
369 const T* a_ptr = a.data<T>();
370 const T* b_ptr = b.data<T>();
371 U* dst = out.data<U>();
372 for (
size_t i = 0; i < out.size(); i++) {
373 int a_idx =
elem_to_loc(i, a.shape(), a.strides());
374 int b_idx =
elem_to_loc(i, b.shape(), b.strides());
375 dst[i] =
op(a_ptr[a_idx], b_ptr[b_idx]);
379template <
typename T,
typename U,
typename Op>
380void binary_op_dispatch_dims(
390 binary_op_dims1<T, U, Op>(a, b, out,
op, stride);
393 binary_op_dims2<T, U, Op>(a, b, out,
op, stride);
397 const T* a_ptr = a.data<T>();
398 const T* b_ptr = b.data<T>();
399 U* dst = out.data<U>();
400 for (
size_t i = 0; i < out.size(); i += stride) {
401 int a_idx =
elem_to_loc(i, a.shape(), a.strides());
402 int b_idx =
elem_to_loc(i, b.shape(), b.strides());
403 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
423 auto bopt = get_binary_op_type(a, b);
424 set_binary_op_output_data(a, b, out, bopt);
427 if (bopt == BinaryOpType::ScalarScalar) {
428 *(out.data<U>()) =
op(*a.data<T>(), *b.data<T>());
433 if (bopt == BinaryOpType::ScalarVector) {
434 opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
439 if (bopt == BinaryOpType::VectorScalar) {
440 opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
445 if (bopt == BinaryOpType::VectorVector) {
446 opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
453 auto& strides = out.strides();
454 auto leftmost_rc_dim = [&strides](
const array& arr) {
455 int d = arr.ndim() - 1;
456 for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
460 auto a_rc_dim = leftmost_rc_dim(a);
461 auto b_rc_dim = leftmost_rc_dim(b);
464 auto leftmost_s_dim = [](
const array& arr) {
465 int d = arr.ndim() - 1;
466 for (; d >= 0 && arr.strides()[d] == 0; d--) {
470 auto a_s_dim = leftmost_s_dim(a);
471 auto b_s_dim = leftmost_s_dim(b);
473 auto ndim = out.ndim();
477 if (
int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
478 bopt = BinaryOpType::VectorVector;
482 }
else if (
int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
483 bopt = BinaryOpType::VectorScalar;
487 }
else if (
int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
488 bopt = BinaryOpType::ScalarVector;
496 if (dim == 0 || strides[dim - 1] < 16) {
498 bopt = BinaryOpType::General;
501 stride = strides[dim - 1];
505 case BinaryOpType::VectorVector:
506 binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
508 case BinaryOpType::VectorScalar:
509 binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
511 case BinaryOpType::ScalarVector:
512 binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
515 binary_op_dispatch_dims<T, U>(a, b, out,
op);
520template <
typename T,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
532 if (std::is_same<
decltype(opsv), UseDefaultBinaryOp>::value) {
533 if (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::value) {
534 if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
541 DefaultScalarVector<T, T, Op>(
op),
542 DefaultVectorScalar<T, T, Op>(
op),
543 DefaultVectorVector<T, T, Op>(
op));
551 DefaultScalarVector<T, T, Op>(
op),
552 DefaultVectorScalar<T, T, Op>(
op),
555 }
else if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
562 DefaultScalarVector<T, T, Op>(
op),
564 DefaultVectorVector<T, T, Op>(
op));
568 a, b, out,
op, DefaultScalarVector<T, T, Op>(
op), opvs, opvv);
570 }
else if (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::value) {
571 if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
579 DefaultVectorScalar<T, T, Op>(
op),
580 DefaultVectorVector<T, T, Op>(
op));
584 a, b, out,
op, opsv, DefaultVectorScalar<T, T, Op>(
op), opvv);
586 }
else if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
589 a, b, out,
op, opsv, opvs, DefaultVectorVector<T, T, Op>(
op));
592 binary_op<T, T>(a, b, out,
op, opsv, opvs, opvv);
596template <
typename T,
typename Op>
597void binary_op(
const array& a,
const array& b, array& out, Op
op) {
598 DefaultScalarVector<T, T, Op> opsv(
op);
599 DefaultVectorScalar<T, T, Op> opvs(
op);
600 DefaultVectorVector<T, T, Op> opvv(
op);
601 binary_op<T, T>(a, b, out,
op, opsv, opvs, opvv);
604template <
typename... Ops>
605void binary(
const array& a,
const array& b, array& out, Ops... ops) {
606 switch (out.dtype()) {
608 binary_op<bool>(a, b, out, ops...);
611 binary_op<uint8_t>(a, b, out, ops...);
614 binary_op<uint16_t>(a, b, out, ops...);
617 binary_op<uint32_t>(a, b, out, ops...);
620 binary_op<uint64_t>(a, b, out, ops...);
623 binary_op<int8_t>(a, b, out, ops...);
626 binary_op<int16_t>(a, b, out, ops...);
629 binary_op<int32_t>(a, b, out, ops...);
632 binary_op<int64_t>(a, b, out, ops...);
635 binary_op<float16_t>(a, b, out, ops...);
638 binary_op<float>(a, b, out, ops...);
641 binary_op<bfloat16_t>(a, b, out, ops...);
644 binary_op<complex64_t>(a, b, out, ops...);
Op op
Definition binary.h:139
Buffer malloc_or_wait(size_t size)
constexpr Dtype bool_
Definition dtype.h:60
constexpr Dtype uint64
Definition dtype.h:65
constexpr Dtype uint16
Definition dtype.h:63
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
constexpr Dtype bfloat16
Definition dtype.h:74
constexpr Dtype int32
Definition dtype.h:69
constexpr Dtype float32
Definition dtype.h:73
constexpr Dtype int16
Definition dtype.h:68
constexpr Dtype int8
Definition dtype.h:67
constexpr Dtype int64
Definition dtype.h:70
constexpr Dtype uint8
Definition dtype.h:62
constexpr Dtype float16
Definition dtype.h:72
constexpr Dtype uint32
Definition dtype.h:64
constexpr Dtype complex64
Definition dtype.h:75