14enum class BinaryOpType {
 
   22BinaryOpType get_binary_op_type(
const array& a, 
const array& b) {
 
   24  if (a.data_size() == 1 && b.data_size() == 1) {
 
   25    bopt = BinaryOpType::ScalarScalar;
 
   26  } 
else if (a.data_size() == 1 && b.flags().contiguous) {
 
   27    bopt = BinaryOpType::ScalarVector;
 
   28  } 
else if (b.data_size() == 1 && a.flags().contiguous) {
 
   29    bopt = BinaryOpType::VectorScalar;
 
   31      a.flags().row_contiguous && b.flags().row_contiguous ||
 
   32      a.flags().col_contiguous && b.flags().col_contiguous) {
 
   33    bopt = BinaryOpType::VectorVector;
 
   35    bopt = BinaryOpType::General;
 
   40void set_binary_op_output_data(
 
   45    bool donate_with_move = 
false) {
 
   47    case BinaryOpType::ScalarScalar:
 
   51    case BinaryOpType::ScalarVector:
 
   52      if (b.is_donatable() && b.itemsize() == out.itemsize()) {
 
   53        if (donate_with_move) {
 
   54          out.move_shared_buffer(b);
 
   56          out.copy_shared_buffer(b);
 
   66    case BinaryOpType::VectorScalar:
 
   67      if (a.is_donatable() && a.itemsize() == out.itemsize()) {
 
   68        if (donate_with_move) {
 
   69          out.move_shared_buffer(a);
 
   71          out.copy_shared_buffer(a);
 
   81    case BinaryOpType::VectorVector:
 
   82      if (a.is_donatable() && a.itemsize() == out.itemsize()) {
 
   83        if (donate_with_move) {
 
   84          out.move_shared_buffer(a);
 
   86          out.copy_shared_buffer(a);
 
   88      } 
else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
 
   89        if (donate_with_move) {
 
   90          out.move_shared_buffer(b);
 
   92          out.copy_shared_buffer(b);
 
  102    case BinaryOpType::General:
 
  103      if (a.is_donatable() && a.flags().row_contiguous &&
 
  104          a.itemsize() == out.itemsize() && a.size() == out.size()) {
 
  105        if (donate_with_move) {
 
  106          out.move_shared_buffer(a);
 
  108          out.copy_shared_buffer(a);
 
  111          b.is_donatable() && b.flags().row_contiguous &&
 
  112          b.itemsize() == out.itemsize() && b.size() == out.size()) {
 
  113        if (donate_with_move) {
 
  114          out.move_shared_buffer(b);
 
  116          out.copy_shared_buffer(b);
 
  125struct UseDefaultBinaryOp {
 
  126  template <
typename T, 
typename U>
 
  127  void operator()(
const T* a, 
const T* b, U* dst, 
int size) {
 
  132  template <
typename T, 
typename U>
 
  133  void operator()(
const T* a, 
const T* b, U* dst_a, U* dst_b, 
int size) {
 
  139template <
typename T, 
typename U, 
typename Op>
 
  140struct DefaultVectorScalar {
 
  143  DefaultVectorScalar(Op op_) : 
op(op_) {}
 
  145  void operator()(
const T* a, 
const T* b, U* dst, 
int size) {
 
  148      *dst = 
op(*a, scalar);
 
  154  void operator()(
const T* a, 
const T* b, U* dst_a, U* dst_b, 
int size) {
 
  157      auto dst = 
op(*a, scalar);
 
  167template <
typename T, 
typename U, 
typename Op>
 
  168struct DefaultScalarVector {
 
  171  DefaultScalarVector(Op op_) : 
op(op_) {}
 
  173  void operator()(
const T* a, 
const T* b, U* dst, 
int size) {
 
  176      *dst = 
op(scalar, *b);
 
  182  void operator()(
const T* a, 
const T* b, U* dst_a, U* dst_b, 
int size) {
 
  185      auto dst = 
op(scalar, *b);
 
  195template <
typename T, 
typename U, 
typename Op>
 
  196struct DefaultVectorVector {
 
  199  DefaultVectorVector(Op op_) : 
op(op_) {}
 
  201  void operator()(
const T* a, 
const T* b, U* dst, 
int size) {
 
  210  void operator()(
const T* a, 
const T* b, U* dst_a, U* dst_b, 
int size) {
 
  212      auto dst = 
op(*a, *b);
 
  223template <
typename T, 
typename U, 
typename Op>
 
  224void binary_op_dims1(
const array& a, 
const array& b, array& out, Op 
op) {
 
  225  const T* a_ptr = a.data<T>();
 
  226  const T* b_ptr = b.data<T>();
 
  227  U* dst = out.data<U>();
 
  230  for (
size_t i = 0; i < out.size(); ++i) {
 
  231    dst[i] = 
op(a_ptr[a_idx], b_ptr[b_idx]);
 
  232    a_idx += a.strides()[0];
 
  233    b_idx += b.strides()[0];
 
  237template <
typename T, 
typename U, 
typename Op>
 
  244  const T* a_ptr = a.data<T>();
 
  245  const T* b_ptr = b.data<T>();
 
  246  U* dst = out.data<U>();
 
  249  for (
size_t i = 0; i < a.shape()[0]; i++) {
 
  250    op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
 
  251    a_idx += a.strides()[0];
 
  252    b_idx += b.strides()[0];
 
  257template <
typename T, 
typename U, 
typename Op>
 
  258void binary_op_dims2(
const array& a, 
const array& b, array& out, Op 
op) {
 
  259  const T* a_ptr = a.data<T>();
 
  260  const T* b_ptr = b.data<T>();
 
  261  U* dst = out.data<U>();
 
  265  for (
size_t i = 0; i < a.shape()[0]; ++i) {
 
  266    for (
size_t j = 0; j < a.shape()[1]; ++j) {
 
  267      dst[out_idx++] = 
op(a_ptr[a_idx], b_ptr[b_idx]);
 
  268      a_idx += a.strides()[1];
 
  269      b_idx += b.strides()[1];
 
  271    a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
 
  272    b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
 
  276template <
typename T, 
typename U, 
typename Op>
 
  283  const T* a_ptr = a.data<T>();
 
  284  const T* b_ptr = b.data<T>();
 
  285  U* dst = out.data<U>();
 
  288  for (
size_t i = 0; i < a.shape()[0]; ++i) {
 
  289    for (
size_t j = 0; j < a.shape()[1]; ++j) {
 
  290      op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
 
  291      a_idx += a.strides()[1];
 
  292      b_idx += b.strides()[1];
 
  295    a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
 
  296    b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
 
  300template <
typename T, 
typename U, 
typename Op>
 
  301void binary_op_dims3(
const array& a, 
const array& b, array& out, Op 
op) {
 
  302  const T* a_ptr = a.data<T>();
 
  303  const T* b_ptr = b.data<T>();
 
  304  U* dst = out.data<U>();
 
  308  for (
size_t i = 0; i < a.shape()[0]; ++i) {
 
  309    for (
size_t j = 0; j < a.shape()[1]; ++j) {
 
  310      for (
size_t k = 0; k < a.shape()[2]; ++k) {
 
  311        dst[out_idx++] = 
op(a_ptr[a_idx], b_ptr[b_idx]);
 
  312        a_idx += a.strides()[2];
 
  313        b_idx += b.strides()[2];
 
  315      a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
 
  316      b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
 
  318    a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
 
  319    b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
 
  323template <
typename T, 
typename U, 
typename Op>
 
  324void binary_op_dims4(
const array& a, 
const array& b, array& out, Op 
op) {
 
  325  const T* a_ptr = a.data<T>();
 
  326  const T* b_ptr = b.data<T>();
 
  327  U* dst = out.data<U>();
 
  331  for (
size_t i = 0; i < a.shape()[0]; ++i) {
 
  332    for (
size_t j = 0; j < a.shape()[1]; ++j) {
 
  333      for (
size_t k = 0; k < a.shape()[2]; ++k) {
 
  334        for (
size_t ii = 0; ii < a.shape()[3]; ++ii) {
 
  335          dst[out_idx++] = 
op(a_ptr[a_idx], b_ptr[b_idx]);
 
  336          a_idx += a.strides()[3];
 
  337          b_idx += b.strides()[3];
 
  339        a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
 
  340        b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
 
  342      a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
 
  343      b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
 
  345    a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
 
  346    b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
 
  350template <
typename T, 
typename U, 
typename Op>
 
  351void binary_op_dispatch_dims(
 
  356  switch (out.ndim()) {
 
  358      binary_op_dims1<T, U, Op>(a, b, out, 
op);
 
  361      binary_op_dims2<T, U, Op>(a, b, out, 
op);
 
  364      binary_op_dims3<T, U, Op>(a, b, out, 
op);
 
  367      binary_op_dims4<T, U, Op>(a, b, out, 
op);
 
  371  const T* a_ptr = a.data<T>();
 
  372  const T* b_ptr = b.data<T>();
 
  373  U* dst = out.data<U>();
 
  374  for (
size_t i = 0; i < out.size(); i++) {
 
  375    int a_idx = 
elem_to_loc(i, a.shape(), a.strides());
 
  376    int b_idx = 
elem_to_loc(i, b.shape(), b.strides());
 
  377    dst[i] = 
op(a_ptr[a_idx], b_ptr[b_idx]);
 
  381template <
typename T, 
typename U, 
typename Op>
 
  382void binary_op_dispatch_dims(
 
  392      binary_op_dims1<T, U, Op>(a, b, out, 
op, stride);
 
  395      binary_op_dims2<T, U, Op>(a, b, out, 
op, stride);
 
  399  const T* a_ptr = a.data<T>();
 
  400  const T* b_ptr = b.data<T>();
 
  401  U* dst = out.data<U>();
 
  402  for (
size_t i = 0; i < out.size(); i += stride) {
 
  403    int a_idx = 
elem_to_loc(i, a.shape(), a.strides());
 
  404    int b_idx = 
elem_to_loc(i, b.shape(), b.strides());
 
  405    op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
 
  425  auto bopt = get_binary_op_type(a, b);
 
  426  set_binary_op_output_data(a, b, out, bopt);
 
  429  if (bopt == BinaryOpType::ScalarScalar) {
 
  430    *(out.data<U>()) = 
op(*a.data<T>(), *b.data<T>());
 
  435  if (bopt == BinaryOpType::ScalarVector) {
 
  436    opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
 
  441  if (bopt == BinaryOpType::VectorScalar) {
 
  442    opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
 
  447  if (bopt == BinaryOpType::VectorVector) {
 
  448    opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
 
  455  auto& strides = out.strides();
 
  456  auto leftmost_rc_dim = [&strides](
const array& arr) {
 
  457    int d = arr.ndim() - 1;
 
  458    for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
 
  462  auto a_rc_dim = leftmost_rc_dim(a);
 
  463  auto b_rc_dim = leftmost_rc_dim(b);
 
  466  auto leftmost_s_dim = [](
const array& arr) {
 
  467    int d = arr.ndim() - 1;
 
  468    for (; d >= 0 && arr.strides()[d] == 0; d--) {
 
  472  auto a_s_dim = leftmost_s_dim(a);
 
  473  auto b_s_dim = leftmost_s_dim(b);
 
  475  auto ndim = out.ndim();
 
  479  if (
int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
 
  480    bopt = BinaryOpType::VectorVector;
 
  484  } 
else if (
int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
 
  485    bopt = BinaryOpType::VectorScalar;
 
  489  } 
else if (
int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
 
  490    bopt = BinaryOpType::ScalarVector;
 
  498  if (dim == 0 || strides[dim - 1] < 16) {
 
  500    bopt = BinaryOpType::General;
 
  503    stride = strides[dim - 1];
 
  507    case BinaryOpType::VectorVector:
 
  508      binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
 
  510    case BinaryOpType::VectorScalar:
 
  511      binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
 
  513    case BinaryOpType::ScalarVector:
 
  514      binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
 
  517      binary_op_dispatch_dims<T, U>(a, b, out, 
op);
 
  522template <
typename T, 
typename Op, 
typename OpSV, 
typename OpVS, 
typename OpVV>
 
  534  if (std::is_same<
decltype(opsv), UseDefaultBinaryOp>::value) {
 
  535    if (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::value) {
 
  536      if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
 
  543            DefaultScalarVector<T, T, Op>(
op),
 
  544            DefaultVectorScalar<T, T, Op>(
op),
 
  545            DefaultVectorVector<T, T, Op>(
op));
 
  553            DefaultScalarVector<T, T, Op>(
op),
 
  554            DefaultVectorScalar<T, T, Op>(
op),
 
  557    } 
else if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
 
  564          DefaultScalarVector<T, T, Op>(
op),
 
  566          DefaultVectorVector<T, T, Op>(
op));
 
  570          a, b, out, 
op, DefaultScalarVector<T, T, Op>(
op), opvs, opvv);
 
  572  } 
else if (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::value) {
 
  573    if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
 
  581          DefaultVectorScalar<T, T, Op>(
op),
 
  582          DefaultVectorVector<T, T, Op>(
op));
 
  586          a, b, out, 
op, opsv, DefaultVectorScalar<T, T, Op>(
op), opvv);
 
  588  } 
else if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
 
  591        a, b, out, 
op, opsv, opvs, DefaultVectorVector<T, T, Op>(
op));
 
  594    binary_op<T, T>(a, b, out, 
op, opsv, opvs, opvv);
 
  598template <
typename T, 
typename Op>
 
  599void binary_op(
const array& a, 
const array& b, array& out, Op 
op) {
 
  600  DefaultScalarVector<T, T, Op> opsv(
op);
 
  601  DefaultVectorScalar<T, T, Op> opvs(
op);
 
  602  DefaultVectorVector<T, T, Op> opvv(
op);
 
  603  binary_op<T, T>(a, b, out, 
op, opsv, opvs, opvv);
 
  606template <
typename... Ops>
 
  607void binary(
const array& a, 
const array& b, array& out, Ops... ops) {
 
  608  switch (out.dtype()) {
 
  610      binary_op<bool>(a, b, out, ops...);
 
  613      binary_op<uint8_t>(a, b, out, ops...);
 
  616      binary_op<uint16_t>(a, b, out, ops...);
 
  619      binary_op<uint32_t>(a, b, out, ops...);
 
  622      binary_op<uint64_t>(a, b, out, ops...);
 
  625      binary_op<int8_t>(a, b, out, ops...);
 
  628      binary_op<int16_t>(a, b, out, ops...);
 
  631      binary_op<int32_t>(a, b, out, ops...);
 
  634      binary_op<int64_t>(a, b, out, ops...);
 
  637      binary_op<float16_t>(a, b, out, ops...);
 
  640      binary_op<float>(a, b, out, ops...);
 
  643      binary_op<bfloat16_t>(a, b, out, ops...);
 
  646      binary_op<complex64_t>(a, b, out, ops...);
 
Op op
Definition binary.h:141
 
Buffer malloc_or_wait(size_t size)
 
constexpr Dtype bool_
Definition dtype.h:58
 
constexpr Dtype uint64
Definition dtype.h:63
 
constexpr Dtype uint16
Definition dtype.h:61
 
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
 
constexpr Dtype bfloat16
Definition dtype.h:72
 
constexpr Dtype int32
Definition dtype.h:67
 
constexpr Dtype float32
Definition dtype.h:71
 
constexpr Dtype int16
Definition dtype.h:66
 
constexpr Dtype int8
Definition dtype.h:65
 
constexpr Dtype int64
Definition dtype.h:68
 
constexpr Dtype uint8
Definition dtype.h:60
 
constexpr Dtype float16
Definition dtype.h:70
 
constexpr Dtype uint32
Definition dtype.h:62
 
constexpr Dtype complex64
Definition dtype.h:73