16 template <
typename T,
typename U>
17 void operator()(
const T* a,
const T* b, U* dst,
int size) {
27 *dst = Op{}(*a, scalar);
36 template <
typename T,
typename U>
37 void operator()(
const T* a,
const T* b, U* dst,
int size) {
47 *dst = Op{}(scalar, *b);
56 template <
typename T,
typename U>
57 void operator()(
const T* a,
const T* b, U* dst,
int size) {
75template <
typename T,
typename U,
typename Op,
int D,
bool Str
ided>
85 auto stride_a = a_strides[axis];
86 auto stride_b = b_strides[axis];
87 auto stride_out = out_strides[axis];
90 for (
int i = 0; i < N; i++) {
91 if constexpr (D > 1) {
93 a, b, out, shape, a_strides, b_strides, out_strides, axis + 1);
95 if constexpr (Strided) {
96 Op{}(a, b, out, stride_out);
107template <
typename T,
typename U,
bool Str
ided,
typename Op>
121 a, b, out, shape, a_strides, b_strides, out_strides, 0);
125 a, b, out, shape, a_strides, b_strides, out_strides, 0);
129 a, b, out, shape, a_strides, b_strides, out_strides, 0);
135 auto stride = out_strides[dim - 4];
136 for (int64_t elem = 0; elem < size; elem += stride) {
151template <
typename T,
typename U,
typename Op>
154 auto a_ptr = a.
data<T>();
155 auto b_ptr = b.
data<T>();
157 auto out_ptr = out.
data<U>();
159 *out_ptr = Op{}(*a_ptr, *b_ptr);
183 a.
shape(), {a.strides(), b.strides(), out.strides()});
184 auto& a_strides = new_strides[0];
185 auto& b_strides = new_strides[1];
186 auto& strides = new_strides[2];
189 auto leftmost_rc_dim = [&strides](
const auto& arr_strides) {
190 int d = arr_strides.size() - 1;
191 for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
195 auto a_rc_dim = leftmost_rc_dim(a_strides);
196 auto b_rc_dim = leftmost_rc_dim(b_strides);
199 auto leftmost_s_dim = [](
const auto& arr_strides) {
200 int d = arr_strides.size() - 1;
201 for (; d >= 0 && arr_strides[d] == 0; d--) {
205 auto a_s_dim = leftmost_s_dim(a_strides);
206 auto b_s_dim = leftmost_s_dim(b_strides);
208 auto ndim = new_shape.size();
213 if (
int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
218 }
else if (
int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
223 }
else if (
int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
231 if (dim == 0 || strides[dim - 1] < 16) {
288template <
typename T,
typename Op>
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
size_t size() const
The number of elements in the array.
Definition array.h:88
T * data()
Definition array.h:349
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:327
Simd< T, N > load(const T *x)
Definition base_simd.h:28
static constexpr int max_size
Definition base_simd.h:14
void store(T *dst, Simd< T, N > x)
Definition base_simd.h:33
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
BinaryOpType
Definition binary.h:11
@ General
Definition binary.h:16
@ VectorVector
Definition binary.h:15
@ ScalarScalar
Definition binary.h:12
@ VectorScalar
Definition binary.h:14
@ ScalarVector
Definition binary.h:13
std::vector< ShapeElem > Shape
Definition array.h:21
std::vector< int64_t > Strides
Definition array.h:22
void binary_op_dispatch_dims(const T *a, const T *b, U *out, int dim, int size, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides)
Definition binary.h:108
void binary_op(const array &a, const array &b, array &out, BinaryOpType bopt)
Definition binary.h:152
void binary_op_dims(const T *a, const T *b, U *out, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides, int axis)
Definition binary.h:76
int64_t loc
Definition utils.h:126
void step()
Definition utils.h:74
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:37
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:17
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:57
Definition accelerate_simd.h:55