54 std::function<
void(
int)> callback,
60 const std::vector<int>& axes);
62template <
typename T,
typename U,
typename Op>
68 void operator()(
const T* x, U* accumulator,
int size,
size_t stride) {
69 for (
int i = 0; i < size; i++) {
70 U* moving_accumulator = accumulator;
71 for (
int j = 0; j < stride; j++) {
72 op(moving_accumulator, *x);
80template <
typename T,
typename U,
typename Op>
94template <
typename T,
typename U,
typename OpS,
typename OpC,
typename Op>
98 const std::vector<int>& axes,
107 U* out_ptr = out.
data<U>();
109 opc(x.
data<T>(), out_ptr, x.
size());
114 int reduction_size = plan.
shape[0];
115 const T* x_ptr = x.
data<T>();
116 U* out_ptr = out.
data<U>();
117 for (
int i = 0; i < out.
size(); i++, out_ptr++, x_ptr += reduction_size) {
119 opc(x_ptr, out_ptr, reduction_size);
125 int reduction_size = plan.
shape.back();
126 plan.
shape.pop_back();
128 const T* x_ptr = x.
data<T>();
129 U* out_ptr = out.
data<U>();
133 if (plan.
shape.size() == 0) {
134 for (
int i = 0; i < out.
size(); i++, out_ptr++) {
137 opc(x_ptr + offset, out_ptr, reduction_size);
140 for (
int i = 0; i < out.
size(); i++, out_ptr++) {
144 [&](
int extra_offset) {
145 opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
155 int reduction_size = plan.
shape.back();
156 size_t reduction_stride = plan.
strides.back();
157 plan.
shape.pop_back();
159 const T* x_ptr = x.
data<T>();
160 U* out_ptr = out.
data<U>();
161 for (
int i = 0; i < out.
size(); i += reduction_stride) {
162 std::fill_n(out_ptr, reduction_stride, init);
163 ops(x_ptr, out_ptr, reduction_size, reduction_stride);
164 x_ptr += reduction_stride * reduction_size;
165 out_ptr += reduction_stride;
172 int reduction_size = plan.
shape.back();
173 size_t reduction_stride = plan.
strides.back();
174 plan.
shape.pop_back();
176 const T* x_ptr = x.
data<T>();
177 U* out_ptr = out.
data<U>();
179 if (plan.
shape.size() == 0) {
180 for (
int i = 0; i < out.
size(); i += reduction_stride) {
182 std::fill_n(out_ptr, reduction_stride, init);
183 ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
184 out_ptr += reduction_stride;
187 for (
int i = 0; i < out.
size(); i += reduction_stride) {
189 std::fill_n(out_ptr, reduction_stride, init);
191 [&](
int extra_offset) {
192 ops(x_ptr + offset + extra_offset,
199 out_ptr += reduction_stride;
206 const T* x_ptr = x.
data<T>();
207 U* out_ptr = out.
data<U>();
209 for (
int i = 0; i < out.
size(); i++, out_ptr++) {
213 [&](
int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); },
221template <
typename T,
typename U,
typename Op>
225 const std::vector<int>& axes,
size_t nbytes() const
The number of bytes in the array.
Definition array.h:93
size_t size() const
The number of elements in the array.
Definition array.h:88
T * data()
Definition array.h:342
void set_data(allocator::Buffer buffer, Deleter d=allocator::free)
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Buffer malloc_or_wait(size_t size)
std::pair< Shape, Strides > shapes_without_reduction_axes(const array &x, const std::vector< int > &axes)
ReductionOpType
Definition reduce.h:9
@ GeneralReduce
Definition reduce.h:36
@ GeneralContiguousReduce
Definition reduce.h:25
@ ContiguousStridedReduce
Definition reduce.h:19
@ ContiguousReduce
Definition reduce.h:15
@ GeneralStridedReduce
Definition reduce.h:30
@ ContiguousAllReduce
Definition reduce.h:11
int64_t elem_to_loc(int elem, const Shape &shape, const Strides &strides)
Definition utils.h:12
std::vector< ShapeElem > Shape
Definition array.h:21
std::vector< int64_t > Strides
Definition array.h:22
void reduction_op(const array &x, array &out, const std::vector< int > &axes, U init, OpS ops, OpC opc, Op op)
Definition reduce.h:95
ReductionPlan get_reduction_plan(const array &x, const std::vector< int > &axes)
void nd_loop(std::function< void(int)> callback, const Shape &shape, const Strides &strides)
void operator()(const T *x, U *accumulator, int size)
Definition reduce.h:86
Op op
Definition reduce.h:82
DefaultContiguousReduce(Op op_)
Definition reduce.h:84
void operator()(const T *x, U *accumulator, int size, size_t stride)
Definition reduce.h:68
DefaultStridedReduce(Op op_)
Definition reduce.h:66
Op op
Definition reduce.h:64
ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
Definition reduce.h:44
Shape shape
Definition reduce.h:41
ReductionOpType type
Definition reduce.h:40
Strides strides
Definition reduce.h:42
ReductionPlan(ReductionOpType type_)
Definition reduce.h:46