11template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op,
int D>
24 auto stride_a = a_strides[axis];
25 auto stride_b = b_strides[axis];
26 auto stride_c = c_strides[axis];
27 auto stride_out = out_strides[axis];
30 for (
int i = 0; i < N; i++) {
31 if constexpr (D > 1) {
45 *out = op(*a, *b, *c);
54template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op>
62 a.
shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
63 const auto& a_strides = strides[0];
64 const auto& b_strides = strides[1];
65 const auto& c_strides = strides[2];
66 const auto& out_strides = strides[3];
68 const T1* a_ptr = a.
data<T1>();
69 const T2* b_ptr = b.
data<T2>();
70 const T3* c_ptr = c.
data<T3>();
71 U* out_ptr = out.
data<T3>();
72 int ndim = shape.size();
107 auto stride = out_strides[ndim - 3];
108 for (
size_t elem = 0; elem < a.
size(); elem += stride) {
127template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op>
141 const T1* a_ptr = a.
data<T1>();
142 const T2* b_ptr = b.
data<T2>();
143 const T3* c_ptr = c.
data<T3>();
144 U* out_ptr = out.
data<U>();
145 for (
size_t i = 0; i < out.
size(); ++i) {
146 *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
size_t size() const
The number of elements in the array.
Definition array.h:88
T * data()
Definition array.h:342
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
TernaryOpType get_ternary_op_type(const array &a, const array &b, const array &c)
Definition ternary.h:18
std::vector< ShapeElem > Shape
Definition array.h:21
void set_ternary_op_output_data(const array &a, const array &b, const array &c, array &out, TernaryOpType topt, bool donate_with_move=false)
Definition ternary.h:34
std::vector< int64_t > Strides
Definition array.h:22
void ternary_op_dims(const T1 *a, const T2 *b, const T3 *c, U *out, Op op, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &c_strides, const Strides &out_strides, int axis)
Definition ternary.h:12
void ternary_op(const array &a, const array &b, const array &c, array &out, Op op)
Definition ternary.h:128
void ternary_op_dispatch_dims(const array &a, const array &b, const array &c, array &out, Op op)
Definition ternary.h:55
TernaryOpType
Definition ternary.h:11
@ ScalarScalarScalar
Definition ternary.h:12
@ VectorVectorVector
Definition ternary.h:13
int64_t loc
Definition utils.h:126
void step()
Definition utils.h:74