21 template <
typename T,
typename U>
22 void operator()(
const T* a,
const T* b, U* dst,
int size) {
32 *dst =
op(*a, scalar);
45 template <
typename T,
typename U>
46 void operator()(
const T* a,
const T* b, U* dst,
int size) {
56 *dst =
op(scalar, *b);
69 template <
typename T,
typename U>
70 void operator()(
const T* a,
const T* b, U* dst,
int size) {
88template <
typename T,
typename U,
typename Op,
int D,
bool Str
ided>
99 auto stride_a = a_strides[axis];
100 auto stride_b = b_strides[axis];
101 auto stride_out = out_strides[axis];
102 auto N = shape[axis];
104 for (
int i = 0; i < N; i++) {
105 if constexpr (D > 1) {
107 a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
109 if constexpr (Strided) {
110 op(a, b, out, stride_out);
121template <
typename T,
typename U,
bool Str
ided,
typename Op>
132 const T* a_ptr = a.
data<T>();
133 const T* b_ptr = b.
data<T>();
134 U* out_ptr = out.
data<U>();
176 auto stride = out_strides[dim - 4];
177 for (int64_t elem = 0; elem < a.
size(); elem += stride) {
193template <
typename T,
typename U,
typename Op>
224 a.
shape(), {a.strides(), b.strides(), out.strides()});
225 const auto& a_strides = new_strides[0];
226 const auto& b_strides = new_strides[1];
227 const auto& strides = new_strides[2];
230 auto leftmost_rc_dim = [&strides](
const auto& arr_strides) {
231 int d = arr_strides.size() - 1;
232 for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
236 auto a_rc_dim = leftmost_rc_dim(a_strides);
237 auto b_rc_dim = leftmost_rc_dim(b_strides);
240 auto leftmost_s_dim = [](
const auto& arr_strides) {
241 int d = arr_strides.size() - 1;
242 for (; d >= 0 && arr_strides[d] == 0; d--) {
246 auto a_s_dim = leftmost_s_dim(a_strides);
247 auto b_s_dim = leftmost_s_dim(b_strides);
249 auto ndim = new_shape.size();
253 if (
int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
258 }
else if (
int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
263 }
else if (
int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
271 if (dim == 0 || strides[dim - 1] < 16) {
315 a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
320template <
typename T,
typename Op>
325template <
typename Op>
327 switch (out.
dtype()) {
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
size_t size() const
The number of elements in the array.
Definition array.h:88
T * data()
Definition array.h:354
Dtype dtype() const
Get the arrays data type.
Definition array.h:131
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:332
Simd< T, N > load(const T *x)
Definition base_simd.h:28
static constexpr int max_size
Definition base_simd.h:14
void store(T *dst, Simd< T, N > x)
Definition base_simd.h:33
constexpr Dtype bool_
Definition dtype.h:68
constexpr Dtype uint64
Definition dtype.h:73
BinaryOpType get_binary_op_type(const array &a, const array &b)
Definition binary.h:19
constexpr Dtype uint16
Definition dtype.h:71
constexpr Dtype float64
Definition dtype.h:82
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
constexpr Dtype bfloat16
Definition dtype.h:83
@ General
Definition binary.h:16
@ VectorVector
Definition binary.h:15
@ ScalarScalar
Definition binary.h:12
@ VectorScalar
Definition binary.h:14
@ ScalarVector
Definition binary.h:13
constexpr Dtype int32
Definition dtype.h:77
void binary_op_dispatch_dims(const array &a, const array &b, array &out, Op op, int dim, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides)
Definition binary.h:122
constexpr Dtype float32
Definition dtype.h:81
std::vector< ShapeElem > Shape
Definition array.h:21
void set_binary_op_output_data(const array &a, const array &b, array &out, BinaryOpType bopt, bool donate_with_move=false)
Definition binary.h:37
constexpr Dtype int16
Definition dtype.h:76
std::vector< int64_t > Strides
Definition array.h:22
void binary_op_dims(const T *a, const T *b, U *out, Op op, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides, int axis)
Definition binary.h:89
constexpr Dtype int8
Definition dtype.h:75
constexpr Dtype int64
Definition dtype.h:78
constexpr Dtype uint8
Definition dtype.h:70
void binary_op(const array &a, const array &b, array &out, Op op)
Definition binary.h:194
constexpr Dtype float16
Definition dtype.h:80
constexpr Dtype uint32
Definition dtype.h:72
void binary(const array &a, const array &b, array &out, Op op)
Definition binary.h:326
constexpr Dtype complex64
Definition dtype.h:84
int64_t loc
Definition utils.h:126
void step()
Definition utils.h:74
ScalarVector(Op op_)
Definition binary.h:43
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:46
Op op
Definition binary.h:41
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:22
Op op
Definition binary.h:17
VectorScalar(Op op_)
Definition binary.h:19
VectorVector(Op op_)
Definition binary.h:67
Op op
Definition binary.h:65
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:70
Definition accelerate_simd.h:55