14enum class BinaryOpType {
22BinaryOpType get_binary_op_type(
const array& a,
const array& b) {
24 if (a.data_size() == 1 && b.data_size() == 1) {
25 bopt = BinaryOpType::ScalarScalar;
26 }
else if (a.data_size() == 1 && b.flags().contiguous) {
27 bopt = BinaryOpType::ScalarVector;
28 }
else if (b.data_size() == 1 && a.flags().contiguous) {
29 bopt = BinaryOpType::VectorScalar;
31 a.flags().row_contiguous && b.flags().row_contiguous ||
32 a.flags().col_contiguous && b.flags().col_contiguous) {
33 bopt = BinaryOpType::VectorVector;
35 bopt = BinaryOpType::General;
40void set_binary_op_output_data(
45 bool donate_with_move =
false) {
49 case BinaryOpType::ScalarScalar:
53 case BinaryOpType::ScalarVector:
55 if (donate_with_move) {
56 out.move_shared_buffer(b);
58 out.copy_shared_buffer(b);
68 case BinaryOpType::VectorScalar:
70 if (donate_with_move) {
71 out.move_shared_buffer(a);
73 out.copy_shared_buffer(a);
83 case BinaryOpType::VectorVector:
85 if (donate_with_move) {
86 out.move_shared_buffer(a);
88 out.copy_shared_buffer(a);
90 }
else if (b_donatable) {
91 if (donate_with_move) {
92 out.move_shared_buffer(b);
94 out.copy_shared_buffer(b);
104 case BinaryOpType::General:
105 if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
106 if (donate_with_move) {
107 out.move_shared_buffer(a);
109 out.copy_shared_buffer(a);
112 b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
113 if (donate_with_move) {
114 out.move_shared_buffer(b);
116 out.copy_shared_buffer(b);
125struct UseDefaultBinaryOp {};
127template <
typename T,
typename U,
typename Op>
128struct DefaultVectorScalar {
131 DefaultVectorScalar(Op op_) :
op(op_) {}
133 void operator()(
const T* a,
const T* b, U* dst,
int size) {
136 *dst =
op(*a, scalar);
143template <
typename T,
typename U,
typename Op>
144struct DefaultScalarVector {
147 DefaultScalarVector(Op op_) :
op(op_) {}
149 void operator()(
const T* a,
const T* b, U* dst,
int size) {
152 *dst =
op(scalar, *b);
159template <
typename T,
typename U,
typename Op>
160struct DefaultVectorVector {
163 DefaultVectorVector(Op op_) :
op(op_) {}
165 void operator()(
const T* a,
const T* b, U* dst,
int size) {
175template <
typename T,
typename U,
typename Op,
int D,
bool Str
ided>
181 const std::vector<int>& shape,
182 const std::vector<size_t>& a_strides,
183 const std::vector<size_t>& b_strides,
184 const std::vector<size_t>& out_strides,
186 auto stride_a = a_strides[axis];
187 auto stride_b = b_strides[axis];
188 auto stride_out = out_strides[axis];
189 auto N = shape[axis];
191 for (
int i = 0; i < N; i++) {
192 if constexpr (D > 1) {
193 binary_op_dims<T, U, Op, D - 1, Strided>(
194 a, b, out,
op, shape, a_strides, b_strides, out_strides, axis + 1);
196 if constexpr (Strided) {
197 op(a, b, out, stride_out);
208template <
typename T,
typename U,
bool Str
ided,
typename Op>
209void binary_op_dispatch_dims(
215 const std::vector<int>& shape,
216 const std::vector<size_t>& a_strides,
217 const std::vector<size_t>& b_strides,
218 const std::vector<size_t>& out_strides) {
219 const T* a_ptr = a.data<T>();
220 const T* b_ptr = b.data<T>();
221 U* out_ptr = out.data<U>();
224 binary_op_dims<T, U, Op, 1, Strided>(
236 binary_op_dims<T, U, Op, 2, Strided>(
248 binary_op_dims<T, U, Op, 3, Strided>(
261 ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
262 ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
263 size_t stride = out_strides[dim - 4];
264 for (
size_t elem = 0; elem < a.size(); elem += stride) {
265 binary_op_dims<T, U, Op, 3, Strided>(
295 auto bopt = get_binary_op_type(a, b);
296 set_binary_op_output_data(a, b, out, bopt);
299 if (bopt == BinaryOpType::ScalarScalar) {
300 *(out.data<U>()) =
op(*a.data<T>(), *b.data<T>());
305 if (bopt == BinaryOpType::ScalarVector) {
306 opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
311 if (bopt == BinaryOpType::VectorScalar) {
312 opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
317 if (bopt == BinaryOpType::VectorVector) {
318 opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
324 a.shape(), {a.strides(), b.strides(), out.strides()});
325 const auto& a_strides = new_strides[0];
326 const auto& b_strides = new_strides[1];
327 const auto& strides = new_strides[2];
330 auto leftmost_rc_dim = [&strides](
const std::vector<size_t>& arr_strides) {
331 int d = arr_strides.size() - 1;
332 for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
336 auto a_rc_dim = leftmost_rc_dim(a_strides);
337 auto b_rc_dim = leftmost_rc_dim(b_strides);
340 auto leftmost_s_dim = [](
const std::vector<size_t>& arr_strides) {
341 int d = arr_strides.size() - 1;
342 for (; d >= 0 && arr_strides[d] == 0; d--) {
346 auto a_s_dim = leftmost_s_dim(a_strides);
347 auto b_s_dim = leftmost_s_dim(b_strides);
349 auto ndim = new_shape.size();
353 if (
int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
354 bopt = BinaryOpType::VectorVector;
358 }
else if (
int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
359 bopt = BinaryOpType::VectorScalar;
363 }
else if (
int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
364 bopt = BinaryOpType::ScalarVector;
371 if (dim == 0 || strides[dim - 1] < 16) {
372 bopt = BinaryOpType::General;
377 case BinaryOpType::VectorVector:
378 binary_op_dispatch_dims<T, U, true>(
379 a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
381 case BinaryOpType::VectorScalar:
382 binary_op_dispatch_dims<T, U, true>(
383 a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
385 case BinaryOpType::ScalarVector:
386 binary_op_dispatch_dims<T, U, true>(
387 a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
390 binary_op_dispatch_dims<T, U, false>(
391 a, b, out,
op, dim, new_shape, a_strides, b_strides, strides);
396template <
typename T,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
408 if constexpr (std::is_same<
decltype(opsv), UseDefaultBinaryOp>::value) {
409 if constexpr (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::value) {
410 if constexpr (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
417 DefaultScalarVector<T, T, Op>(
op),
418 DefaultVectorScalar<T, T, Op>(
op),
419 DefaultVectorVector<T, T, Op>(
op));
427 DefaultScalarVector<T, T, Op>(
op),
428 DefaultVectorScalar<T, T, Op>(
op),
431 }
else if constexpr (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::
439 DefaultScalarVector<T, T, Op>(
op),
441 DefaultVectorVector<T, T, Op>(
op));
445 a, b, out,
op, DefaultScalarVector<T, T, Op>(
op), opvs, opvv);
447 }
else if constexpr (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::
449 if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
457 DefaultVectorScalar<T, T, Op>(
op),
458 DefaultVectorVector<T, T, Op>(
op));
462 a, b, out,
op, opsv, DefaultVectorScalar<T, T, Op>(
op), opvv);
464 }
else if constexpr (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::
468 a, b, out,
op, opsv, opvs, DefaultVectorVector<T, T, Op>(
op));
471 binary_op<T, T>(a, b, out,
op, opsv, opvs, opvv);
475template <
typename T,
typename Op>
476void binary_op(
const array& a,
const array& b, array& out, Op
op) {
477 DefaultScalarVector<T, T, Op> opsv(
op);
478 DefaultVectorScalar<T, T, Op> opvs(
op);
479 DefaultVectorVector<T, T, Op> opvv(
op);
480 binary_op<T, T>(a, b, out,
op, opsv, opvs, opvv);
483template <
typename... Ops>
484void binary(
const array& a,
const array& b, array& out, Ops... ops) {
485 switch (out.dtype()) {
487 binary_op<bool>(a, b, out, ops...);
490 binary_op<uint8_t>(a, b, out, ops...);
493 binary_op<uint16_t>(a, b, out, ops...);
496 binary_op<uint32_t>(a, b, out, ops...);
499 binary_op<uint64_t>(a, b, out, ops...);
502 binary_op<int8_t>(a, b, out, ops...);
505 binary_op<int16_t>(a, b, out, ops...);
508 binary_op<int32_t>(a, b, out, ops...);
511 binary_op<int64_t>(a, b, out, ops...);
514 binary_op<float16_t>(a, b, out, ops...);
517 binary_op<float>(a, b, out, ops...);
520 binary_op<bfloat16_t>(a, b, out, ops...);
523 binary_op<complex64_t>(a, b, out, ops...);
Op op
Definition binary.h:129
Buffer malloc_or_wait(size_t size)
constexpr Dtype bool_
Definition dtype.h:67
constexpr Dtype uint64
Definition dtype.h:72
constexpr Dtype uint16
Definition dtype.h:70
std::tuple< std::vector< int >, std::vector< std::vector< int64_t > > > collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< int64_t > > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
constexpr Dtype bfloat16
Definition dtype.h:81
constexpr Dtype int32
Definition dtype.h:76
constexpr Dtype float32
Definition dtype.h:80
constexpr Dtype int16
Definition dtype.h:75
constexpr Dtype int8
Definition dtype.h:74
constexpr Dtype int64
Definition dtype.h:77
constexpr Dtype uint8
Definition dtype.h:69
constexpr Dtype float16
Definition dtype.h:79
constexpr Dtype uint32
Definition dtype.h:71
bool is_donatable(const array &in, const array &out)
Definition utils.h:174
constexpr Dtype complex64
Definition dtype.h:82