MLX
 
Loading...
Searching...
No Matches
reduce.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
6
7namespace mlx::core {
8
10 // Self-explanatory. Read everything and produce 1 output.
12
13 // The input is contiguous and the last axis is reduced
14 // N1xR1xN2xR2x...xNnxRn
16
17 // The input is contiguous and the last axis is not reduced
18 // R1xN1xR2xN2x...xRnxNn
20
21 // The input is not contiguous but the last axis is and it is reduced so we
22 // need to figure out the offsets but we can call the contiguous reduce after
23 // that.
24 // N3xR1xN1xR4x...xRn
26
27 // The input is not contiguous but the last reduction axis and the last axis
28 // are so we need to figure out the offset but we can call the strided reduce
29 // after that.
31
32 // The input is not contiguous after the reduction axis and it may contain
33 // 0-stride axes or transpositions. We could copy the strides and produce a
34 // transposed outcome or we can read the input out of order and write the
35 // output in order.
37};
38
43
44 ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
45 : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
47};
48
49ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
50
51// Helper for the ndimensional strided loop
52// Should this be in utils?
54 std::function<void(int)> callback,
55 const Shape& shape,
56 const Strides& strides);
57
58std::pair<Shape, Strides> shapes_without_reduction_axes(
59 const array& x,
60 const std::vector<int>& axes);
61
62template <typename T, typename U, typename Op>
64 Op op;
65
66 DefaultStridedReduce(Op op_) : op(op_) {}
67
68 void operator()(const T* x, U* accumulator, int size, size_t stride) {
69 for (int i = 0; i < size; i++) {
70 U* moving_accumulator = accumulator;
71 for (int j = 0; j < stride; j++) {
72 op(moving_accumulator, *x);
73 moving_accumulator++;
74 x++;
75 }
76 }
77 }
78};
79
80template <typename T, typename U, typename Op>
82 Op op;
83
84 DefaultContiguousReduce(Op op_) : op(op_) {}
85
86 void operator()(const T* x, U* accumulator, int size) {
87 while (size-- > 0) {
88 op(accumulator, *x);
89 x++;
90 }
91 }
92};
93
94template <typename T, typename U, typename OpS, typename OpC, typename Op>
96 const array& x,
97 array& out,
98 const std::vector<int>& axes,
99 U init,
100 OpS ops,
101 OpC opc,
102 Op op) {
104 ReductionPlan plan = get_reduction_plan(x, axes);
105
106 if (plan.type == ContiguousAllReduce) {
107 U* out_ptr = out.data<U>();
108 *out_ptr = init;
109 opc(x.data<T>(), out_ptr, x.size());
110 return;
111 }
112
113 if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
114 int reduction_size = plan.shape[0];
115 const T* x_ptr = x.data<T>();
116 U* out_ptr = out.data<U>();
117 for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
118 *out_ptr = init;
119 opc(x_ptr, out_ptr, reduction_size);
120 }
121 return;
122 }
123
124 if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {
125 int reduction_size = plan.shape.back();
126 plan.shape.pop_back();
127 plan.strides.pop_back();
128 const T* x_ptr = x.data<T>();
129 U* out_ptr = out.data<U>();
130 // Unrolling the following loop (and implementing it in order for
131 // ContiguousReduce) should hold extra performance boost.
132 auto [shape, strides] = shapes_without_reduction_axes(x, axes);
133 if (plan.shape.size() == 0) {
134 for (int i = 0; i < out.size(); i++, out_ptr++) {
135 int offset = elem_to_loc(i, shape, strides);
136 *out_ptr = init;
137 opc(x_ptr + offset, out_ptr, reduction_size);
138 }
139 } else {
140 for (int i = 0; i < out.size(); i++, out_ptr++) {
141 int offset = elem_to_loc(i, shape, strides);
142 *out_ptr = init;
143 nd_loop(
144 [&](int extra_offset) {
145 opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
146 },
147 plan.shape,
148 plan.strides);
149 }
150 }
151 return;
152 }
153
154 if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {
155 int reduction_size = plan.shape.back();
156 size_t reduction_stride = plan.strides.back();
157 plan.shape.pop_back();
158 plan.strides.pop_back();
159 const T* x_ptr = x.data<T>();
160 U* out_ptr = out.data<U>();
161 for (int i = 0; i < out.size(); i += reduction_stride) {
162 std::fill_n(out_ptr, reduction_stride, init);
163 ops(x_ptr, out_ptr, reduction_size, reduction_stride);
164 x_ptr += reduction_stride * reduction_size;
165 out_ptr += reduction_stride;
166 }
167 return;
168 }
169
170 if (plan.type == GeneralStridedReduce ||
172 int reduction_size = plan.shape.back();
173 size_t reduction_stride = plan.strides.back();
174 plan.shape.pop_back();
175 plan.strides.pop_back();
176 const T* x_ptr = x.data<T>();
177 U* out_ptr = out.data<U>();
178 auto [shape, strides] = shapes_without_reduction_axes(x, axes);
179 if (plan.shape.size() == 0) {
180 for (int i = 0; i < out.size(); i += reduction_stride) {
181 int offset = elem_to_loc(i, shape, strides);
182 std::fill_n(out_ptr, reduction_stride, init);
183 ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
184 out_ptr += reduction_stride;
185 }
186 } else {
187 for (int i = 0; i < out.size(); i += reduction_stride) {
188 int offset = elem_to_loc(i, shape, strides);
189 std::fill_n(out_ptr, reduction_stride, init);
190 nd_loop(
191 [&](int extra_offset) {
192 ops(x_ptr + offset + extra_offset,
193 out_ptr,
194 reduction_size,
195 reduction_stride);
196 },
197 plan.shape,
198 plan.strides);
199 out_ptr += reduction_stride;
200 }
201 }
202 return;
203 }
204
205 if (plan.type == GeneralReduce) {
206 const T* x_ptr = x.data<T>();
207 U* out_ptr = out.data<U>();
208 auto [shape, strides] = shapes_without_reduction_axes(x, axes);
209 for (int i = 0; i < out.size(); i++, out_ptr++) {
210 int offset = elem_to_loc(i, shape, strides);
211 U val = init;
212 nd_loop(
213 [&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); },
214 plan.shape,
215 plan.strides);
216 *out_ptr = val;
217 }
218 }
219}
220
221template <typename T, typename U, typename Op>
223 const array& x,
224 array& out,
225 const std::vector<int>& axes,
226 U init,
227 Op op) {
230 reduction_op<T, U>(x, out, axes, init, ops, opc, op);
231}
232
233} // namespace mlx::core
Definition array.h:24
size_t nbytes() const
The number of bytes in the array.
Definition array.h:93
size_t size() const
The number of elements in the array.
Definition array.h:88
T * data()
Definition array.h:342
void set_data(allocator::Buffer buffer, Deleter d=allocator::free)
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Buffer malloc_or_wait(size_t size)
Definition allocator.h:7
std::pair< Shape, Strides > shapes_without_reduction_axes(const array &x, const std::vector< int > &axes)
ReductionOpType
Definition reduce.h:9
@ GeneralReduce
Definition reduce.h:36
@ GeneralContiguousReduce
Definition reduce.h:25
@ ContiguousStridedReduce
Definition reduce.h:19
@ ContiguousReduce
Definition reduce.h:15
@ GeneralStridedReduce
Definition reduce.h:30
@ ContiguousAllReduce
Definition reduce.h:11
int64_t elem_to_loc(int elem, const Shape &shape, const Strides &strides)
Definition utils.h:12
std::vector< ShapeElem > Shape
Definition array.h:21
std::vector< int64_t > Strides
Definition array.h:22
void reduction_op(const array &x, array &out, const std::vector< int > &axes, U init, OpS ops, OpC opc, Op op)
Definition reduce.h:95
ReductionPlan get_reduction_plan(const array &x, const std::vector< int > &axes)
void nd_loop(std::function< void(int)> callback, const Shape &shape, const Strides &strides)
void operator()(const T *x, U *accumulator, int size)
Definition reduce.h:86
Op op
Definition reduce.h:82
DefaultContiguousReduce(Op op_)
Definition reduce.h:84
Definition reduce.h:63
void operator()(const T *x, U *accumulator, int size, size_t stride)
Definition reduce.h:68
DefaultStridedReduce(Op op_)
Definition reduce.h:66
Op op
Definition reduce.h:64
Definition reduce.h:39
ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
Definition reduce.h:44
Shape shape
Definition reduce.h:41
ReductionOpType type
Definition reduce.h:40
Strides strides
Definition reduce.h:42
ReductionPlan(ReductionOpType type_)
Definition reduce.h:46