MLX
Loading...
Searching...
No Matches
reduce_col.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3template <typename T, typename U, typename Op, int NDIMS>
4[[kernel]] void col_reduce_small(
5 const device T* in [[buffer(0)]],
6 device U* out [[buffer(1)]],
7 const constant size_t& reduction_size [[buffer(2)]],
8 const constant size_t& reduction_stride [[buffer(3)]],
9 const constant int* shape [[buffer(4)]],
10 const constant size_t* strides [[buffer(5)]],
11 const constant int& ndim [[buffer(6)]],
12 const constant int* reduce_shape [[buffer(7)]],
13 const constant size_t* reduce_strides [[buffer(8)]],
14 const constant int& reduce_ndim [[buffer(9)]],
15 const constant size_t& non_col_reductions [[buffer(10)]],
16 uint3 gid [[threadgroup_position_in_grid]],
17 uint3 gsize [[threadgroups_per_grid]],
18 uint3 lid [[thread_position_in_threadgroup]],
19 uint3 lsize [[threads_per_threadgroup]]) {
20 constexpr int n_reads = 4;
21 Op op;
23 const device T* row;
24
25 U totals[n_reads];
26 for (int i = 0; i < n_reads; i++) {
27 totals[i] = Op::init;
28 }
29
30 size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;
31 if (column >= reduction_stride) {
32 return;
33 }
34 bool safe = column + n_reads <= reduction_stride;
35
36 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
37 size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
38 in += in_idx + column;
39
40 size_t total_rows = non_col_reductions * reduction_size;
41 loop.next(lid.y, reduce_shape, reduce_strides);
42 for (size_t r = lid.y; r < total_rows; r += lsize.y) {
43 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
44 if (safe) {
45 for (int i = 0; i < n_reads; i++) {
46 totals[i] = op(static_cast<U>(row[i]), totals[i]);
47 }
48 } else {
49 U vals[n_reads];
50 for (int i = 0; i < n_reads; i++) {
51 vals[i] =
52 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
53 }
54 for (int i = 0; i < n_reads; i++) {
55 totals[i] = op(vals[i], totals[i]);
56 }
57 }
58 loop.next(lsize.y, reduce_shape, reduce_strides);
59 }
60
61 if (lsize.y > 1) {
62 // lsize.y should be <= 8
63 threadgroup U shared_vals[32 * 8 * n_reads];
64 for (int i = 0; i < n_reads; i++) {
65 shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];
66 }
67 threadgroup_barrier(mem_flags::mem_threadgroup);
68 if (lid.y == 0) {
69 for (int i = 0; i < n_reads; i++) {
70 totals[i] = shared_vals[lid.x * n_reads + i];
71 }
72 for (uint j = 1; j < lsize.y; j++) {
73 for (int i = 0; i < n_reads; i++) {
74 totals[i] =
75 op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],
76 totals[i]);
77 }
78 }
79 }
80 }
81
82 if (lid.y == 0) {
83 out += out_idx * reduction_stride + column;
84 if (safe) {
85 for (int i = 0; i < n_reads; i++) {
86 out[i] = totals[i];
87 }
88 } else {
89 for (int i = 0; column + i < reduction_stride; i++) {
90 out[i] = totals[i];
91 }
92 }
93 }
94}
95
96template <typename T, typename U, typename Op, int NDIMS>
97[[kernel]] void col_reduce_longcolumn(
98 const device T* in [[buffer(0)]],
99 device U* out [[buffer(1)]],
100 const constant size_t& reduction_size [[buffer(2)]],
101 const constant size_t& reduction_stride [[buffer(3)]],
102 const constant int* shape [[buffer(4)]],
103 const constant size_t* strides [[buffer(5)]],
104 const constant int& ndim [[buffer(6)]],
105 const constant int* reduce_shape [[buffer(7)]],
106 const constant size_t* reduce_strides [[buffer(8)]],
107 const constant int& reduce_ndim [[buffer(9)]],
108 const constant size_t& non_col_reductions [[buffer(10)]],
109 const constant size_t& out_size [[buffer(11)]],
110 uint3 gid [[threadgroup_position_in_grid]],
111 uint3 gsize [[threadgroups_per_grid]],
112 uint3 lid [[thread_position_in_threadgroup]],
113 uint3 lsize [[threads_per_threadgroup]]) {
114 Op op;
116 const device T* row;
117
118 size_t out_idx = gid.x + gsize.x * size_t(gid.y);
119 size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
120 in += in_idx + lid.x;
121
122 U total = Op::init;
123 size_t total_rows = non_col_reductions * reduction_size;
124 loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
125 for (size_t r = gid.z * lsize.y + lid.y; r < total_rows;
126 r += lsize.y * gsize.z) {
127 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
128 total = op(static_cast<U>(*row), total);
129 loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
130 }
131
132 threadgroup U shared_vals[32 * 32];
133 shared_vals[lid.y * lsize.x + lid.x] = total;
134 threadgroup_barrier(mem_flags::mem_threadgroup);
135 if (lid.y == 0) {
136 for (uint i = 1; i < lsize.y; i++) {
137 total = op(total, shared_vals[i * lsize.x + lid.x]);
138 }
139 out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;
140 }
141}
142
154template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
155[[kernel]] void col_reduce_looped(
156 const device T* in [[buffer(0)]],
157 device U* out [[buffer(1)]],
158 const constant size_t& reduction_size [[buffer(2)]],
159 const constant size_t& reduction_stride [[buffer(3)]],
160 const constant int* shape [[buffer(4)]],
161 const constant size_t* strides [[buffer(5)]],
162 const constant int& ndim [[buffer(6)]],
163 const constant int* reduce_shape [[buffer(7)]],
164 const constant size_t* reduce_strides [[buffer(8)]],
165 const constant int& reduce_ndim [[buffer(9)]],
166 const constant size_t& non_col_reductions [[buffer(10)]],
167 uint3 gid [[threadgroup_position_in_grid]],
168 uint3 gsize [[threadgroups_per_grid]],
169 uint simd_lane_id [[thread_index_in_simdgroup]],
170 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
171 Op op;
172 constexpr int n_simdgroups = 8;
173 constexpr short tgp_size = n_simdgroups * simd_size;
174 constexpr short n_reads = (BM * BN) / tgp_size;
175 constexpr short n_read_blocks = BN / n_reads;
176
177 threadgroup U shared_vals[BN * BM];
178 U totals[n_reads];
180 const device T* row;
181
182 for (int i = 0; i < n_reads; i++) {
183 totals[i] = Op::init;
184 }
185
186 short lid = simd_group_id * simd_size + simd_lane_id;
187 short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
188 size_t column = BN * gid.x + offset.x;
189 bool safe = column + n_reads <= reduction_stride;
190
191 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
192 size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
193 in += in_idx + column;
194
195 size_t total = non_col_reductions * reduction_size;
196 loop.next(offset.y, reduce_shape, reduce_strides);
197 for (size_t r = offset.y; r < total; r += BM) {
198 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
199
200 if (safe) {
201 for (int i = 0; i < n_reads; i++) {
202 totals[i] = op(static_cast<U>(row[i]), totals[i]);
203 }
204 } else {
205 U vals[n_reads];
206 for (int i = 0; i < n_reads; i++) {
207 vals[i] =
208 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
209 }
210 for (int i = 0; i < n_reads; i++) {
211 totals[i] = op(vals[i], totals[i]);
212 }
213 }
214
215 loop.next(BM, reduce_shape, reduce_strides);
216 }
217
218 // We can use a simd reduction to accumulate across BM so each thread writes
219 // the partial output to SM and then each simdgroup does BN / n_simdgroups
220 // accumulations.
221 if (BM == 32) {
222 constexpr int n_outputs = BN / n_simdgroups;
223 static_assert(
224 BM != 32 || n_outputs == n_reads,
225 "The tile should be selected such that n_outputs == n_reads");
226 for (int i = 0; i < n_reads; i++) {
227 shared_vals[offset.y * BN + offset.x + i] = totals[i];
228 }
229 threadgroup_barrier(mem_flags::mem_threadgroup);
230 short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
231 for (int i = 0; i < n_outputs; i++) {
232 totals[i] =
233 op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
234 }
235
236 // Write the output.
237 if (simd_lane_id == 0) {
238 size_t out_column = BN * gid.x + out_offset.x;
239 out += out_idx * reduction_stride + out_column;
240 if (out_column + n_outputs <= reduction_stride) {
241 for (int i = 0; i < n_outputs; i++) {
242 out[i] = totals[i];
243 }
244 } else {
245 for (int i = 0; out_column + i < reduction_stride; i++) {
246 out[i] = totals[i];
247 }
248 }
249 }
250 }
251
252 // Each thread holds n_reads partial results. We write them all out to shared
253 // memory and threads with offset.y == 0 aggregate the columns and write the
254 // outputs.
255 else {
256 short x_block = offset.x / n_reads;
257 for (int i = 0; i < n_reads; i++) {
258 shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
259 }
260 threadgroup_barrier(mem_flags::mem_threadgroup);
261 if (offset.y == 0) {
262 for (int i = 0; i < n_reads; i++) {
263 for (int j = 1; j < BM; j++) {
264 totals[i] =
265 op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
266 }
267 }
268 }
269
270 // Write the output.
271 if (offset.y == 0) {
272 out += out_idx * reduction_stride + column;
273 if (safe) {
274 for (int i = 0; i < n_reads; i++) {
275 out[i] = totals[i];
276 }
277 } else {
278 for (int i = 0; column + i < reduction_stride; i++) {
279 out[i] = totals[i];
280 }
281 }
282 }
283 }
284}
285
286template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
287[[kernel]] void col_reduce_2pass(
288 const device T* in [[buffer(0)]],
289 device U* out [[buffer(1)]],
290 const constant size_t& reduction_size [[buffer(2)]],
291 const constant size_t& reduction_stride [[buffer(3)]],
292 const constant int* shape [[buffer(4)]],
293 const constant size_t* strides [[buffer(5)]],
294 const constant int& ndim [[buffer(6)]],
295 const constant int* reduce_shape [[buffer(7)]],
296 const constant size_t* reduce_strides [[buffer(8)]],
297 const constant int& reduce_ndim [[buffer(9)]],
298 const constant size_t& non_col_reductions [[buffer(10)]],
299 const constant size_t& out_size [[buffer(11)]],
300 uint3 gid [[threadgroup_position_in_grid]],
301 uint3 gsize [[threadgroups_per_grid]],
302 uint simd_lane_id [[thread_index_in_simdgroup]],
303 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
304 Op op;
305 constexpr int n_simdgroups = 8;
306 constexpr short tgp_size = n_simdgroups * simd_size;
307 constexpr short n_reads = (BM * BN) / tgp_size;
308 constexpr short n_read_blocks = BN / n_reads;
309 constexpr int n_outputs = BN / n_simdgroups;
310 constexpr short outer_blocks = 32;
311 static_assert(BM == 32, "BM should be equal to 32");
312
313 threadgroup U shared_vals[BN * BM];
314 U totals[n_reads];
316 const device T* row;
317
318 for (int i = 0; i < n_reads; i++) {
319 totals[i] = Op::init;
320 }
321
322 short lid = simd_group_id * simd_size + simd_lane_id;
323 short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
324 size_t column = BN * gid.x + offset.x;
325 bool safe = column + n_reads <= reduction_stride;
326
327 size_t full_idx = gid.y + gsize.y * size_t(gid.z);
328 size_t block_idx = full_idx / out_size;
329 size_t out_idx = full_idx % out_size;
330 size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
331 in += in_idx + column;
332
333 size_t total = non_col_reductions * reduction_size;
334 loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
335 for (size_t r = offset.y + block_idx * BM; r < total;
336 r += outer_blocks * BM) {
337 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
338
339 if (safe) {
340 for (int i = 0; i < n_reads; i++) {
341 totals[i] = op(static_cast<U>(row[i]), totals[i]);
342 }
343 } else {
344 U vals[n_reads];
345 for (int i = 0; i < n_reads; i++) {
346 vals[i] =
347 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
348 }
349 for (int i = 0; i < n_reads; i++) {
350 totals[i] = op(vals[i], totals[i]);
351 }
352 }
353
354 loop.next(outer_blocks * BM, reduce_shape, reduce_strides);
355 }
356
357 // We can use a simd reduction to accumulate across BM so each thread writes
358 // the partial output to SM and then each simdgroup does BN / n_simdgroups
359 // accumulations.
360 for (int i = 0; i < n_reads; i++) {
361 shared_vals[offset.y * BN + offset.x + i] = totals[i];
362 }
363 threadgroup_barrier(mem_flags::mem_threadgroup);
364 short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
365 for (int i = 0; i < n_outputs; i++) {
366 totals[i] =
367 op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
368 }
369
370 // Write the output.
371 if (simd_lane_id == 0) {
372 size_t out_column = BN * gid.x + out_offset.x;
373 out += full_idx * reduction_stride + out_column;
374 if (out_column + n_outputs <= reduction_stride) {
375 for (int i = 0; i < n_outputs; i++) {
376 out[i] = totals[i];
377 }
378 } else {
379 for (int i = 0; out_column + i < reduction_stride; i++) {
380 out[i] = totals[i];
381 }
382 }
383 }
384}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
METAL_FUNC stride_t elem_to_loc(uint elem, constant const int *shape, constant const stride_t *strides, int ndim)
Definition utils.h:87
Op op
Definition binary.h:129
void col_reduce_2pass(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
Definition reduce_col.h:287
void col_reduce_looped(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
Our approach is the following simple looped approach:
Definition reduce_col.h:155
void col_reduce_longcolumn(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
Definition reduce_col.h:97
void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
Definition reduce_col.h:4
Definition utils.h:197
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:202
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:229