13enum class TernaryOpType {
 
   19get_ternary_op_type(
const array& a, 
const array& b, 
const array& c) {
 
   21  if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
 
   22    topt = TernaryOpType::ScalarScalarScalar;
 
   24    topt = TernaryOpType::General;
 
   29void set_ternary_op_output_data(
 
   35    bool donate_with_move = 
false) {
 
   37    case TernaryOpType::ScalarScalarScalar:
 
   41    case TernaryOpType::General:
 
   47template <
typename T1, 
typename T2, 
typename T3, 
typename U, 
typename Op>
 
   54  const T1* a_ptr = a.data<T1>();
 
   55  const T2* b_ptr = b.data<T2>();
 
   56  const T3* c_ptr = c.data<T3>();
 
   58  U* dst = out.data<U>();
 
   62  for (
size_t i = 0; i < out.size(); ++i) {
 
   63    dst[i] = 
op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
 
   64    a_idx += a.strides()[0];
 
   65    b_idx += b.strides()[0];
 
   66    c_idx += c.strides()[0];
 
   70template <
typename T1, 
typename T2, 
typename T3, 
typename U, 
typename Op>
 
   77  const T1* a_ptr = a.data<T1>();
 
   78  const T2* b_ptr = b.data<T2>();
 
   79  const T3* c_ptr = c.data<T3>();
 
   81  U* dst = out.data<U>();
 
   86  for (
size_t i = 0; i < a.shape()[0]; ++i) {
 
   87    for (
size_t j = 0; j < a.shape()[1]; ++j) {
 
   88      dst[out_idx++] = 
op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
 
   89      a_idx += a.strides()[1];
 
   90      b_idx += b.strides()[1];
 
   91      c_idx += c.strides()[1];
 
   93    a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
 
   94    b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
 
   95    c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
 
   99template <
typename T1, 
typename T2, 
typename T3, 
typename U, 
typename Op>
 
  100void ternary_op_dims3(
 
  106  const T1* a_ptr = a.data<T1>();
 
  107  const T2* b_ptr = b.data<T2>();
 
  108  const T3* c_ptr = c.data<T3>();
 
  109  U* dst = out.data<U>();
 
  114  for (
size_t i = 0; i < a.shape()[0]; ++i) {
 
  115    for (
size_t j = 0; j < a.shape()[1]; ++j) {
 
  116      for (
size_t k = 0; k < a.shape()[2]; ++k) {
 
  117        dst[out_idx++] = 
op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
 
  118        a_idx += a.strides()[2];
 
  119        b_idx += b.strides()[2];
 
  120        c_idx += c.strides()[2];
 
  122      a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
 
  123      b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
 
  124      c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
 
  126    a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
 
  127    b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
 
  128    c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
 
  132template <
typename T1, 
typename T2, 
typename T3, 
typename U, 
typename Op>
 
  133void ternary_op_dims4(
 
  139  const T1* a_ptr = a.data<T1>();
 
  140  const T2* b_ptr = b.data<T2>();
 
  141  const T3* c_ptr = c.data<T3>();
 
  143  U* dst = out.data<U>();
 
  148  for (
size_t i = 0; i < a.shape()[0]; ++i) {
 
  149    for (
size_t j = 0; j < a.shape()[1]; ++j) {
 
  150      for (
size_t k = 0; k < a.shape()[2]; ++k) {
 
  151        for (
size_t ii = 0; ii < a.shape()[3]; ++ii) {
 
  152          dst[out_idx++] = 
op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
 
  153          a_idx += a.strides()[3];
 
  154          b_idx += b.strides()[3];
 
  155          c_idx += c.strides()[3];
 
  157        a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
 
  158        b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
 
  159        c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
 
  161      a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
 
  162      b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
 
  163      c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
 
  165    a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
 
  166    b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
 
  167    c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
 
  171template <
typename T1, 
typename T2, 
typename T3, 
typename U, 
typename Op>
 
  172void ternary_op_dispatch_dims(
 
  178  switch (out.ndim()) {
 
  180      ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, 
op);
 
  183      ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, 
op);
 
  186      ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, 
op);
 
  189      ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, 
op);
 
  193  const T1* a_ptr = a.data<T1>();
 
  194  const T2* b_ptr = b.data<T2>();
 
  195  const T3* c_ptr = c.data<T3>();
 
  196  U* dst = out.data<U>();
 
  197  for (
size_t i = 0; i < out.size(); i++) {
 
  198    int a_idx = 
elem_to_loc(i, a.shape(), a.strides());
 
  199    int b_idx = 
elem_to_loc(i, b.shape(), b.strides());
 
  200    int c_idx = 
elem_to_loc(i, c.shape(), c.strides());
 
  201    dst[i] = 
op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
 
  205template <
typename T1, 
typename T2, 
typename T3, 
typename U, 
typename Op>
 
  212  TernaryOpType topt = get_ternary_op_type(a, b, c);
 
  213  set_ternary_op_output_data(a, b, c, out, topt);
 
  216  if (topt == TernaryOpType::ScalarScalarScalar) {
 
  217    *(out.data<U>()) = 
op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
 
  221  ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, 
op);
 
Op op
Definition binary.h:141
 
Buffer malloc_or_wait(size_t size)
 
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12