12template <
typename T, 
typename U, 
typename Op, 
int D>
 
   24  auto stride_a = a_strides[axis];
 
   25  auto stride_b = b_strides[axis];
 
   26  auto stride_out = out_strides[axis];
 
   29  for (
int i = 0; i < 
N; i++) {
 
   30    if constexpr (D > 1) {
 
   43      std::tie(*out_a, *out_b) = op(*a, *b);
 
   52template <
typename T, 
typename U, 
typename Op>
 
   60      a.shape(), {a.strides(), b.strides(), out_a.strides()});
 
   61  const auto& a_strides = strides[0];
 
   62  const auto& b_strides = strides[1];
 
   63  const auto& out_strides = strides[2];
 
   64  const T* a_ptr = a.data<T>();
 
   65  const T* b_ptr = b.data<T>();
 
   66  U* out_a_ptr = out_a.data<U>();
 
   67  U* out_b_ptr = out_b.data<U>();
 
   69  int ndim = shape.size();
 
  101  auto stride = out_strides[ndim - 3];
 
  102  for (
size_t elem = 0; elem < a.size(); elem += stride) {
 
  119template <
typename T, 
typename U = T, 
typename Op>
 
  123    std::vector<array>& outputs,
 
  126  auto& out_a = outputs[0];
 
  127  auto& out_b = outputs[1];
 
  137  auto a_ptr = a.data<T>();
 
  138  auto b_ptr = b.data<T>();
 
  139  auto out_a_ptr = out_a.data<U>();
 
  140  auto out_b_ptr = out_b.data<U>();
 
  142    std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
 
  144    for (
size_t i = 0; i < b.size(); ++i) {
 
  145      std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
 
  151    for (
size_t i = 0; i < a.size(); ++i) {
 
  152      std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
 
  158    for (
size_t i = 0; i < a.size(); ++i) {
 
  159      std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
 
  168template <
typename Op>
 
  172    std::vector<array>& outputs,
 
  174  switch (outputs[0].dtype()) {
 
constexpr int N
Definition neon_fp16_simd.h:9
 
constexpr Dtype bool_
Definition dtype.h:68
 
constexpr Dtype uint64
Definition dtype.h:73
 
BinaryOpType get_binary_op_type(const array &a, const array &b)
Definition binary.h:19
 
constexpr Dtype uint16
Definition dtype.h:71
 
constexpr Dtype float64
Definition dtype.h:82
 
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
 
constexpr Dtype bfloat16
Definition dtype.h:83
 
@ General
Definition binary.h:16
 
@ ScalarScalar
Definition binary.h:12
 
@ VectorScalar
Definition binary.h:14
 
@ ScalarVector
Definition binary.h:13
 
constexpr Dtype int32
Definition dtype.h:77
 
void binary_op_dispatch_dims(const array &a, const array &b, array &out, Op op, int dim, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides)
Definition binary.h:122
 
constexpr Dtype float32
Definition dtype.h:81
 
std::vector< ShapeElem > Shape
Definition array.h:21
 
void set_binary_op_output_data(const array &a, const array &b, array &out, BinaryOpType bopt, bool donate_with_move=false)
Definition binary.h:37
 
constexpr Dtype int16
Definition dtype.h:76
 
std::vector< int64_t > Strides
Definition array.h:22
 
void binary_op_dims(const T *a, const T *b, U *out, Op op, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides, int axis)
Definition binary.h:89
 
constexpr Dtype int8
Definition dtype.h:75
 
constexpr Dtype int64
Definition dtype.h:78
 
constexpr Dtype uint8
Definition dtype.h:70
 
void binary_op(const array &a, const array &b, array &out, Op op)
Definition binary.h:194
 
constexpr Dtype float16
Definition dtype.h:80
 
constexpr Dtype uint32
Definition dtype.h:72
 
void binary(const array &a, const array &b, array &out, Op op)
Definition binary.h:326
 
constexpr Dtype complex64
Definition dtype.h:84