MLX
 
Loading...
Searching...
No Matches
reduce_row.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3// Row reduction utilities
4// - `per_thread_row_reduce` collaborative partial reduction in the threadgroup
5// - `threadgroup_reduce` collaborative reduction in the threadgroup such that
6// lid.x == 0 holds the reduced value
7// - `thread_reduce` simple loop and reduce the row
8
13template <
14 typename T,
15 typename U,
16 typename Op,
17 int N_READS = REDUCE_N_READS,
18 int N_WRITES = REDUCE_N_WRITES>
19METAL_FUNC void per_thread_row_reduce(
20 thread U totals[N_WRITES],
21 const device T* inputs[N_WRITES],
22 int blocks,
23 int extra,
24 uint lsize_x,
25 uint lid_x) {
26 Op op;
27
28 // Set up the accumulator registers
29 for (int i = 0; i < N_WRITES; i++) {
30 totals[i] = Op::init;
31 }
32
33 // Loop over the reduction size within thread group
34 for (int i = 0; i < blocks; i++) {
35 for (int j = 0; j < N_WRITES; j++) {
36 for (int i = 0; i < N_READS; i++) {
37 totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
38 }
39
40 inputs[j] += lsize_x * N_READS;
41 }
42 }
43
44 // Separate case for the last set as we close the reduction size
45 int index = lid_x * N_READS;
46 if (index + N_READS <= extra) {
47 for (int j = 0; j < N_WRITES; j++) {
48 for (int i = 0; i < N_READS; i++) {
49 totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
50 }
51 }
52 } else {
53 for (int j = 0; j < N_WRITES; j++) {
54 for (int i = 0; index + i < extra; i++) {
55 totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
56 }
57 }
58 }
59}
60
64template <
65 typename T,
66 typename U,
67 typename Op,
68 int N_READS = REDUCE_N_READS,
69 int N_WRITES = REDUCE_N_WRITES>
70METAL_FUNC void per_thread_row_reduce(
71 thread U totals[N_WRITES],
72 const device T* in,
73 const constant size_t& reduction_size,
74 int blocks,
75 int extra,
76 uint lsize_x,
77 uint lid_x) {
78 // Set up the input pointers
79 const device T* inputs[N_WRITES];
80 inputs[0] = in + lid_x * N_READS;
81 for (int i = 1; i < N_READS; i++) {
82 inputs[i] = inputs[i - 1] + reduction_size;
83 }
84
86 totals, inputs, blocks, extra, lsize_x, lid_x);
87}
88
92template <
93 typename T,
94 typename U,
95 typename Op,
96 int N_READS = REDUCE_N_READS,
97 int N_WRITES = REDUCE_N_WRITES>
98METAL_FUNC void per_thread_row_reduce(
99 thread U totals[N_WRITES],
100 const device T* in,
101 const int64_t row_idx,
102 int blocks,
103 int extra,
104 const constant int* shape,
105 const constant int64_t* strides,
106 const constant int& ndim,
107 uint lsize_x,
108 uint lid_x) {
109 // Set up the input pointers
110 const device T* inputs[N_WRITES];
111 in += lid_x * N_READS;
112 for (int i = 0; i < N_READS; i++) {
113 inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim);
114 }
115
117 totals, inputs, blocks, extra, lsize_x, lid_x);
118}
119
123template <
124 typename T,
125 typename U,
126 typename Op,
127 int N_READS = REDUCE_N_READS,
128 int N_WRITES = REDUCE_N_WRITES>
129METAL_FUNC void threadgroup_reduce(
130 thread U totals[N_WRITES],
131 threadgroup U* shared_vals,
132 uint3 lid [[thread_position_in_threadgroup]],
133 uint simd_lane_id [[thread_index_in_simdgroup]],
134 uint simd_per_group [[simdgroups_per_threadgroup]],
135 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
136 Op op;
137
138 // Simdgroup first
139 for (int i = 0; i < N_WRITES; i++) {
140 totals[i] = op.simd_reduce(totals[i]);
141 }
142
143 // Across simdgroups
144 if (simd_per_group > 1) {
145 if (simd_lane_id == 0) {
146 for (int i = 0; i < N_WRITES; i++) {
147 shared_vals[simd_group_id * N_WRITES + i] = totals[i];
148 }
149 }
150 threadgroup_barrier(mem_flags::mem_threadgroup);
151
152 U values[N_WRITES];
153 for (int i = 0; i < N_WRITES; i++) {
154 values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i]
155 : op.init;
156 }
157
158 for (int i = 0; i < N_WRITES; i++) {
159 totals[i] = op.simd_reduce(values[i]);
160 }
161 }
162}
163
164template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
165METAL_FUNC void
166thread_reduce(thread U& total, const device T* row, int blocks, int extra) {
167 Op op;
168 for (int i = 0; i < blocks; i++) {
169 U vals[N_READS];
170 for (int j = 0; j < N_READS; j++) {
171 vals[j] = row[j];
172 }
173 for (int j = 0; j < N_READS; j++) {
174 total = op(vals[j], total);
175 }
176 row += N_READS;
177 }
178 for (int i = 0; i < extra; i++) {
179 total = op(*row++, total);
180 }
181}
182
183// Reduction kernels
184// - `row_reduce_small` depending on the non-row reductions and row size it
185// either just loops over everything or a simd collaboratively reduces the
186// non_row reductions. In the first case one thread is responsible for one
187// output on the 2nd one simd is responsible for one output.
188// - `row_reduce_simple` simple contiguous row reduction
189// - `row_reduce_looped` simply loop and reduce each row for each non-row
190// reduction. One threadgroup is responsible for one output.
191
192template <
193 typename T,
194 typename U,
195 typename Op,
196 typename IdxT,
197 int NDIMS,
198 int N_READS = REDUCE_N_READS>
199[[kernel]] void row_reduce_small(
200 const device T* in [[buffer(0)]],
201 device U* out [[buffer(1)]],
202 const constant int64_t& row_size [[buffer(2)]],
203 const constant int64_t& non_row_reductions [[buffer(3)]],
204 const constant int* shape [[buffer(4)]],
205 const constant int64_t* strides [[buffer(5)]],
206 const constant int& ndim [[buffer(6)]],
207 const constant int* reduce_shape [[buffer(7)]],
208 const constant int64_t* reduce_strides [[buffer(8)]],
209 const constant int& reduce_ndim [[buffer(9)]],
210 uint simd_lane_id [[thread_index_in_simdgroup]],
211 uint3 gid [[threadgroup_position_in_grid]],
212 uint3 gsize [[threadgroups_per_grid]],
213 uint3 tid [[thread_position_in_grid]],
214 uint3 tsize [[threads_per_grid]]) {
215 Op op;
216
217 U total_val = Op::init;
218 LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
219
220 // Precompute some row reduction numbers
221 const device T* row;
222 int blocks = IdxT(row_size) / N_READS;
223 int extra = IdxT(row_size) % N_READS;
224
225 if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
226 // Simple loop over non_row_reductions and reduce the row in the thread.
227 IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
228 in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
229
230 for (uint r = 0; r < non_row_reductions; r++) {
231 row = in + loop.location();
232 thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
233 loop.next(reduce_shape, reduce_strides);
234 }
235
236 out[out_idx] = total_val;
237 } else {
238 // Collaboratively reduce over non_row_reductions in the simdgroup. Each
239 // thread reduces every 32nd row and then a simple simd reduce.
240 IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
241 in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
242
243 loop.next(simd_lane_id, reduce_shape, reduce_strides);
244
245 for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
246 row = in + loop.location();
247 thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
248 loop.next(simd_size, reduce_shape, reduce_strides);
249 }
250
251 total_val = op.simd_reduce(total_val);
252
253 if (simd_lane_id == 0) {
254 out[out_idx] = total_val;
255 }
256 }
257}
258
259template <
260 typename T,
261 typename U,
262 typename Op,
263 typename IdxT = int64_t,
264 int N_READS = REDUCE_N_READS,
265 int N_WRITES = REDUCE_N_WRITES>
266[[kernel]] void row_reduce_simple(
267 const device T* in [[buffer(0)]],
268 device U* out [[buffer(1)]],
269 const constant size_t& reduction_size [[buffer(2)]],
270 const constant int64_t& out_size [[buffer(3)]],
271 uint3 gid [[threadgroup_position_in_grid]],
272 uint3 gsize [[threadgroups_per_grid]],
273 uint3 lid [[thread_position_in_threadgroup]],
274 uint3 lsize [[threads_per_threadgroup]],
275 uint simd_lane_id [[thread_index_in_simdgroup]],
276 uint simd_per_group [[simdgroups_per_threadgroup]],
277 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
278 threadgroup U shared_vals[simd_size * N_WRITES];
279 U totals[N_WRITES];
280
281 // Move to the row
282 IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z));
283 if (out_idx + N_WRITES > out_size) {
284 out_idx = out_size - N_WRITES;
285 }
286 in += out_idx * IdxT(reduction_size);
287 out += out_idx;
288
289 // Each thread reduces across the row
290 int blocks = IdxT(reduction_size) / (lsize.x * N_READS);
291 int extra = reduction_size - blocks * (lsize.x * N_READS);
293 totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
294
295 // Reduce across the threadgroup
297 totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
298
299 // Write the output
300 if (lid.x == 0) {
301 for (int i = 0; i < N_WRITES; i++) {
302 out[i] = totals[i];
303 }
304 }
305}
306
307template <
308 typename T,
309 typename U,
310 typename Op,
311 typename IdxT,
312 int NDIMS,
313 int N_READS = REDUCE_N_READS>
314[[kernel]] void row_reduce_looped(
315 const device T* in [[buffer(0)]],
316 device U* out [[buffer(1)]],
317 const constant int64_t& row_size [[buffer(2)]],
318 const constant int64_t& non_row_reductions [[buffer(3)]],
319 const constant int* shape [[buffer(4)]],
320 const constant int64_t* strides [[buffer(5)]],
321 const constant int& ndim [[buffer(6)]],
322 const constant int* reduce_shape [[buffer(7)]],
323 const constant int64_t* reduce_strides [[buffer(8)]],
324 const constant int& reduce_ndim [[buffer(9)]],
325 uint3 gid [[threadgroup_position_in_grid]],
326 uint3 gsize [[threadgroups_per_grid]],
327 uint3 lid [[thread_position_in_threadgroup]],
328 uint3 lsize [[threads_per_threadgroup]],
329 uint simd_lane_id [[thread_index_in_simdgroup]],
330 uint simd_per_group [[simdgroups_per_threadgroup]],
331 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
332 Op op;
333 threadgroup U shared_vals[simd_size];
334 U total = Op::init;
335
336 IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
337
338 // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
339 // needs a small refactor.
340 in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim) + lid.x * N_READS;
341
342 LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
343 const device T* row;
344 int blocks = IdxT(row_size) / (lsize.x * N_READS);
345 int extra = row_size - blocks * (lsize.x * N_READS);
346
347 for (IdxT i = 0; i < non_row_reductions; i++) {
348 row = in + loop.location();
349
350 // Each thread reduces across the row
351 U row_total;
353 &row_total, &row, blocks, extra, lsize.x, lid.x);
354
355 // Aggregate across rows
356 total = op(total, row_total);
357
358 loop.next(reduce_shape, reduce_strides);
359 }
360
361 // Reduce across the threadgroup
363 &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
364
365 // Write the output
366 if (lid.x == 0) {
367 out[out_idx] = total;
368 }
369}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
METAL_FUNC IdxT elem_to_loc(IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim)
Definition utils.h:93
static constexpr int REDUCE_N_READS
Definition defines.h:12
static constexpr int REDUCE_N_WRITES
Definition defines.h:13
void row_reduce_small(const device T *in, device U *out, const constant int64_t &row_size, const constant int64_t &non_row_reductions, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, uint simd_lane_id, uint3 gid, uint3 gsize, uint3 tid, uint3 tsize)
Definition reduce_row.h:199
void row_reduce_looped(const device T *in, device U *out, const constant int64_t &row_size, const constant int64_t &non_row_reductions, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:314
void row_reduce_simple(const device T *in, device U *out, const constant size_t &reduction_size, const constant int64_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:266
METAL_FUNC void per_thread_row_reduce(thread U totals[N_WRITES], const device T *inputs[N_WRITES], int blocks, int extra, uint lsize_x, uint lid_x)
The thread group collaboratively reduces across the rows with bounds checking.
Definition reduce_row.h:19
METAL_FUNC void threadgroup_reduce(thread U totals[N_WRITES], threadgroup U *shared_vals, uint3 lid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Reduce within the threadgroup.
Definition reduce_row.h:129
METAL_FUNC void thread_reduce(thread U &total, const device T *row, int blocks, int extra)
Definition reduce_row.h:166
Definition utils.h:197