12template <
typename T, 
typename U, 
typename Op>
 
   19  const T* a_ptr = a.data<T>();
 
   20  const T* b_ptr = b.data<T>();
 
   21  U* dst_a = out_a.data<U>();
 
   22  U* dst_b = out_b.data<U>();
 
   25  for (
size_t i = 0; i < out_a.size(); ++i) {
 
   26    auto dst = 
op(a_ptr[a_idx], b_ptr[b_idx]);
 
   28    dst_b[i] = dst.second;
 
   29    a_idx += a.strides()[0];
 
   30    b_idx += b.strides()[0];
 
   34template <
typename T, 
typename U, 
typename Op>
 
   42  const T* a_ptr = a.data<T>();
 
   43  const T* b_ptr = b.data<T>();
 
   44  U* dst_a = out_a.data<U>();
 
   45  U* dst_b = out_b.data<U>();
 
   48  for (
size_t i = 0; i < a.shape()[0]; i++) {
 
   49    op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
 
   50    a_idx += a.strides()[0];
 
   51    b_idx += b.strides()[0];
 
   57template <
typename T, 
typename U, 
typename Op>
 
   64  const T* a_ptr = a.data<T>();
 
   65  const T* b_ptr = b.data<T>();
 
   66  U* dst_a = out_a.data<U>();
 
   67  U* dst_b = out_b.data<U>();
 
   71  for (
size_t i = 0; i < a.shape()[0]; ++i) {
 
   72    for (
size_t j = 0; j < a.shape()[1]; ++j) {
 
   73      auto dst = 
op(a_ptr[a_idx], b_ptr[b_idx]);
 
   74      dst_a[out_idx] = dst.first;
 
   75      dst_b[out_idx++] = dst.second;
 
   76      a_idx += a.strides()[1];
 
   77      b_idx += b.strides()[1];
 
   79    a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
 
   80    b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
 
   84template <
typename T, 
typename U, 
typename Op>
 
   92  const T* a_ptr = a.data<T>();
 
   93  const T* b_ptr = b.data<T>();
 
   94  U* dst_a = out_a.data<U>();
 
   95  U* dst_b = out_b.data<U>();
 
   98  for (
size_t i = 0; i < a.shape()[0]; ++i) {
 
   99    for (
size_t j = 0; j < a.shape()[1]; ++j) {
 
  100      op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
 
  101      a_idx += a.strides()[1];
 
  102      b_idx += b.strides()[1];
 
  106    a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
 
  107    b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
 
  111template <
typename T, 
typename U, 
typename Op>
 
  118  const T* a_ptr = a.data<T>();
 
  119  const T* b_ptr = b.data<T>();
 
  120  U* dst_a = out_a.data<U>();
 
  121  U* dst_b = out_b.data<U>();
 
  125  for (
size_t i = 0; i < a.shape()[0]; ++i) {
 
  126    for (
size_t j = 0; j < a.shape()[1]; ++j) {
 
  127      for (
size_t k = 0; k < a.shape()[2]; ++k) {
 
  128        auto dst = 
op(a_ptr[a_idx], b_ptr[b_idx]);
 
  129        dst_a[out_idx] = dst.first;
 
  130        dst_b[out_idx++] = dst.second;
 
  131        a_idx += a.strides()[2];
 
  132        b_idx += b.strides()[2];
 
  134      a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
 
  135      b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
 
  137    a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
 
  138    b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
 
  142template <
typename T, 
typename U, 
typename Op>
 
  149  const T* a_ptr = a.data<T>();
 
  150  const T* b_ptr = b.data<T>();
 
  151  U* dst_a = out_a.data<U>();
 
  152  U* dst_b = out_b.data<U>();
 
  156  for (
size_t i = 0; i < a.shape()[0]; ++i) {
 
  157    for (
size_t j = 0; j < a.shape()[1]; ++j) {
 
  158      for (
size_t k = 0; k < a.shape()[2]; ++k) {
 
  159        for (
size_t ii = 0; ii < a.shape()[3]; ++ii) {
 
  160          auto dst = 
op(a_ptr[a_idx], b_ptr[b_idx]);
 
  161          dst_a[out_idx] = dst.first;
 
  162          dst_b[out_idx++] = dst.second;
 
  163          a_idx += a.strides()[3];
 
  164          b_idx += b.strides()[3];
 
  166        a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
 
  167        b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
 
  169      a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
 
  170      b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
 
  172    a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
 
  173    b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
 
  177template <
typename T, 
typename U, 
typename Op>
 
  178void binary_op_dispatch_dims(
 
  184  switch (out_a.ndim()) {
 
  186      binary_op_dims1<T, U, Op>(a, b, out_a, out_b, 
op);
 
  189      binary_op_dims2<T, U, Op>(a, b, out_a, out_b, 
op);
 
  192      binary_op_dims3<T, U, Op>(a, b, out_a, out_b, 
op);
 
  195      binary_op_dims4<T, U, Op>(a, b, out_a, out_b, 
op);
 
  199  const T* a_ptr = a.data<T>();
 
  200  const T* b_ptr = b.data<T>();
 
  201  U* dst_a = out_a.data<U>();
 
  202  U* dst_b = out_b.data<U>();
 
  203  for (
size_t i = 0; i < out_a.size(); i++) {
 
  204    int a_idx = 
elem_to_loc(i, a.shape(), a.strides());
 
  205    int b_idx = 
elem_to_loc(i, b.shape(), b.strides());
 
  206    std::tie(dst_a[i], dst_b[i]) = 
op(a_ptr[a_idx], b_ptr[b_idx]);
 
  210template <
typename T, 
typename U, 
typename Op>
 
  211void binary_op_dispatch_dims(
 
  222      binary_op_dims1<T, U, Op>(a, b, out_a, out_b, 
op, stride);
 
  225      binary_op_dims2<T, U, Op>(a, b, out_a, out_b, 
op, stride);
 
  229  const T* a_ptr = a.data<T>();
 
  230  const T* b_ptr = b.data<T>();
 
  231  U* dst_a = out_a.data<U>();
 
  232  U* dst_b = out_b.data<U>();
 
  233  for (
size_t i = 0; i < out_a.size(); i += stride) {
 
  234    int a_idx = 
elem_to_loc(i, a.shape(), a.strides());
 
  235    int b_idx = 
elem_to_loc(i, b.shape(), b.strides());
 
  236    op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
 
  258  auto bopt = get_binary_op_type(a, b);
 
  259  set_binary_op_output_data(a, b, out_a, bopt);
 
  260  set_binary_op_output_data(a, b, out_b, bopt);
 
  263  if (bopt == BinaryOpType::ScalarScalar) {
 
  264    std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
 
  265        op(*a.data<T>(), *b.data<T>());
 
  270  if (bopt == BinaryOpType::ScalarVector) {
 
  281  if (bopt == BinaryOpType::VectorScalar) {
 
  292  if (bopt == BinaryOpType::VectorVector) {
 
  305  auto& strides = out_a.strides();
 
  306  auto leftmost_rc_dim = [&strides](
const array& arr) {
 
  307    int d = arr.ndim() - 1;
 
  308    for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
 
  312  auto a_rc_dim = leftmost_rc_dim(a);
 
  313  auto b_rc_dim = leftmost_rc_dim(b);
 
  316  auto leftmost_s_dim = [](
const array& arr) {
 
  317    int d = arr.ndim() - 1;
 
  318    for (; d >= 0 && arr.strides()[d] == 0; d--) {
 
  322  auto a_s_dim = leftmost_s_dim(a);
 
  323  auto b_s_dim = leftmost_s_dim(b);
 
  325  auto ndim = out_a.ndim();
 
  329  if (
int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
 
  330    bopt = BinaryOpType::VectorVector;
 
  334  } 
else if (
int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
 
  335    bopt = BinaryOpType::VectorScalar;
 
  339  } 
else if (
int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
 
  340    bopt = BinaryOpType::ScalarVector;
 
  348  if (dim == 0 || strides[dim - 1] < 16) {
 
  350    bopt = BinaryOpType::General;
 
  353    stride = strides[dim - 1];
 
  357    case BinaryOpType::VectorVector:
 
  358      binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
 
  360    case BinaryOpType::VectorScalar:
 
  361      binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
 
  363    case BinaryOpType::ScalarVector:
 
  364      binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
 
  367      binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, 
op);
 
  372template <
typename T, 
typename Op, 
typename OpSV, 
typename OpVS, 
typename OpVV>
 
  376    std::vector<array>& outputs,
 
  384  if (std::is_same<
decltype(opsv), UseDefaultBinaryOp>::value) {
 
  385    if (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::value) {
 
  386      if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
 
  394            DefaultScalarVector<T, T, Op>(
op),
 
  395            DefaultVectorScalar<T, T, Op>(
op),
 
  396            DefaultVectorVector<T, T, Op>(
op));
 
  405            DefaultScalarVector<T, T, Op>(
op),
 
  406            DefaultVectorScalar<T, T, Op>(
op),
 
  409    } 
else if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
 
  417          DefaultScalarVector<T, T, Op>(
op),
 
  419          DefaultVectorVector<T, T, Op>(
op));
 
  428          DefaultScalarVector<T, T, Op>(
op),
 
  432  } 
else if (std::is_same<
decltype(opvs), UseDefaultBinaryOp>::value) {
 
  433    if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
 
  442          DefaultVectorScalar<T, T, Op>(
op),
 
  443          DefaultVectorVector<T, T, Op>(
op));
 
  453          DefaultVectorScalar<T, T, Op>(
op),
 
  456  } 
else if (std::is_same<
decltype(opvv), UseDefaultBinaryOp>::value) {
 
  466        DefaultVectorVector<T, T, Op>(
op));
 
  469    binary_op<T, T>(a, b, outputs[0], outputs[1], 
op, opsv, opvs, opvv);
 
  473template <
typename T, 
typename Op>
 
  477    std::vector<array>& outputs,
 
  479  DefaultScalarVector<T, T, Op> opsv(
op);
 
  480  DefaultVectorScalar<T, T, Op> opvs(
op);
 
  481  DefaultVectorVector<T, T, Op> opvv(
op);
 
  482  binary_op<T, T>(a, b, outputs[0], outputs[1], 
op, opsv, opvs, opvv);
 
  485template <
typename... Ops>
 
  489    std::vector<array>& outputs,
 
  491  switch (outputs[0].dtype()) {
 
  493      binary_op<bool>(a, b, outputs, ops...);
 
  496      binary_op<uint8_t>(a, b, outputs, ops...);
 
  499      binary_op<uint16_t>(a, b, outputs, ops...);
 
  502      binary_op<uint32_t>(a, b, outputs, ops...);
 
  505      binary_op<uint64_t>(a, b, outputs, ops...);
 
  508      binary_op<int8_t>(a, b, outputs, ops...);
 
  511      binary_op<int16_t>(a, b, outputs, ops...);
 
  514      binary_op<int32_t>(a, b, outputs, ops...);
 
  517      binary_op<int64_t>(a, b, outputs, ops...);
 
  520      binary_op<float16_t>(a, b, outputs, ops...);
 
  523      binary_op<float>(a, b, outputs, ops...);
 
  526      binary_op<bfloat16_t>(a, b, outputs, ops...);
 
  529      binary_op<complex64_t>(a, b, outputs, ops...);
 
Op op
Definition binary.h:141
 
constexpr Dtype bool_
Definition dtype.h:58
 
constexpr Dtype uint64
Definition dtype.h:63
 
constexpr Dtype uint16
Definition dtype.h:61
 
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
 
constexpr Dtype bfloat16
Definition dtype.h:72
 
constexpr Dtype int32
Definition dtype.h:67
 
constexpr Dtype float32
Definition dtype.h:71
 
constexpr Dtype int16
Definition dtype.h:66
 
constexpr Dtype int8
Definition dtype.h:65
 
constexpr Dtype int64
Definition dtype.h:68
 
constexpr Dtype uint8
Definition dtype.h:60
 
constexpr Dtype float16
Definition dtype.h:70
 
constexpr Dtype uint32
Definition dtype.h:62
 
constexpr Dtype complex64
Definition dtype.h:73