13enum class TernaryOpType {
20get_ternary_op_type(
const array& a,
const array& b,
const array& c) {
22 if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
23 topt = TernaryOpType::ScalarScalarScalar;
25 (a.flags().row_contiguous && b.flags().row_contiguous &&
26 c.flags().row_contiguous) ||
27 (a.flags().col_contiguous && b.flags().col_contiguous &&
28 c.flags().col_contiguous)) {
29 topt = TernaryOpType::VectorVectorVector;
31 topt = TernaryOpType::General;
36void set_ternary_op_output_data(
42 bool donate_with_move =
false) {
43 auto maybe_donate = [&out, donate_with_move](
const array& x) {
45 if (donate_with_move) {
46 out.move_shared_buffer(x);
48 out.copy_shared_buffer(x);
56 case TernaryOpType::ScalarScalarScalar:
60 case TernaryOpType::VectorVectorVector:
61 if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
69 case TernaryOpType::General:
74template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op,
int D>
81 const std::vector<int>& shape,
82 const std::vector<size_t>& a_strides,
83 const std::vector<size_t>& b_strides,
84 const std::vector<size_t>& c_strides,
85 const std::vector<size_t>& out_strides,
87 auto stride_a = a_strides[axis];
88 auto stride_b = b_strides[axis];
89 auto stride_c = c_strides[axis];
90 auto stride_out = out_strides[axis];
93 for (
int i = 0; i < N; i++) {
94 if constexpr (D > 1) {
95 ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
108 *out =
op(*a, *b, *c);
117template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op>
118void ternary_op_dispatch_dims(
125 a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
126 const auto& a_strides = strides[0];
127 const auto& b_strides = strides[1];
128 const auto& c_strides = strides[2];
129 const auto& out_strides = strides[3];
131 const T1* a_ptr = a.data<T1>();
132 const T2* b_ptr = b.data<T2>();
133 const T3* c_ptr = c.data<T3>();
134 U* out_ptr = out.data<T3>();
135 int ndim = shape.size();
138 ternary_op_dims<T1, T2, T3, U, Op, 1>(
152 ternary_op_dims<T1, T2, T3, U, Op, 2>(
167 ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
168 ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
169 ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
170 size_t stride = out_strides[ndim - 3];
171 for (
size_t elem = 0; elem < a.size(); elem += stride) {
172 ternary_op_dims<T1, T2, T3, U, Op, 2>(
190template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op>
197 TernaryOpType topt = get_ternary_op_type(a, b, c);
198 set_ternary_op_output_data(a, b, c, out, topt);
201 if (topt == TernaryOpType::ScalarScalarScalar) {
202 *(out.data<U>()) =
op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
203 }
else if (topt == TernaryOpType::VectorVectorVector) {
204 const T1* a_ptr = a.data<T1>();
205 const T2* b_ptr = b.data<T2>();
206 const T3* c_ptr = c.data<T3>();
207 U* out_ptr = out.data<U>();
208 for (
size_t i = 0; i < out.size(); ++i) {
209 *out_ptr =
op(*a_ptr, *b_ptr, *c_ptr);
216 ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out,
op);
Op op
Definition binary.h:129
Buffer malloc_or_wait(size_t size)
std::tuple< std::vector< int >, std::vector< std::vector< int64_t > > > collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< int64_t > > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
bool is_donatable(const array &in, const array &out)
Definition utils.h:174