MLX
Loading...
Searching...
No Matches
reduce.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
6
7namespace mlx::core {
8
10 // Self-explanatory. Read everything and produce 1 output.
12
13 // The input is contiguous and the last axis is reduced
14 // N1xR1xN2xR2x...xNnxRn
16
17 // The input is contiguous and the last axis is not reduced
18 // R1xN1xR2xN2x...xRnxNn
20
21 // The input is not contiguous but the last axis is and it is reduced so we
22 // need to figure out the offsets but we can call the contiguous reduce after
23 // that.
24 // N3xR1xN1xR4x...xRn
26
27 // The input is not contiguous but the last reduction axis and the last axis
28 // are so we need to figure out the offset but we can call the strided reduce
29 // after that.
31
32 // The input is not contiguous after the reduction axis and it may contain
33 // 0-stride axes or transpositions. We could copy the strides and produce a
34 // transposed outcome or we can read the input out of order and write the
35 // output in order.
37};
38
41 std::vector<int> shape;
42 std::vector<size_t> strides;
43
45 ReductionOpType type_,
46 std::vector<int> shape_,
47 std::vector<size_t> strides_)
48 : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
50};
51
52ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
53
54// Helper for the ndimensional strided loop
55// Should this be in utils?
57 std::function<void(int)> callback,
58 const std::vector<int>& shape,
59 const std::vector<size_t>& strides);
60
61std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
62 const array& x,
63 const std::vector<int>& axes);
64
65template <typename T, typename U, typename Op>
67 Op op;
68
69 DefaultStridedReduce(Op op_) : op(op_) {}
70
71 void operator()(const T* x, U* accumulator, int size, size_t stride) {
72 for (int i = 0; i < size; i++) {
73 U* moving_accumulator = accumulator;
74 for (int j = 0; j < stride; j++) {
75 op(moving_accumulator, *x);
76 moving_accumulator++;
77 x++;
78 }
79 }
80 }
81};
82
83template <typename T, typename U, typename Op>
85 Op op;
86
87 DefaultContiguousReduce(Op op_) : op(op_) {}
88
89 void operator()(const T* x, U* accumulator, int size) {
90 while (size-- > 0) {
91 op(accumulator, *x);
92 x++;
93 }
94 }
95};
96
97template <typename T, typename U, typename OpS, typename OpC, typename Op>
99 const array& x,
100 array& out,
101 const std::vector<int>& axes,
102 U init,
103 OpS ops,
104 OpC opc,
105 Op op) {
107 ReductionPlan plan = get_reduction_plan(x, axes);
108
109 if (plan.type == ContiguousAllReduce) {
110 U* out_ptr = out.data<U>();
111 *out_ptr = init;
112 opc(x.data<T>(), out_ptr, x.size());
113 return;
114 }
115
116 std::vector<int> shape;
117 std::vector<size_t> strides;
118
119 if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
120 int reduction_size = plan.shape[0];
121 const T* x_ptr = x.data<T>();
122 U* out_ptr = out.data<U>();
123 for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
124 *out_ptr = init;
125 opc(x_ptr, out_ptr, reduction_size);
126 }
127 return;
128 }
129
130 if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {
131 int reduction_size = plan.shape.back();
132 plan.shape.pop_back();
133 plan.strides.pop_back();
134 const T* x_ptr = x.data<T>();
135 U* out_ptr = out.data<U>();
136 // Unrolling the following loop (and implementing it in order for
137 // ContiguousReduce) should hold extra performance boost.
138 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
139 if (plan.shape.size() == 0) {
140 for (int i = 0; i < out.size(); i++, out_ptr++) {
141 int offset = elem_to_loc(i, shape, strides);
142 *out_ptr = init;
143 opc(x_ptr + offset, out_ptr, reduction_size);
144 }
145 } else {
146 for (int i = 0; i < out.size(); i++, out_ptr++) {
147 int offset = elem_to_loc(i, shape, strides);
148 *out_ptr = init;
149 nd_loop(
150 [&](int extra_offset) {
151 opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
152 },
153 plan.shape,
154 plan.strides);
155 }
156 }
157 return;
158 }
159
160 if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {
161 int reduction_size = plan.shape.back();
162 size_t reduction_stride = plan.strides.back();
163 plan.shape.pop_back();
164 plan.strides.pop_back();
165 const T* x_ptr = x.data<T>();
166 U* out_ptr = out.data<U>();
167 for (int i = 0; i < out.size(); i += reduction_stride) {
168 std::fill_n(out_ptr, reduction_stride, init);
169 ops(x_ptr, out_ptr, reduction_size, reduction_stride);
170 x_ptr += reduction_stride * reduction_size;
171 out_ptr += reduction_stride;
172 }
173 return;
174 }
175
176 if (plan.type == GeneralStridedReduce ||
178 int reduction_size = plan.shape.back();
179 size_t reduction_stride = plan.strides.back();
180 plan.shape.pop_back();
181 plan.strides.pop_back();
182 const T* x_ptr = x.data<T>();
183 U* out_ptr = out.data<U>();
184 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
185 if (plan.shape.size() == 0) {
186 for (int i = 0; i < out.size(); i += reduction_stride) {
187 int offset = elem_to_loc(i, shape, strides);
188 std::fill_n(out_ptr, reduction_stride, init);
189 ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
190 out_ptr += reduction_stride;
191 }
192 } else {
193 for (int i = 0; i < out.size(); i += reduction_stride) {
194 int offset = elem_to_loc(i, shape, strides);
195 std::fill_n(out_ptr, reduction_stride, init);
196 nd_loop(
197 [&](int extra_offset) {
198 ops(x_ptr + offset + extra_offset,
199 out_ptr,
200 reduction_size,
201 reduction_stride);
202 },
203 plan.shape,
204 plan.strides);
205 out_ptr += reduction_stride;
206 }
207 }
208 return;
209 }
210
211 if (plan.type == GeneralReduce) {
212 const T* x_ptr = x.data<T>();
213 U* out_ptr = out.data<U>();
214 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
215 for (int i = 0; i < out.size(); i++, out_ptr++) {
216 int offset = elem_to_loc(i, shape, strides);
217 U val = init;
218 nd_loop(
219 [&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); },
220 plan.shape,
221 plan.strides);
222 *out_ptr = val;
223 }
224 }
225}
226
227template <typename T, typename U, typename Op>
229 const array& x,
230 array& out,
231 const std::vector<int>& axes,
232 U init,
233 Op op) {
236 reduction_op<T, U>(x, out, axes, init, ops, opc, op);
237}
238
239} // namespace mlx::core
Definition array.h:20
size_t nbytes() const
The number of bytes in the array.
Definition array.h:89
size_t size() const
The number of elements in the array.
Definition array.h:84
void set_data(allocator::Buffer buffer, deleter_t d=allocator::free)
T * data()
Definition array.h:313
Op op
Definition binary.h:141
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Buffer malloc_or_wait(size_t size)
Definition allocator.h:7
ReductionOpType
Definition reduce.h:9
@ GeneralReduce
Definition reduce.h:36
@ GeneralContiguousReduce
Definition reduce.h:25
@ ContiguousStridedReduce
Definition reduce.h:19
@ ContiguousReduce
Definition reduce.h:15
@ GeneralStridedReduce
Definition reduce.h:30
@ ContiguousAllReduce
Definition reduce.h:11
std::pair< std::vector< int >, std::vector< size_t > > shapes_without_reduction_axes(const array &x, const std::vector< int > &axes)
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
void nd_loop(std::function< void(int)> callback, const std::vector< int > &shape, const std::vector< size_t > &strides)
void reduction_op(const array &x, array &out, const std::vector< int > &axes, U init, OpS ops, OpC opc, Op op)
Definition reduce.h:98
ReductionPlan get_reduction_plan(const array &x, const std::vector< int > &axes)
void operator()(const T *x, U *accumulator, int size)
Definition reduce.h:89
Op op
Definition reduce.h:85
DefaultContiguousReduce(Op op_)
Definition reduce.h:87
Definition reduce.h:66
void operator()(const T *x, U *accumulator, int size, size_t stride)
Definition reduce.h:71
DefaultStridedReduce(Op op_)
Definition reduce.h:69
Op op
Definition reduce.h:67
Definition reduce.h:39
ReductionOpType type
Definition reduce.h:40
ReductionPlan(ReductionOpType type_, std::vector< int > shape_, std::vector< size_t > strides_)
Definition reduce.h:44
std::vector< int > shape
Definition reduce.h:41
std::vector< size_t > strides
Definition reduce.h:42
ReductionPlan(ReductionOpType type_)
Definition reduce.h:49