14enum class BinaryOpType {
 
   22BinaryOpType get_binary_op_type(
const array& a, 
const array& b) {
 
   24  if (a.data_size() == 1 && b.data_size() == 1) {
 
   25    bopt = BinaryOpType::ScalarScalar;
 
   26  } 
else if (a.data_size() == 1 && b.flags().contiguous) {
 
   27    bopt = BinaryOpType::ScalarVector;
 
   28  } 
else if (b.data_size() == 1 && a.flags().contiguous) {
 
   29    bopt = BinaryOpType::VectorScalar;
 
   31      a.flags().row_contiguous && b.flags().row_contiguous ||
 
   32      a.flags().col_contiguous && b.flags().col_contiguous) {
 
   33    bopt = BinaryOpType::VectorVector;
 
   35    bopt = BinaryOpType::General;
 
   40void set_binary_op_output_data(
 
   45    bool donate_with_move = 
false) {
 
   49    case BinaryOpType::ScalarScalar:
 
   53    case BinaryOpType::ScalarVector:
 
   55        if (donate_with_move) {
 
   56          out.move_shared_buffer(b);
 
   58          out.copy_shared_buffer(b);
 
   68    case BinaryOpType::VectorScalar:
 
   70        if (donate_with_move) {
 
   71          out.move_shared_buffer(a);
 
   73          out.copy_shared_buffer(a);
 
   83    case BinaryOpType::VectorVector:
 
   85        if (donate_with_move) {
 
   86          out.move_shared_buffer(a);
 
   88          out.copy_shared_buffer(a);
 
   90      } 
else if (b_donatable) {
 
   91        if (donate_with_move) {
 
   92          out.move_shared_buffer(b);
 
   94          out.copy_shared_buffer(b);
 
  104    case BinaryOpType::General:
 
  105      if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
 
  106        if (donate_with_move) {
 
  107          out.move_shared_buffer(a);
 
  109          out.copy_shared_buffer(a);
 
  112          b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
 
  113        if (donate_with_move) {
 
  114          out.move_shared_buffer(b);
 
  116          out.copy_shared_buffer(b);
 
  125struct UseDefaultBinaryOp {};
 
  127template <
typename T, 
typename U, 
typename Op>
 
  128struct DefaultVectorScalar {
 
  131  DefaultVectorScalar(Op op_) : 
op(op_) {}
 
  133  void operator()(
const T* a, 
const T* b, U* dst, 
int size) {
 
  136      *dst = 
op(*a, scalar);
 
  143template <
typename T, 
typename U, 
typename Op>
 
  144struct DefaultScalarVector {
 
  147  DefaultScalarVector(Op op_) : 
op(op_) {}
 
  149  void operator()(
const T* a, 
const T* b, U* dst, 
int size) {
 
  152      *dst = 
op(scalar, *b);
 
  159template <
typename T, 
typename U, 
typename Op>
 
  160struct DefaultVectorVector {
 
  163  DefaultVectorVector(Op op_) : 
op(op_) {}
 
  165  void operator()(
const T* a, 
const T* b, U* dst, 
int size) {
 
  175template <
typename T, 
typename U, 
typename Op, 
int D, 
bool Str
ided>
 
  181    const std::vector<int>& shape,
 
  182    const std::vector<size_t>& a_strides,
 
  183    const std::vector<size_t>& b_strides,
 
  184    const std::vector<size_t>& out_strides,
 
  186  auto stride_a = a_strides[axis];
 
  187  auto stride_b = b_strides[axis];
 
  188  auto stride_out = out_strides[axis];
 
  189  auto N = shape[axis];
 
  191  for (
int i = 0; i < N; i++) {
 
  192    if constexpr (D > 1) {
 
  193      binary_op_dims<T, U, Op, D - 1, Strided>(
 
  194          a, b, out, 
op, shape, a_strides, b_strides, out_strides, axis + 1);
 
  196      if constexpr (Strided) {
 
  197        op(a, b, out, stride_out);
 
  208template <
typename T, 
typename U, 
bool Str
ided, 
typename Op>
 
  209void binary_op_dispatch_dims(
 
  215    const std::vector<int>& shape,
 
  216    const std::vector<size_t>& a_strides,
 
  217    const std::vector<size_t>& b_strides,
 
  218    const std::vector<size_t>& out_strides) {
 
  219  const T* a_ptr = a.data<T>();
 
  220  const T* b_ptr = b.data<T>();
 
  221  U* out_ptr = out.data<U>();
 
  224      binary_op_dims<T, U, Op, 1, Strided>(
 
  236      binary_op_dims<T, U, Op, 2, Strided>(
 
  248      binary_op_dims<T, U, Op, 3, Strided>(
 
  261  ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
 
  262  ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
 
  263  size_t stride = out_strides[dim - 4];
 
  264  for (
size_t elem = 0; elem < a.size(); elem += stride) {
 
  265    binary_op_dims<T, U, Op, 3, Strided>(
 
  295  auto bopt = get_binary_op_type(a, b);
 
  296  set_binary_op_output_data(a, b, out, bopt);
 
  299  if (bopt == BinaryOpType::ScalarScalar) {
 
  300    *(out.data<U>()) = 
op(*a.data<T>(), *b.data<T>());
 
  305  if (bopt == BinaryOpType::ScalarVector) {
 
  306    opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
 
  311  if (bopt == BinaryOpType::VectorScalar) {
 
  312    opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
 
  317  if (bopt == BinaryOpType::VectorVector) {
 
  318    opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
 
  324      a.shape(), {a.strides(), b.strides(), out.strides()});
 
  325  const auto& a_strides = new_strides[0];
 
  326  const auto& b_strides = new_strides[1];
 
  327  const auto& strides = new_strides[2];
 
  330  auto leftmost_rc_dim = [&strides](
const std::vector<size_t>& arr_strides) {
 
  331    int d = arr_strides.size() - 1;
 
  332    for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
 
  336  auto a_rc_dim = leftmost_rc_dim(a_strides);
 
  337  auto b_rc_dim = leftmost_rc_dim(b_strides);
 
  340  auto leftmost_s_dim = [](
const std::vector<size_t>& arr_strides) {
 
  341    int d = arr_strides.size() - 1;
 
  342    for (; d >= 0 && arr_strides[d] == 0; d--) {
 
  346  auto a_s_dim = leftmost_s_dim(a_strides);
 
  347  auto b_s_dim = leftmost_s_dim(b_strides);
 
  349  auto ndim = new_shape.size();
 
  353  if (
int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
 
  354    bopt = BinaryOpType::VectorVector;
 
  358  } 
else if (
int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
 
  359    bopt = BinaryOpType::VectorScalar;
 
  363  } 
else if (
int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
 
  364    bopt = BinaryOpType::ScalarVector;
 
  371  if (dim == 0 || strides[dim - 1] < 16) {
 
  372    bopt = BinaryOpType::General;
 
  377    case BinaryOpType::VectorVector:
 
  378      binary_op_dispatch_dims<T, U, true>(
 
  379          a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
 
  381    case BinaryOpType::VectorScalar:
 
  382      binary_op_dispatch_dims<T, U, true>(
 
  383          a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
 
  385    case BinaryOpType::ScalarVector:
 
  386      binary_op_dispatch_dims<T, U, true>(
 
  387          a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
 
  390      binary_op_dispatch_dims<T, U, false>(
 
  391          a, b, out, 
op, dim, new_shape, a_strides, b_strides, strides);
 
  396template <
typename T, 
typename Op, 
typename OpSV, 
typename OpVS, 
typename OpVV>
 
  408  if constexpr (std::is_same<
decltype(opsv), UseDefaultBinaryOp>::value) {
 
  409    if constexpr (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::value) {
 
  410      if constexpr (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
 
  417            DefaultScalarVector<T, T, Op>(
op),
 
  418            DefaultVectorScalar<T, T, Op>(
op),
 
  419            DefaultVectorVector<T, T, Op>(
op));
 
  427            DefaultScalarVector<T, T, Op>(
op),
 
  428            DefaultVectorScalar<T, T, Op>(
op),
 
  431    } 
else if constexpr (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::
 
  439          DefaultScalarVector<T, T, Op>(
op),
 
  441          DefaultVectorVector<T, T, Op>(
op));
 
  445          a, b, out, 
op, DefaultScalarVector<T, T, Op>(
op), opvs, opvv);
 
  447  } 
else if constexpr (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::
 
  449    if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
 
  457          DefaultVectorScalar<T, T, Op>(
op),
 
  458          DefaultVectorVector<T, T, Op>(
op));
 
  462          a, b, out, 
op, opsv, DefaultVectorScalar<T, T, Op>(
op), opvv);
 
  464  } 
else if constexpr (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::
 
  468        a, b, out, 
op, opsv, opvs, DefaultVectorVector<T, T, Op>(
op));
 
  471    binary_op<T, T>(a, b, out, 
op, opsv, opvs, opvv);
 
  475template <
typename T, 
typename Op>
 
  476void binary_op(
const array& a, 
const array& b, array& out, Op 
op) {
 
  477  DefaultScalarVector<T, T, Op> opsv(
op);
 
  478  DefaultVectorScalar<T, T, Op> opvs(
op);
 
  479  DefaultVectorVector<T, T, Op> opvv(
op);
 
  480  binary_op<T, T>(a, b, out, 
op, opsv, opvs, opvv);
 
  483template <
typename... Ops>
 
  484void binary(
const array& a, 
const array& b, array& out, Ops... ops) {
 
  485  switch (out.dtype()) {
 
  487      binary_op<bool>(a, b, out, ops...);
 
  490      binary_op<uint8_t>(a, b, out, ops...);
 
  493      binary_op<uint16_t>(a, b, out, ops...);
 
  496      binary_op<uint32_t>(a, b, out, ops...);
 
  499      binary_op<uint64_t>(a, b, out, ops...);
 
  502      binary_op<int8_t>(a, b, out, ops...);
 
  505      binary_op<int16_t>(a, b, out, ops...);
 
  508      binary_op<int32_t>(a, b, out, ops...);
 
  511      binary_op<int64_t>(a, b, out, ops...);
 
  514      binary_op<float16_t>(a, b, out, ops...);
 
  517      binary_op<float>(a, b, out, ops...);
 
  520      binary_op<bfloat16_t>(a, b, out, ops...);
 
  523      binary_op<complex64_t>(a, b, out, ops...);
 
Op op
Definition binary.h:129
 
Buffer malloc_or_wait(size_t size)
 
constexpr Dtype bool_
Definition dtype.h:67
 
constexpr Dtype uint64
Definition dtype.h:72
 
constexpr Dtype uint16
Definition dtype.h:70
 
std::tuple< std::vector< int >, std::vector< std::vector< int64_t > > > collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< int64_t > > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
 
constexpr Dtype bfloat16
Definition dtype.h:81
 
constexpr Dtype int32
Definition dtype.h:76
 
constexpr Dtype float32
Definition dtype.h:80
 
constexpr Dtype int16
Definition dtype.h:75
 
constexpr Dtype int8
Definition dtype.h:74
 
constexpr Dtype int64
Definition dtype.h:77
 
constexpr Dtype uint8
Definition dtype.h:69
 
constexpr Dtype float16
Definition dtype.h:79
 
constexpr Dtype uint32
Definition dtype.h:71
 
bool is_donatable(const array &in, const array &out)
Definition utils.h:174
 
constexpr Dtype complex64
Definition dtype.h:82