11template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op,
int D>
24 auto stride_a = a_strides[axis];
25 auto stride_b = b_strides[axis];
26 auto stride_c = c_strides[axis];
27 auto stride_out = out_strides[axis];
30 for (
int i = 0; i < N; i++) {
31 if constexpr (D > 1) {
45 *out = op(*a, *b, *c);
54template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op>
63 std::vector<Strides>& strides) {
64 const auto& a_strides = strides[0];
65 const auto& b_strides = strides[1];
66 const auto& c_strides = strides[2];
67 const auto& out_strides = strides[3];
68 int ndim = shape.size();
103 auto stride = out_strides[ndim - 3];
104 for (
size_t elem = 0; elem < size; elem += stride) {
123template <
typename T1,
typename T2,
typename T3,
typename U,
typename Op>
131 const T1* a_ptr = a.
data<T1>();
132 const T2* b_ptr = b.
data<T2>();
133 const T3* c_ptr = c.
data<T3>();
134 U* out_ptr = out.
data<U>();
137 *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
139 for (
size_t i = 0; i < out.
size(); ++i) {
140 *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
148 a.
shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
150 a_ptr, b_ptr, c_ptr, out_ptr, op, out.
size(), shape, strides);
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
size_t size() const
The number of elements in the array.
Definition array.h:88
T * data()
Definition array.h:349
void ternary_op(const array &a, const array &b, const array &c, array &out, Op op, TernaryOpType topt)
Definition ternary.h:124
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
std::vector< ShapeElem > Shape
Definition array.h:21
std::vector< int64_t > Strides
Definition array.h:22
void ternary_op_dims(const T1 *a, const T2 *b, const T3 *c, U *out, Op op, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &c_strides, const Strides &out_strides, int axis)
Definition ternary.h:12
void ternary_op_dispatch_dims(const T1 *a_ptr, const T2 *b_ptr, const T3 *c_ptr, U *out_ptr, Op op, size_t size, Shape &shape, std::vector< Strides > &strides)
Definition ternary.h:55
TernaryOpType
Definition ternary.h:11
@ ScalarScalarScalar
Definition ternary.h:12
@ VectorVectorVector
Definition ternary.h:13
int64_t loc
Definition utils.h:126
void step()
Definition utils.h:74