MLX
Loading...
Searching...
No Matches
reduce.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
6
7namespace mlx::core {
8
10 // Self-explanatory. Read everything and produce 1 output.
12
13 // The input is contiguous and the last axis is reduced
14 // N1xR1xN2xR2x...xNnxRn
16
17 // The input is contiguous and the last axis is not reduced
18 // R1xN1xR2xN2x...xRnxNn
20
21 // The input is not contiguous but the last axis is and it is reduced so we
22 // need to figure out the offsets but we can call the contiguous reduce after
23 // that.
24 // N3xR1xN1xR4x...xRn
26
27 // The input is not contiguous but the last reduction axis and the last axis
28 // are so we need to figure out the offset but we can call the strided reduce
29 // after that.
31
32 // The input is not contiguous after the reduction axis and it may contain
33 // 0-stride axes or transpositions. We could copy the strides and produce a
34 // transposed outcome or we can read the input out of order and write the
35 // output in order.
37};
38
41 std::vector<int> shape;
42 std::vector<size_t> strides;
43
45 ReductionOpType type_,
46 std::vector<int> shape_,
47 std::vector<size_t> strides_)
48 : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
50};
51
52namespace {
53
54// Helper for the ndimensional strided loop
55// Should this be in utils?
56inline void nd_loop(
57 std::function<void(int)> callback,
58 const std::vector<int>& shape,
59 const std::vector<size_t>& strides) {
60 std::function<void(int, int)> loop_inner;
61 loop_inner = [&](int dim, int offset) {
62 if (dim < shape.size() - 1) {
63 int size = shape[dim];
64 size_t stride = strides[dim];
65 for (int i = 0; i < size; i++) {
66 loop_inner(dim + 1, offset + i * stride);
67 }
68 } else {
69 int size = shape[dim];
70 size_t stride = strides[dim];
71 for (int i = 0; i < size; i++) {
72 callback(offset + i * stride);
73 }
74 }
75 };
76 loop_inner(0, 0);
77}
78
79std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
80 const array& x,
81 const std::vector<int>& axes) {
82 std::vector<int> shape = x.shape();
83 std::vector<size_t> strides = x.strides();
84
85 for (int i = axes.size() - 1; i >= 0; i--) {
86 int a = axes[i];
87 shape.erase(shape.begin() + a);
88 strides.erase(strides.begin() + a);
89 }
90
91 return std::make_pair(shape, strides);
92}
93
94template <typename T, typename U, typename Op>
95struct DefaultStridedReduce {
96 Op op;
97
98 DefaultStridedReduce(Op op_) : op(op_) {}
99
100 void operator()(const T* x, U* accumulator, int size, size_t stride) {
101 for (int i = 0; i < size; i++) {
102 U* moving_accumulator = accumulator;
103 for (int j = 0; j < stride; j++) {
104 op(moving_accumulator, *x);
105 moving_accumulator++;
106 x++;
107 }
108 }
109 }
110};
111
112template <typename T, typename U, typename Op>
113struct DefaultContiguousReduce {
114 Op op;
115
116 DefaultContiguousReduce(Op op_) : op(op_) {}
117
118 void operator()(const T* x, U* accumulator, int size) {
119 while (size-- > 0) {
120 op(accumulator, *x);
121 x++;
122 }
123 }
124};
125
126ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
127 // The data is all there and we are reducing over everything
128 if (x.size() == x.data_size() && axes.size() == x.ndim() &&
129 x.flags().contiguous) {
130 return ContiguousAllReduce;
131 }
132
133 // Row contiguous input so the output is row contiguous
134 if (x.flags().row_contiguous) {
135 // Merge consecutive axes
136 std::vector<int> shape = {x.shape(axes[0])};
137 std::vector<size_t> strides = {x.strides()[axes[0]]};
138 for (int i = 1; i < axes.size(); i++) {
139 if (axes[i] - 1 == axes[i - 1]) {
140 shape.back() *= x.shape(axes[i]);
141 strides.back() = x.strides()[axes[i]];
142 } else {
143 shape.push_back(x.shape(axes[i]));
144 strides.push_back(x.strides()[axes[i]]);
145 }
146 }
147
148 if (strides.back() == 1) {
149 return ReductionPlan(ContiguousReduce, shape, strides);
150 } else if (strides.back() > 1) {
151 return ReductionPlan(ContiguousStridedReduce, shape, strides);
152 }
153 }
154
155 // Let's check if we can optimize our access patterns
156 //
157 // 1. We have a reduction axis with stride 1. Simply call
158 // GeneralContiguousReduce and be done with it.
159 // 2. We have transpositions and we are not reducing over the axis with
160 // stride 1. However, we are reducing over an axis where everything is
161 // contiguous in memory to the right of that axis. We can call strided
162 // reduce and be done with it.
163 // 2. We have weird transpositions and expands. Copy the strides to the
164 // output, then call strided reduce.
165
166 // Sort reduction axes by stride in order to merge them and figure out if we
167 // have a contiguous reduction.
168 std::vector<std::pair<int, size_t>> reductions;
169 for (auto a : axes) {
170 reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
171 }
172 std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
173 return a.second > b.second;
174 });
175 // Extract the two smallest and try to merge them in case the contiguous
176 // reduction can be bigger than just the last axis.
177 for (int i = reductions.size() - 1; i >= 1; i--) {
178 auto a = reductions[i];
179 auto b = reductions[i - 1];
180
181 // b.stride = a.shape * a.stride then a and b are contiguous
182 if (b.second == a.first * a.second) {
183 reductions.erase(reductions.begin() + i);
184 reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
185 }
186 }
187
188 std::vector<int> shape;
189 std::vector<size_t> strides;
190 for (auto r : reductions) {
191 shape.push_back(r.first);
192 strides.push_back(r.second);
193 }
194
195 // We can call the contiguous reduction op for every weird way the input is
196 // structured in the rest of the axes.
197 if (strides.back() == 1) {
198 return ReductionPlan(GeneralContiguousReduce, shape, strides);
199 }
200
201 // Delegate to the general strided reduction op if the axes after
202 // strides.back() are contiguous.
203 if (strides.back() > 1) {
204 int size = 1;
205 for (int i = x.ndim() - 1; i >= 0; i--) {
206 if (axes.back() == i) {
207 continue;
208 }
209 if (x.strides()[i] != size) {
210 break;
211 }
212 size *= x.shape(i);
213 }
214 if (size >= strides.back()) {
215 return ReductionPlan(GeneralStridedReduce, shape, strides);
216 }
217 }
218
219 return ReductionPlan(GeneralReduce, shape, strides);
220}
221
222template <typename T, typename U, typename OpS, typename OpC, typename Op>
223void reduction_op(
224 const array& x,
225 array& out,
226 const std::vector<int>& axes,
227 U init,
228 OpS ops,
229 OpC opc,
230 Op op) {
231 out.set_data(allocator::malloc_or_wait(out.nbytes()));
232 ReductionPlan plan = get_reduction_plan(x, axes);
233
234 if (plan.type == ContiguousAllReduce) {
235 U* out_ptr = out.data<U>();
236 *out_ptr = init;
237 opc(x.data<T>(), out_ptr, x.size());
238 return;
239 }
240
241 std::vector<int> shape;
242 std::vector<size_t> strides;
243
244 if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
245 int reduction_size = plan.shape[0];
246 const T* x_ptr = x.data<T>();
247 U* out_ptr = out.data<U>();
248 for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
249 *out_ptr = init;
250 opc(x_ptr, out_ptr, reduction_size);
251 }
252 return;
253 }
254
255 if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {
256 int reduction_size = plan.shape.back();
257 plan.shape.pop_back();
258 plan.strides.pop_back();
259 const T* x_ptr = x.data<T>();
260 U* out_ptr = out.data<U>();
261 // Unrolling the following loop (and implementing it in order for
262 // ContiguousReduce) should hold extra performance boost.
263 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
264 if (plan.shape.size() == 0) {
265 for (int i = 0; i < out.size(); i++, out_ptr++) {
266 int offset = elem_to_loc(i, shape, strides);
267 *out_ptr = init;
268 opc(x_ptr + offset, out_ptr, reduction_size);
269 }
270 } else {
271 for (int i = 0; i < out.size(); i++, out_ptr++) {
272 int offset = elem_to_loc(i, shape, strides);
273 *out_ptr = init;
274 nd_loop(
275 [&](int extra_offset) {
276 opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
277 },
278 plan.shape,
279 plan.strides);
280 }
281 }
282 return;
283 }
284
285 if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {
286 int reduction_size = plan.shape.back();
287 size_t reduction_stride = plan.strides.back();
288 plan.shape.pop_back();
289 plan.strides.pop_back();
290 const T* x_ptr = x.data<T>();
291 U* out_ptr = out.data<U>();
292 for (int i = 0; i < out.size(); i += reduction_stride) {
293 std::fill_n(out_ptr, reduction_stride, init);
294 ops(x_ptr, out_ptr, reduction_size, reduction_stride);
295 x_ptr += reduction_stride * reduction_size;
296 out_ptr += reduction_stride;
297 }
298 return;
299 }
300
301 if (plan.type == GeneralStridedReduce ||
302 plan.type == ContiguousStridedReduce) {
303 int reduction_size = plan.shape.back();
304 size_t reduction_stride = plan.strides.back();
305 plan.shape.pop_back();
306 plan.strides.pop_back();
307 const T* x_ptr = x.data<T>();
308 U* out_ptr = out.data<U>();
309 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
310 if (plan.shape.size() == 0) {
311 for (int i = 0; i < out.size(); i += reduction_stride) {
312 int offset = elem_to_loc(i, shape, strides);
313 std::fill_n(out_ptr, reduction_stride, init);
314 ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
315 out_ptr += reduction_stride;
316 }
317 } else {
318 for (int i = 0; i < out.size(); i += reduction_stride) {
319 int offset = elem_to_loc(i, shape, strides);
320 std::fill_n(out_ptr, reduction_stride, init);
321 nd_loop(
322 [&](int extra_offset) {
323 ops(x_ptr + offset + extra_offset,
324 out_ptr,
325 reduction_size,
326 reduction_stride);
327 },
328 plan.shape,
329 plan.strides);
330 out_ptr += reduction_stride;
331 }
332 }
333 return;
334 }
335
336 if (plan.type == GeneralReduce) {
337 const T* x_ptr = x.data<T>();
338 U* out_ptr = out.data<U>();
339 std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
340 for (int i = 0; i < out.size(); i++, out_ptr++) {
341 int offset = elem_to_loc(i, shape, strides);
342 U val = init;
343 nd_loop(
344 [&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); },
345 plan.shape,
346 plan.strides);
347 *out_ptr = val;
348 }
349 }
350}
351
352template <typename T, typename U, typename Op>
353void reduction_op(
354 const array& x,
355 array& out,
356 const std::vector<int>& axes,
357 U init,
358 Op op) {
359 DefaultStridedReduce<T, U, Op> ops(op);
360 DefaultContiguousReduce<T, U, Op> opc(op);
361 reduction_op<T, U>(x, out, axes, init, ops, opc, op);
362}
363
364} // namespace
365
366} // namespace mlx::core
Op op
Definition binary.h:141
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Buffer malloc_or_wait(size_t size)
Group init(bool strict=false)
Initialize the distributed backend and return the group containing all discoverable processes.
Definition allocator.h:7
ReductionOpType
Definition reduce.h:9
@ GeneralReduce
Definition reduce.h:36
@ GeneralContiguousReduce
Definition reduce.h:25
@ ContiguousStridedReduce
Definition reduce.h:19
@ ContiguousReduce
Definition reduce.h:15
@ GeneralStridedReduce
Definition reduce.h:30
@ ContiguousAllReduce
Definition reduce.h:11
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
Definition reduce.h:39
ReductionOpType type
Definition reduce.h:40
ReductionPlan(ReductionOpType type_, std::vector< int > shape_, std::vector< size_t > strides_)
Definition reduce.h:44
std::vector< int > shape
Definition reduce.h:41
std::vector< size_t > strides
Definition reduce.h:42
ReductionPlan(ReductionOpType type_)
Definition reduce.h:49