13enum class TernaryOpType {
20get_ternary_op_type(
const array& a,
const array& b,
const array& c) {
22 if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
23 topt = TernaryOpType::ScalarScalarScalar;
25 (a.flags().row_contiguous && b.flags().row_contiguous &&
26 c.flags().row_contiguous) ||
27 (a.flags().col_contiguous && b.flags().col_contiguous &&
28 c.flags().col_contiguous)) {
29 topt = TernaryOpType::VectorVectorVector;
31 topt = TernaryOpType::General;
36void set_ternary_op_output_data(
42 bool donate_with_move =
false) {
43 auto maybe_donate = [&out, donate_with_move](
const array& x) {
45 if (donate_with_move) {
46 out.move_shared_buffer(x);
48 out.copy_shared_buffer(x);
56 case TernaryOpType::ScalarScalarScalar:
60 case TernaryOpType::VectorVectorVector:
61 if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
69 case TernaryOpType::General:
75template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op>
82 const T1* a_ptr = a.data<T1>();
83 const T2* b_ptr = b.data<T2>();
84 const T3* c_ptr = c.data<T3>();
86 U* dst = out.data<U>();
90 for (
size_t i = 0; i < out.size(); ++i) {
91 dst[i] =
op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
92 a_idx += a.strides()[0];
93 b_idx += b.strides()[0];
94 c_idx += c.strides()[0];
98template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op>
105 const T1* a_ptr = a.data<T1>();
106 const T2* b_ptr = b.data<T2>();
107 const T3* c_ptr = c.data<T3>();
109 U* dst = out.data<U>();
114 for (
size_t i = 0; i < a.shape()[0]; ++i) {
115 for (
size_t j = 0; j < a.shape()[1]; ++j) {
116 dst[out_idx++] =
op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
117 a_idx += a.strides()[1];
118 b_idx += b.strides()[1];
119 c_idx += c.strides()[1];
121 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
122 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
123 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
127template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op>
128void ternary_op_dims3(
134 const T1* a_ptr = a.data<T1>();
135 const T2* b_ptr = b.data<T2>();
136 const T3* c_ptr = c.data<T3>();
137 U* dst = out.data<U>();
142 for (
size_t i = 0; i < a.shape()[0]; ++i) {
143 for (
size_t j = 0; j < a.shape()[1]; ++j) {
144 for (
size_t k = 0; k < a.shape()[2]; ++k) {
145 dst[out_idx++] =
op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
146 a_idx += a.strides()[2];
147 b_idx += b.strides()[2];
148 c_idx += c.strides()[2];
150 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
151 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
152 c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
154 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
155 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
156 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
160template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op>
161void ternary_op_dims4(
167 const T1* a_ptr = a.data<T1>();
168 const T2* b_ptr = b.data<T2>();
169 const T3* c_ptr = c.data<T3>();
171 U* dst = out.data<U>();
176 for (
size_t i = 0; i < a.shape()[0]; ++i) {
177 for (
size_t j = 0; j < a.shape()[1]; ++j) {
178 for (
size_t k = 0; k < a.shape()[2]; ++k) {
179 for (
size_t ii = 0; ii < a.shape()[3]; ++ii) {
180 dst[out_idx++] =
op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
181 a_idx += a.strides()[3];
182 b_idx += b.strides()[3];
183 c_idx += c.strides()[3];
185 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
186 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
187 c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
189 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
190 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
191 c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
193 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
194 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
195 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
199template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op>
200void ternary_op_dispatch_dims(
206 switch (out.ndim()) {
208 ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out,
op);
211 ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out,
op);
214 ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out,
op);
217 ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out,
op);
221 const T1* a_ptr = a.data<T1>();
222 const T2* b_ptr = b.data<T2>();
223 const T3* c_ptr = c.data<T3>();
224 U* dst = out.data<U>();
225 for (
size_t i = 0; i < out.size(); i++) {
226 int a_idx =
elem_to_loc(i, a.shape(), a.strides());
227 int b_idx =
elem_to_loc(i, b.shape(), b.strides());
228 int c_idx =
elem_to_loc(i, c.shape(), c.strides());
229 dst[i] =
op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
233template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op>
240 TernaryOpType topt = get_ternary_op_type(a, b, c);
241 set_ternary_op_output_data(a, b, c, out, topt);
244 if (topt == TernaryOpType::ScalarScalarScalar) {
245 *(out.data<U>()) =
op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
249 ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out,
op);
Op op
Definition binary.h:141
Buffer malloc_or_wait(size_t size)
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
bool is_donatable(const array &in, const array &out)
Definition utils.h:158