21  template <
typename T, 
typename U>
 
   22  void operator()(
const T* a, 
const T* b, U* dst, 
int size) {
 
   32      *dst = 
op(*a, scalar);
 
 
 
   45  template <
typename T, 
typename U>
 
   46  void operator()(
const T* a, 
const T* b, U* dst, 
int size) {
 
   56      *dst = 
op(scalar, *b);
 
 
 
   69  template <
typename T, 
typename U>
 
   70  void operator()(
const T* a, 
const T* b, U* dst, 
int size) {
 
 
 
   88template <
typename T, 
typename U, 
typename Op, 
int D, 
bool Str
ided>
 
   99  auto stride_a = a_strides[axis];
 
  100  auto stride_b = b_strides[axis];
 
  101  auto stride_out = out_strides[axis];
 
  102  auto N = shape[axis];
 
  104  for (
int i = 0; i < N; i++) {
 
  105    if constexpr (D > 1) {
 
  107          a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
 
  109      if constexpr (Strided) {
 
  110        op(a, b, out, stride_out);
 
 
  121template <
typename T, 
typename U, 
bool Str
ided, 
typename Op>
 
  132  const T* a_ptr = a.
data<T>();
 
  133  const T* b_ptr = b.
data<T>();
 
  134  U* out_ptr = out.
data<U>();
 
  176  auto stride = out_strides[dim - 4];
 
  177  for (int64_t elem = 0; elem < a.
size(); elem += stride) {
 
 
  193template <
typename T, 
typename U, 
typename Op>
 
  224      a.
shape(), {a.strides(), b.strides(), out.strides()});
 
  225  const auto& a_strides = new_strides[0];
 
  226  const auto& b_strides = new_strides[1];
 
  227  const auto& strides = new_strides[2];
 
  230  auto leftmost_rc_dim = [&strides](
const auto& arr_strides) {
 
  231    int d = arr_strides.size() - 1;
 
  232    for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
 
  236  auto a_rc_dim = leftmost_rc_dim(a_strides);
 
  237  auto b_rc_dim = leftmost_rc_dim(b_strides);
 
  240  auto leftmost_s_dim = [](
const auto& arr_strides) {
 
  241    int d = arr_strides.size() - 1;
 
  242    for (; d >= 0 && arr_strides[d] == 0; d--) {
 
  246  auto a_s_dim = leftmost_s_dim(a_strides);
 
  247  auto b_s_dim = leftmost_s_dim(b_strides);
 
  249  auto ndim = new_shape.size();
 
  253  if (
int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
 
  258  } 
else if (
int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
 
  263  } 
else if (
int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
 
  271  if (dim == 0 || strides[dim - 1] < 16) {
 
  315          a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
 
 
  320template <
typename T, 
typename Op>
 
  325template <
typename Op>
 
  327  switch (out.
dtype()) {
 
 
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
 
size_t size() const
The number of elements in the array.
Definition array.h:88
 
T * data()
Definition array.h:354
 
Dtype dtype() const
Get the arrays data type.
Definition array.h:131
 
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:332
 
Simd< T, N > load(const T *x)
Definition base_simd.h:28
 
static constexpr int max_size
Definition base_simd.h:14
 
void store(T *dst, Simd< T, N > x)
Definition base_simd.h:33
 
constexpr Dtype bool_
Definition dtype.h:68
 
constexpr Dtype uint64
Definition dtype.h:73
 
BinaryOpType get_binary_op_type(const array &a, const array &b)
Definition binary.h:19
 
constexpr Dtype uint16
Definition dtype.h:71
 
constexpr Dtype float64
Definition dtype.h:82
 
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
 
constexpr Dtype bfloat16
Definition dtype.h:83
 
@ General
Definition binary.h:16
 
@ VectorVector
Definition binary.h:15
 
@ ScalarScalar
Definition binary.h:12
 
@ VectorScalar
Definition binary.h:14
 
@ ScalarVector
Definition binary.h:13
 
constexpr Dtype int32
Definition dtype.h:77
 
void binary_op_dispatch_dims(const array &a, const array &b, array &out, Op op, int dim, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides)
Definition binary.h:122
 
constexpr Dtype float32
Definition dtype.h:81
 
std::vector< ShapeElem > Shape
Definition array.h:21
 
void set_binary_op_output_data(const array &a, const array &b, array &out, BinaryOpType bopt, bool donate_with_move=false)
Definition binary.h:37
 
constexpr Dtype int16
Definition dtype.h:76
 
std::vector< int64_t > Strides
Definition array.h:22
 
void binary_op_dims(const T *a, const T *b, U *out, Op op, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides, int axis)
Definition binary.h:89
 
constexpr Dtype int8
Definition dtype.h:75
 
constexpr Dtype int64
Definition dtype.h:78
 
constexpr Dtype uint8
Definition dtype.h:70
 
void binary_op(const array &a, const array &b, array &out, Op op)
Definition binary.h:194
 
constexpr Dtype float16
Definition dtype.h:80
 
constexpr Dtype uint32
Definition dtype.h:72
 
void binary(const array &a, const array &b, array &out, Op op)
Definition binary.h:326
 
constexpr Dtype complex64
Definition dtype.h:84
 
int64_t loc
Definition utils.h:126
 
void step()
Definition utils.h:74
 
ScalarVector(Op op_)
Definition binary.h:43
 
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:46
 
Op op
Definition binary.h:41
 
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:22
 
Op op
Definition binary.h:17
 
VectorScalar(Op op_)
Definition binary.h:19
 
VectorVector(Op op_)
Definition binary.h:67
 
Op op
Definition binary.h:65
 
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:70
 
Definition accelerate_simd.h:55