MLX
Loading...
Searching...
No Matches
reduce_row.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3// Row reduction utilities
4// - `per_thread_row_reduce` collaborative partial reduction in the threadgroup
5// - `threadgroup_reduce` collaborative reduction in the threadgroup such that
6// lid.x == 0 holds the reduced value
7// - `thread_reduce` simple loop and reduce the row
8
13template <
14 typename T,
15 typename U,
16 typename Op,
17 int N_READS = REDUCE_N_READS,
18 int N_WRITES = REDUCE_N_WRITES>
19METAL_FUNC void per_thread_row_reduce(
20 thread U totals[N_WRITES],
21 const device T* inputs[N_WRITES],
22 int blocks,
23 int extra,
24 uint lsize_x,
25 uint lid_x) {
26 Op op;
27
28 // Set up the accumulator registers
29 for (int i = 0; i < N_WRITES; i++) {
30 totals[i] = Op::init;
31 }
32
33 // Loop over the reduction size within thread group
34 for (int i = 0; i < blocks; i++) {
35 for (int j = 0; j < N_WRITES; j++) {
36 for (int i = 0; i < N_READS; i++) {
37 totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
38 }
39
40 inputs[j] += lsize_x * N_READS;
41 }
42 }
43
44 // Separate case for the last set as we close the reduction size
45 int index = lid_x * N_READS;
46 if (index + N_READS <= extra) {
47 for (int j = 0; j < N_WRITES; j++) {
48 for (int i = 0; i < N_READS; i++) {
49 totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
50 }
51 }
52 } else {
53 for (int j = 0; j < N_WRITES; j++) {
54 for (int i = 0; index + i < extra; i++) {
55 totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
56 }
57 }
58 }
59}
60
64template <
65 typename T,
66 typename U,
67 typename Op,
68 int N_READS = REDUCE_N_READS,
69 int N_WRITES = REDUCE_N_WRITES>
70METAL_FUNC void per_thread_row_reduce(
71 thread U totals[N_WRITES],
72 const device T* in,
73 const constant size_t& reduction_size,
74 int blocks,
75 int extra,
76 uint lsize_x,
77 uint lid_x) {
78 // Set up the input pointers
79 const device T* inputs[N_WRITES];
80 inputs[0] = in + lid_x * N_READS;
81 for (int i = 1; i < N_READS; i++) {
82 inputs[i] = inputs[i - 1] + reduction_size;
83 }
84
85 per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
86 totals, inputs, blocks, extra, lsize_x, lid_x);
87}
88
92template <
93 typename T,
94 typename U,
95 typename Op,
96 int N_READS = REDUCE_N_READS,
97 int N_WRITES = REDUCE_N_WRITES>
98METAL_FUNC void per_thread_row_reduce(
99 thread U totals[N_WRITES],
100 const device T* in,
101 const size_t row_idx,
102 int blocks,
103 int extra,
104 const constant int* shape,
105 const constant size_t* strides,
106 const constant int& ndim,
107 uint lsize_x,
108 uint lid_x) {
109 // Set up the input pointers
110 const device T* inputs[N_WRITES];
111 in += lid_x * N_READS;
112 for (int i = 0; i < N_READS; i++) {
113 inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim);
114 }
115
116 per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
117 totals, inputs, blocks, extra, lsize_x, lid_x);
118}
119
123template <
124 typename T,
125 typename U,
126 typename Op,
127 int N_READS = REDUCE_N_READS,
128 int N_WRITES = REDUCE_N_WRITES>
129METAL_FUNC void threadgroup_reduce(
130 thread U totals[N_WRITES],
131 threadgroup U* shared_vals,
132 uint3 lid [[thread_position_in_threadgroup]],
133 uint simd_lane_id [[thread_index_in_simdgroup]],
134 uint simd_per_group [[simdgroups_per_threadgroup]],
135 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
136 Op op;
137
138 // Simdgroup first
139 for (int i = 0; i < N_WRITES; i++) {
140 totals[i] = op.simd_reduce(totals[i]);
141 }
142
143 // Across simdgroups
144 if (simd_per_group > 1) {
145 if (simd_lane_id == 0) {
146 for (int i = 0; i < N_WRITES; i++) {
147 shared_vals[simd_group_id * N_WRITES + i] = totals[i];
148 }
149 }
150 threadgroup_barrier(mem_flags::mem_threadgroup);
151
152 U values[N_WRITES];
153 for (int i = 0; i < N_WRITES; i++) {
154 values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i]
155 : op.init;
156 }
157
158 for (int i = 0; i < N_WRITES; i++) {
159 totals[i] = op.simd_reduce(values[i]);
160 }
161 }
162}
163
164template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
165METAL_FUNC void
166thread_reduce(thread U& total, const device T* row, int blocks, int extra) {
167 Op op;
168 for (int i = 0; i < blocks; i++) {
169 U vals[N_READS];
170 for (int j = 0; j < N_READS; j++) {
171 vals[j] = row[j];
172 }
173 for (int j = 0; j < N_READS; j++) {
174 total = op(vals[j], total);
175 }
176 row += N_READS;
177 }
178 for (int i = 0; i < extra; i++) {
179 total = op(*row++, total);
180 }
181}
182
183// Reduction kernels
184// - `row_reduce_small` depending on the non-row reductions and row size it
185// either just loops over everything or a simd collaboratively reduces the
186// non_row reductions. In the first case one thread is responsible for one
187// output on the 2nd one simd is responsible for one output.
188// - `row_reduce_simple` simple contiguous row reduction
189// - `row_reduce_looped` simply loop and reduce each row for each non-row
190// reduction. One threadgroup is responsible for one output.
191
192template <
193 typename T,
194 typename U,
195 typename Op,
196 int NDIMS = 0,
197 int N_READS = REDUCE_N_READS>
198[[kernel]] void row_reduce_small(
199 const device T* in [[buffer(0)]],
200 device U* out [[buffer(1)]],
201 const constant size_t& row_size [[buffer(2)]],
202 const constant size_t& non_row_reductions [[buffer(3)]],
203 const constant int* shape [[buffer(4)]],
204 const constant size_t* strides [[buffer(5)]],
205 const constant int& ndim [[buffer(6)]],
206 const constant int* reduce_shape [[buffer(7)]],
207 const constant size_t* reduce_strides [[buffer(8)]],
208 const constant int& reduce_ndim [[buffer(9)]],
209 uint simd_lane_id [[thread_index_in_simdgroup]],
210 uint3 gid [[threadgroup_position_in_grid]],
211 uint3 gsize [[threadgroups_per_grid]],
212 uint3 tid [[thread_position_in_grid]],
213 uint3 tsize [[threads_per_grid]]) {
214 Op op;
215
216 U total_val = Op::init;
218
219 // Precompute some row reduction numbers
220 const device T* row;
221 int blocks = row_size / N_READS;
222 int extra = row_size % N_READS;
223
224 if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
225 // Simple loop over non_row_reductions and reduce the row in the thread.
226 size_t out_idx = tid.x + tsize.y * size_t(tid.y);
227 in += elem_to_loc(out_idx, shape, strides, ndim);
228
229 for (uint r = 0; r < non_row_reductions; r++) {
230 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
231 thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
232 loop.next(reduce_shape, reduce_strides);
233 }
234
235 out[out_idx] = total_val;
236 } else {
237 // Collaboratively reduce over non_row_reductions in the simdgroup. Each
238 // thread reduces every 32nd row and then a simple simd reduce.
239 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
240 in += elem_to_loc(out_idx, shape, strides, ndim);
241
242 loop.next(simd_lane_id, reduce_shape, reduce_strides);
243
244 for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
245 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
246 thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
247 loop.next(simd_size, reduce_shape, reduce_strides);
248 }
249
250 total_val = op.simd_reduce(total_val);
251
252 if (simd_lane_id == 0) {
253 out[out_idx] = total_val;
254 }
255 }
256}
257
258template <
259 typename T,
260 typename U,
261 typename Op,
262 int N_READS = REDUCE_N_READS,
263 int N_WRITES = REDUCE_N_WRITES>
264[[kernel]] void row_reduce_simple(
265 const device T* in [[buffer(0)]],
266 device U* out [[buffer(1)]],
267 const constant size_t& reduction_size [[buffer(2)]],
268 const constant size_t& out_size [[buffer(3)]],
269 uint3 gid [[threadgroup_position_in_grid]],
270 uint3 gsize [[threadgroups_per_grid]],
271 uint3 lid [[thread_position_in_threadgroup]],
272 uint3 lsize [[threads_per_threadgroup]],
273 uint simd_lane_id [[thread_index_in_simdgroup]],
274 uint simd_per_group [[simdgroups_per_threadgroup]],
275 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
276 threadgroup U shared_vals[simd_size * N_WRITES];
277 U totals[N_WRITES];
278
279 // Move to the row
280 size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z));
281 if (out_idx + N_WRITES > out_size) {
282 out_idx = out_size - N_WRITES;
283 }
284 in += out_idx * reduction_size;
285 out += out_idx;
286
287 // Each thread reduces across the row
288 int blocks = reduction_size / (lsize.x * N_READS);
289 int extra = reduction_size - blocks * (lsize.x * N_READS);
290 per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
291 totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
292
293 // Reduce across the threadgroup
294 threadgroup_reduce<T, U, Op, N_READS, N_WRITES>(
295 totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
296
297 // Write the output
298 if (lid.x == 0) {
299 for (int i = 0; i < N_WRITES; i++) {
300 out[i] = totals[i];
301 }
302 }
303}
304
305template <
306 typename T,
307 typename U,
308 typename Op,
309 int NDIMS = 0,
310 int N_READS = REDUCE_N_READS>
311[[kernel]] void row_reduce_looped(
312 const device T* in [[buffer(0)]],
313 device U* out [[buffer(1)]],
314 const constant size_t& row_size [[buffer(2)]],
315 const constant size_t& non_row_reductions [[buffer(3)]],
316 const constant int* shape [[buffer(4)]],
317 const constant size_t* strides [[buffer(5)]],
318 const constant int& ndim [[buffer(6)]],
319 const constant int* reduce_shape [[buffer(7)]],
320 const constant size_t* reduce_strides [[buffer(8)]],
321 const constant int& reduce_ndim [[buffer(9)]],
322 uint3 gid [[threadgroup_position_in_grid]],
323 uint3 gsize [[threadgroups_per_grid]],
324 uint3 lid [[thread_position_in_threadgroup]],
325 uint3 lsize [[threads_per_threadgroup]],
326 uint simd_lane_id [[thread_index_in_simdgroup]],
327 uint simd_per_group [[simdgroups_per_threadgroup]],
328 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
329 Op op;
330 threadgroup U shared_vals[simd_size];
331 U total = Op::init;
332
333 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
334
335 // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
336 // needs a small refactor.
337 in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS;
338
340 const device T* row;
341 int blocks = row_size / (lsize.x * N_READS);
342 int extra = row_size - blocks * (lsize.x * N_READS);
343
344 for (size_t i = 0; i < non_row_reductions; i++) {
345 row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim);
346
347 // Each thread reduces across the row
348 U row_total;
349 per_thread_row_reduce<T, U, Op, N_READS, 1>(
350 &row_total, &row, blocks, extra, lsize.x, lid.x);
351
352 // Aggregate across rows
353 total = op(total, row_total);
354
355 loop.next(reduce_shape, reduce_strides);
356 }
357
358 // Reduce across the threadgroup
359 threadgroup_reduce<T, U, Op, N_READS, 1>(
360 &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
361
362 // Write the output
363 if (lid.x == 0) {
364 out[out_idx] = total;
365 }
366}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:87
Op op
Definition binary.h:141
static constexpr int REDUCE_N_READS
Definition defines.h:12
static constexpr int REDUCE_N_WRITES
Definition defines.h:13
void row_reduce_small(const device T *in, device U *out, const constant size_t &row_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, uint simd_lane_id, uint3 gid, uint3 gsize, uint3 tid, uint3 tsize)
Definition reduce_row.h:198
METAL_FUNC void per_thread_row_reduce(thread U totals[N_WRITES], const device T *inputs[N_WRITES], int blocks, int extra, uint lsize_x, uint lid_x)
The thread group collaboratively reduces across the rows with bounds checking.
Definition reduce_row.h:19
METAL_FUNC void threadgroup_reduce(thread U totals[N_WRITES], threadgroup U *shared_vals, uint3 lid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Reduce within the threadgroup.
Definition reduce_row.h:129
void row_reduce_simple(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:264
void row_reduce_looped(const device T *in, device U *out, const constant size_t &row_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
Definition reduce_row.h:311
METAL_FUNC void thread_reduce(thread U &total, const device T *row, int blocks, int extra)
Definition reduce_row.h:166
Definition utils.h:334
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:339
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:366