46 std::vector<int> shape_,
47 std::vector<size_t> strides_)
57 std::function<
void(
int)> callback,
58 const std::vector<int>& shape,
59 const std::vector<size_t>& strides) {
60 std::function<void(
int,
int)> loop_inner;
61 loop_inner = [&](
int dim,
int offset) {
62 if (dim < shape.size() - 1) {
63 int size = shape[dim];
64 size_t stride = strides[dim];
65 for (
int i = 0; i < size; i++) {
66 loop_inner(dim + 1, offset + i * stride);
69 int size = shape[dim];
70 size_t stride = strides[dim];
71 for (
int i = 0; i < size; i++) {
72 callback(offset + i * stride);
79std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
81 const std::vector<int>& axes) {
82 std::vector<int> shape = x.shape();
83 std::vector<size_t> strides = x.strides();
85 for (
int i = axes.size() - 1; i >= 0; i--) {
87 shape.erase(shape.begin() + a);
88 strides.erase(strides.begin() + a);
91 return std::make_pair(shape, strides);
94template <
typename T,
typename U,
typename Op>
95struct DefaultStridedReduce {
98 DefaultStridedReduce(Op op_) :
op(op_) {}
100 void operator()(
const T* x, U* accumulator,
int size,
size_t stride) {
101 for (
int i = 0; i < size; i++) {
102 U* moving_accumulator = accumulator;
103 for (
int j = 0; j < stride; j++) {
104 op(moving_accumulator, *x);
105 moving_accumulator++;
112template <
typename T,
typename U,
typename Op>
113struct DefaultContiguousReduce {
116 DefaultContiguousReduce(Op op_) :
op(op_) {}
118 void operator()(
const T* x, U* accumulator,
int size) {
126ReductionPlan get_reduction_plan(
const array& x,
const std::vector<int> axes) {
128 if (x.size() == x.data_size() && axes.size() == x.ndim() &&
129 x.flags().contiguous) {
134 if (x.flags().row_contiguous) {
136 std::vector<int> shape = {x.shape(axes[0])};
137 std::vector<size_t> strides = {x.strides()[axes[0]]};
138 for (
int i = 1; i < axes.size(); i++) {
139 if (axes[i] - 1 == axes[i - 1]) {
140 shape.back() *= x.shape(axes[i]);
141 strides.back() = x.strides()[axes[i]];
143 shape.push_back(x.shape(axes[i]));
144 strides.push_back(x.strides()[axes[i]]);
148 if (strides.back() == 1) {
150 }
else if (strides.back() > 1) {
168 std::vector<std::pair<int, size_t>> reductions;
169 for (
auto a : axes) {
170 reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
172 std::sort(reductions.begin(), reductions.end(), [](
auto a,
auto b) {
173 return a.second > b.second;
177 for (
int i = reductions.size() - 1; i >= 1; i--) {
178 auto a = reductions[i];
179 auto b = reductions[i - 1];
182 if (b.second == a.first * a.second) {
183 reductions.erase(reductions.begin() + i);
184 reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
188 std::vector<int> shape;
189 std::vector<size_t> strides;
190 for (
auto r : reductions) {
191 shape.push_back(r.first);
192 strides.push_back(r.second);
197 if (strides.back() == 1) {
203 if (strides.back() > 1) {
205 for (
int i = x.ndim() - 1; i >= 0; i--) {
206 if (axes.back() == i) {
209 if (x.strides()[i] != size) {
214 if (size >= strides.back()) {
222template <
typename T,
typename U,
typename OpS,
typename OpC,
typename Op>
226 const std::vector<int>& axes,
232 ReductionPlan plan = get_reduction_plan(x, axes);
235 U* out_ptr = out.data<U>();
237 opc(x.data<T>(), out_ptr, x.size());
241 std::vector<int> shape;
242 std::vector<size_t> strides;
245 int reduction_size = plan.shape[0];
246 const T* x_ptr = x.data<T>();
247 U* out_ptr = out.data<U>();
248 for (
int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
250 opc(x_ptr, out_ptr, reduction_size);
256 int reduction_size = plan.shape.back();
257 plan.shape.pop_back();
258 plan.strides.pop_back();
259 const T* x_ptr = x.data<T>();
260 U* out_ptr = out.data<U>();
263 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
264 if (plan.shape.size() == 0) {
265 for (
int i = 0; i < out.size(); i++, out_ptr++) {
268 opc(x_ptr + offset, out_ptr, reduction_size);
271 for (
int i = 0; i < out.size(); i++, out_ptr++) {
275 [&](
int extra_offset) {
276 opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
286 int reduction_size = plan.shape.back();
287 size_t reduction_stride = plan.strides.back();
288 plan.shape.pop_back();
289 plan.strides.pop_back();
290 const T* x_ptr = x.data<T>();
291 U* out_ptr = out.data<U>();
292 for (
int i = 0; i < out.size(); i += reduction_stride) {
293 std::fill_n(out_ptr, reduction_stride, init);
294 ops(x_ptr, out_ptr, reduction_size, reduction_stride);
295 x_ptr += reduction_stride * reduction_size;
296 out_ptr += reduction_stride;
303 int reduction_size = plan.shape.back();
304 size_t reduction_stride = plan.strides.back();
305 plan.shape.pop_back();
306 plan.strides.pop_back();
307 const T* x_ptr = x.data<T>();
308 U* out_ptr = out.data<U>();
309 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
310 if (plan.shape.size() == 0) {
311 for (
int i = 0; i < out.size(); i += reduction_stride) {
313 std::fill_n(out_ptr, reduction_stride, init);
314 ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
315 out_ptr += reduction_stride;
318 for (
int i = 0; i < out.size(); i += reduction_stride) {
320 std::fill_n(out_ptr, reduction_stride, init);
322 [&](
int extra_offset) {
323 ops(x_ptr + offset + extra_offset,
330 out_ptr += reduction_stride;
337 const T* x_ptr = x.data<T>();
338 U* out_ptr = out.data<U>();
339 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
340 for (
int i = 0; i < out.size(); i++, out_ptr++) {
344 [&](
int extra_offset) {
op(&val, *(x_ptr + offset + extra_offset)); },
352template <
typename T,
typename U,
typename Op>
356 const std::vector<int>& axes,
359 DefaultStridedReduce<T, U, Op> ops(
op);
360 DefaultContiguousReduce<T, U, Op> opc(
op);
361 reduction_op<T, U>(x, out, axes, init, ops, opc,
op);
Op op
Definition binary.h:141
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Buffer malloc_or_wait(size_t size)
Group init(bool strict=false)
Initialize the distributed backend and return the group containing all discoverable processes.
ReductionOpType
Definition reduce.h:9
@ GeneralReduce
Definition reduce.h:36
@ GeneralContiguousReduce
Definition reduce.h:25
@ ContiguousStridedReduce
Definition reduce.h:19
@ ContiguousReduce
Definition reduce.h:15
@ GeneralStridedReduce
Definition reduce.h:30
@ ContiguousAllReduce
Definition reduce.h:11
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
ReductionOpType type
Definition reduce.h:40
ReductionPlan(ReductionOpType type_, std::vector< int > shape_, std::vector< size_t > strides_)
Definition reduce.h:44
std::vector< int > shape
Definition reduce.h:41
std::vector< size_t > strides
Definition reduce.h:42
ReductionPlan(ReductionOpType type_)
Definition reduce.h:49