MLX
 
Loading...
Searching...
No Matches
reduce_col.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
4[[kernel]] void col_reduce_small(
5 const device T* in [[buffer(0)]],
6 device U* out [[buffer(1)]],
7 const constant size_t& reduction_size [[buffer(2)]],
8 const constant int64_t& reduction_stride [[buffer(3)]],
9 const constant int* shape [[buffer(4)]],
10 const constant int64_t* strides [[buffer(5)]],
11 const constant int& ndim [[buffer(6)]],
12 const constant int* reduce_shape [[buffer(7)]],
13 const constant int64_t* reduce_strides [[buffer(8)]],
14 const constant int& reduce_ndim [[buffer(9)]],
15 const constant size_t& non_col_reductions [[buffer(10)]],
16 uint3 gid [[threadgroup_position_in_grid]],
17 uint3 gsize [[threadgroups_per_grid]],
18 uint3 lid [[thread_position_in_threadgroup]],
19 uint3 lsize [[threads_per_threadgroup]]) {
20 constexpr int n_reads = 4;
21 Op op;
22 LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
23 const device T* row;
24
25 U totals[n_reads];
26 for (int i = 0; i < n_reads; i++) {
27 totals[i] = Op::init;
28 }
29
30 IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads;
31 if (column >= reduction_stride) {
32 return;
33 }
34 bool safe = column + n_reads <= reduction_stride;
35
36 IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
37 IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
38 in += in_idx + column;
39
40 IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
41 loop.next(lid.y, reduce_shape, reduce_strides);
42 for (IdxT r = lid.y; r < total_rows; r += lsize.y) {
43 row = in + loop.location();
44 if (safe) {
45 for (int i = 0; i < n_reads; i++) {
46 totals[i] = op(static_cast<U>(row[i]), totals[i]);
47 }
48 } else {
49 U vals[n_reads];
50 for (int i = 0; i < n_reads; i++) {
51 vals[i] =
52 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
53 }
54 for (int i = 0; i < n_reads; i++) {
55 totals[i] = op(vals[i], totals[i]);
56 }
57 }
58 loop.next(lsize.y, reduce_shape, reduce_strides);
59 }
60
61 if (lsize.y > 1) {
62 // lsize.y should be <= 8
63 threadgroup U shared_vals[32 * 8 * n_reads];
64 for (int i = 0; i < n_reads; i++) {
65 shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];
66 }
67 threadgroup_barrier(mem_flags::mem_threadgroup);
68 if (lid.y == 0) {
69 for (int i = 0; i < n_reads; i++) {
70 totals[i] = shared_vals[lid.x * n_reads + i];
71 }
72 for (uint j = 1; j < lsize.y; j++) {
73 for (int i = 0; i < n_reads; i++) {
74 totals[i] =
75 op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],
76 totals[i]);
77 }
78 }
79 }
80 }
81
82 if (lid.y == 0) {
83 out += out_idx * IdxT(reduction_stride) + column;
84 if (safe) {
85 for (int i = 0; i < n_reads; i++) {
86 out[i] = totals[i];
87 }
88 } else {
89 for (int i = 0; column + i < reduction_stride; i++) {
90 out[i] = totals[i];
91 }
92 }
93 }
94}
95
96template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
97[[kernel]] void col_reduce_longcolumn(
98 const device T* in [[buffer(0)]],
99 device U* out [[buffer(1)]],
100 const constant size_t& reduction_size [[buffer(2)]],
101 const constant size_t& reduction_stride [[buffer(3)]],
102 const constant int* shape [[buffer(4)]],
103 const constant int64_t* strides [[buffer(5)]],
104 const constant int& ndim [[buffer(6)]],
105 const constant int* reduce_shape [[buffer(7)]],
106 const constant int64_t* reduce_strides [[buffer(8)]],
107 const constant int& reduce_ndim [[buffer(9)]],
108 const constant size_t& non_col_reductions [[buffer(10)]],
109 const constant size_t& out_size [[buffer(11)]],
110 uint3 gid [[threadgroup_position_in_grid]],
111 uint3 gsize [[threadgroups_per_grid]],
112 uint3 lid [[thread_position_in_threadgroup]],
113 uint3 lsize [[threads_per_threadgroup]]) {
114 Op op;
115 LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
116 const device T* row;
117
118 IdxT out_idx = gid.x + gsize.x * IdxT(gid.y);
119 IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
120 in += in_idx + lid.x;
121
122 U total = Op::init;
123 IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
124 loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
125 for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows;
126 r += lsize.y * gsize.z) {
127 row = in + loop.location();
128 total = op(static_cast<U>(*row), total);
129 loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
130 }
131
132 threadgroup U shared_vals[32 * 32];
133 shared_vals[lid.y * lsize.x + lid.x] = total;
134 threadgroup_barrier(mem_flags::mem_threadgroup);
135 if (lid.y == 0) {
136 for (uint i = 1; i < lsize.y; i++) {
137 total = op(total, shared_vals[i * lsize.x + lid.x]);
138 }
139 out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] =
140 total;
141 }
142}
143
155template <
156 typename T,
157 typename U,
158 typename Op,
159 typename IdxT,
160 int NDIMS,
161 int BM,
162 int BN>
163[[kernel]] void col_reduce_looped(
164 const device T* in [[buffer(0)]],
165 device U* out [[buffer(1)]],
166 const constant size_t& reduction_size [[buffer(2)]],
167 const constant int64_t& reduction_stride [[buffer(3)]],
168 const constant int* shape [[buffer(4)]],
169 const constant int64_t* strides [[buffer(5)]],
170 const constant int& ndim [[buffer(6)]],
171 const constant int* reduce_shape [[buffer(7)]],
172 const constant int64_t* reduce_strides [[buffer(8)]],
173 const constant int& reduce_ndim [[buffer(9)]],
174 const constant size_t& non_col_reductions [[buffer(10)]],
175 uint3 gid [[threadgroup_position_in_grid]],
176 uint3 gsize [[threadgroups_per_grid]],
177 uint simd_lane_id [[thread_index_in_simdgroup]],
178 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
179 Op op;
180 constexpr int n_simdgroups = 8;
181 constexpr short tgp_size = n_simdgroups * simd_size;
182 constexpr short n_reads = (BM * BN) / tgp_size;
183 constexpr short n_read_blocks = BN / n_reads;
184
185 threadgroup U shared_vals[BN * BM];
186 U totals[n_reads];
187 LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
188 const device T* row;
189
190 for (int i = 0; i < n_reads; i++) {
191 totals[i] = Op::init;
192 }
193
194 short lid = simd_group_id * simd_size + simd_lane_id;
195 short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
196 IdxT column = BN * gid.x + offset.x;
197 bool safe = column + n_reads <= reduction_stride;
198
199 IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
200 IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
201 in += in_idx + column;
202
203 IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
204 loop.next(offset.y, reduce_shape, reduce_strides);
205 for (IdxT r = offset.y; r < total; r += BM) {
206 row = in + loop.location();
207
208 if (safe) {
209 for (int i = 0; i < n_reads; i++) {
210 totals[i] = op(static_cast<U>(row[i]), totals[i]);
211 }
212 } else {
213 U vals[n_reads];
214 for (int i = 0; i < n_reads; i++) {
215 vals[i] =
216 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
217 }
218 for (int i = 0; i < n_reads; i++) {
219 totals[i] = op(vals[i], totals[i]);
220 }
221 }
222
223 loop.next(BM, reduce_shape, reduce_strides);
224 }
225
226 // We can use a simd reduction to accumulate across BM so each thread writes
227 // the partial output to SM and then each simdgroup does BN / n_simdgroups
228 // accumulations.
229 if (BM == 32) {
230 constexpr int n_outputs = BN / n_simdgroups;
231 static_assert(
232 BM != 32 || n_outputs == n_reads,
233 "The tile should be selected such that n_outputs == n_reads");
234 for (int i = 0; i < n_reads; i++) {
235 shared_vals[offset.y * BN + offset.x + i] = totals[i];
236 }
237 threadgroup_barrier(mem_flags::mem_threadgroup);
238 short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
239 for (int i = 0; i < n_outputs; i++) {
240 totals[i] =
241 op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
242 }
243
244 // Write the output.
245 if (simd_lane_id == 0) {
246 IdxT out_column = BN * gid.x + out_offset.x;
247 out += out_idx * IdxT(reduction_stride) + out_column;
248 if (out_column + n_outputs <= reduction_stride) {
249 for (int i = 0; i < n_outputs; i++) {
250 out[i] = totals[i];
251 }
252 } else {
253 for (int i = 0; out_column + i < reduction_stride; i++) {
254 out[i] = totals[i];
255 }
256 }
257 }
258 }
259
260 // Each thread holds n_reads partial results. We write them all out to shared
261 // memory and threads with offset.y == 0 aggregate the columns and write the
262 // outputs.
263 else {
264 short x_block = offset.x / n_reads;
265 for (int i = 0; i < n_reads; i++) {
266 shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
267 }
268 threadgroup_barrier(mem_flags::mem_threadgroup);
269 if (offset.y == 0) {
270 for (int i = 0; i < n_reads; i++) {
271 for (int j = 1; j < BM; j++) {
272 totals[i] =
273 op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
274 }
275 }
276 }
277
278 // Write the output.
279 if (offset.y == 0) {
280 out += out_idx * IdxT(reduction_stride) + column;
281 if (safe) {
282 for (int i = 0; i < n_reads; i++) {
283 out[i] = totals[i];
284 }
285 } else {
286 for (int i = 0; column + i < reduction_stride; i++) {
287 out[i] = totals[i];
288 }
289 }
290 }
291 }
292}
293
294template <
295 typename T,
296 typename U,
297 typename Op,
298 typename IdxT,
299 int NDIMS,
300 int BM,
301 int BN>
302[[kernel]] void col_reduce_2pass(
303 const device T* in [[buffer(0)]],
304 device U* out [[buffer(1)]],
305 const constant size_t& reduction_size [[buffer(2)]],
306 const constant int64_t& reduction_stride [[buffer(3)]],
307 const constant int* shape [[buffer(4)]],
308 const constant int64_t* strides [[buffer(5)]],
309 const constant int& ndim [[buffer(6)]],
310 const constant int* reduce_shape [[buffer(7)]],
311 const constant int64_t* reduce_strides [[buffer(8)]],
312 const constant int& reduce_ndim [[buffer(9)]],
313 const constant size_t& non_col_reductions [[buffer(10)]],
314 const constant size_t& out_size [[buffer(11)]],
315 uint3 gid [[threadgroup_position_in_grid]],
316 uint3 gsize [[threadgroups_per_grid]],
317 uint simd_lane_id [[thread_index_in_simdgroup]],
318 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
319 Op op;
320 constexpr int n_simdgroups = 8;
321 constexpr short tgp_size = n_simdgroups * simd_size;
322 constexpr short n_reads = (BM * BN) / tgp_size;
323 constexpr short n_read_blocks = BN / n_reads;
324 constexpr int n_outputs = BN / n_simdgroups;
325 constexpr short outer_blocks = 32;
326 static_assert(BM == 32, "BM should be equal to 32");
327
328 threadgroup U shared_vals[BN * BM];
329 U totals[n_reads];
330 LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
331 const device T* row;
332
333 for (int i = 0; i < n_reads; i++) {
334 totals[i] = Op::init;
335 }
336
337 short lid = simd_group_id * simd_size + simd_lane_id;
338 short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
339 IdxT column = BN * gid.x + offset.x;
340 bool safe = column + n_reads <= reduction_stride;
341
342 IdxT full_idx = gid.y + gsize.y * IdxT(gid.z);
343 IdxT block_idx = full_idx / IdxT(out_size);
344 IdxT out_idx = full_idx % IdxT(out_size);
345 IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
346 in += in_idx + column;
347
348 IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
349 loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
350 for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) {
351 row = in + loop.location();
352
353 if (safe) {
354 for (int i = 0; i < n_reads; i++) {
355 totals[i] = op(static_cast<U>(row[i]), totals[i]);
356 }
357 } else {
358 U vals[n_reads];
359 for (int i = 0; i < n_reads; i++) {
360 vals[i] =
361 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
362 }
363 for (int i = 0; i < n_reads; i++) {
364 totals[i] = op(vals[i], totals[i]);
365 }
366 }
367
368 loop.next(outer_blocks * BM, reduce_shape, reduce_strides);
369 }
370
371 // We can use a simd reduction to accumulate across BM so each thread writes
372 // the partial output to SM and then each simdgroup does BN / n_simdgroups
373 // accumulations.
374 for (int i = 0; i < n_reads; i++) {
375 shared_vals[offset.y * BN + offset.x + i] = totals[i];
376 }
377 threadgroup_barrier(mem_flags::mem_threadgroup);
378 short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
379 for (int i = 0; i < n_outputs; i++) {
380 totals[i] =
381 op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
382 }
383
384 // Write the output.
385 if (simd_lane_id == 0) {
386 IdxT out_column = BN * gid.x + out_offset.x;
387 out += full_idx * IdxT(reduction_stride) + out_column;
388 if (out_column + n_outputs <= reduction_stride) {
389 for (int i = 0; i < n_outputs; i++) {
390 out[i] = totals[i];
391 }
392 } else {
393 for (int i = 0; out_column + i < reduction_stride; i++) {
394 out[i] = totals[i];
395 }
396 }
397 }
398}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
METAL_FUNC IdxT elem_to_loc(IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim)
Definition utils.h:93
void col_reduce_longcolumn(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
Definition reduce_col.h:97
void col_reduce_looped(const device T *in, device U *out, const constant size_t &reduction_size, const constant int64_t &reduction_stride, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
Our approach is the following simple looped approach:
Definition reduce_col.h:163
void col_reduce_2pass(const device T *in, device U *out, const constant size_t &reduction_size, const constant int64_t &reduction_stride, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
Definition reduce_col.h:302
void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant int64_t &reduction_stride, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
Definition reduce_col.h:4
Definition utils.h:197