13enum class TernaryOpType {
 
   20get_ternary_op_type(
const array& a, 
const array& b, 
const array& c) {
 
   22  if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
 
   23    topt = TernaryOpType::ScalarScalarScalar;
 
   25      (a.flags().row_contiguous && b.flags().row_contiguous &&
 
   26       c.flags().row_contiguous) ||
 
   27      (a.flags().col_contiguous && b.flags().col_contiguous &&
 
   28       c.flags().col_contiguous)) {
 
   29    topt = TernaryOpType::VectorVectorVector;
 
   31    topt = TernaryOpType::General;
 
   36void set_ternary_op_output_data(
 
   42    bool donate_with_move = 
false) {
 
   43  auto maybe_donate = [&out, donate_with_move](
const array& x) {
 
   45      if (donate_with_move) {
 
   46        out.move_shared_buffer(x);
 
   48        out.copy_shared_buffer(x);
 
   56    case TernaryOpType::ScalarScalarScalar:
 
   60    case TernaryOpType::VectorVectorVector:
 
   61      if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
 
   69    case TernaryOpType::General:
 
   74template <
typename T1, 
typename T2, 
typename T3, 
typename U, 
typename Op, 
int D>
 
   81    const std::vector<int>& shape,
 
   82    const std::vector<size_t>& a_strides,
 
   83    const std::vector<size_t>& b_strides,
 
   84    const std::vector<size_t>& c_strides,
 
   85    const std::vector<size_t>& out_strides,
 
   87  auto stride_a = a_strides[axis];
 
   88  auto stride_b = b_strides[axis];
 
   89  auto stride_c = c_strides[axis];
 
   90  auto stride_out = out_strides[axis];
 
   93  for (
int i = 0; i < N; i++) {
 
   94    if constexpr (D > 1) {
 
   95      ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
 
  108      *out = 
op(*a, *b, *c);
 
  117template <
typename T1, 
typename T2, 
typename T3, 
typename U, 
typename Op>
 
  118void ternary_op_dispatch_dims(
 
  125      a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
 
  126  const auto& a_strides = strides[0];
 
  127  const auto& b_strides = strides[1];
 
  128  const auto& c_strides = strides[2];
 
  129  const auto& out_strides = strides[3];
 
  131  const T1* a_ptr = a.data<T1>();
 
  132  const T2* b_ptr = b.data<T2>();
 
  133  const T3* c_ptr = c.data<T3>();
 
  134  U* out_ptr = out.data<T3>();
 
  135  int ndim = shape.size();
 
  138      ternary_op_dims<T1, T2, T3, U, Op, 1>(
 
  152      ternary_op_dims<T1, T2, T3, U, Op, 2>(
 
  167  ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
 
  168  ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
 
  169  ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
 
  170  size_t stride = out_strides[ndim - 3];
 
  171  for (
size_t elem = 0; elem < a.size(); elem += stride) {
 
  172    ternary_op_dims<T1, T2, T3, U, Op, 2>(
 
  190template <
typename T1, 
typename T2, 
typename T3, 
typename U, 
typename Op>
 
  197  TernaryOpType topt = get_ternary_op_type(a, b, c);
 
  198  set_ternary_op_output_data(a, b, c, out, topt);
 
  201  if (topt == TernaryOpType::ScalarScalarScalar) {
 
  202    *(out.data<U>()) = 
op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
 
  203  } 
else if (topt == TernaryOpType::VectorVectorVector) {
 
  204    const T1* a_ptr = a.data<T1>();
 
  205    const T2* b_ptr = b.data<T2>();
 
  206    const T3* c_ptr = c.data<T3>();
 
  207    U* out_ptr = out.data<U>();
 
  208    for (
size_t i = 0; i < out.size(); ++i) {
 
  209      *out_ptr = 
op(*a_ptr, *b_ptr, *c_ptr);
 
  216    ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, 
op);
 
Op op
Definition binary.h:129
 
Buffer malloc_or_wait(size_t size)
 
std::tuple< std::vector< int >, std::vector< std::vector< int64_t > > > collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< int64_t > > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
 
bool is_donatable(const array &in, const array &out)
Definition utils.h:174