MLX
Loading...
Searching...
No Matches
reduce_col.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3template <
4 typename T,
5 typename U,
6 typename Op,
7 int NDIMS,
8 int N_READS = REDUCE_N_READS>
9[[kernel]] void col_reduce_small(
10 const device T* in [[buffer(0)]],
11 device U* out [[buffer(1)]],
12 const constant size_t& reduction_size [[buffer(2)]],
13 const constant size_t& reduction_stride [[buffer(3)]],
14 const constant int* shape [[buffer(4)]],
15 const constant size_t* strides [[buffer(5)]],
16 const constant int& ndim [[buffer(6)]],
17 const constant int* reduce_shape [[buffer(7)]],
18 const constant size_t* reduce_strides [[buffer(8)]],
19 const constant int& reduce_ndim [[buffer(9)]],
20 const constant size_t& non_col_reductions [[buffer(10)]],
21 uint3 gid [[threadgroup_position_in_grid]],
22 uint3 gsize [[threadgroups_per_grid]],
23 uint simd_lane_id [[thread_index_in_simdgroup]],
24 uint simd_group_id [[simdgroup_index_in_threadgroup]],
25 uint3 tid [[thread_position_in_grid]],
26 uint3 tsize [[threads_per_grid]]) {
27 Op op;
29 const device T* row;
30
31 // Case 1: Small row small column
32 if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
33 U totals[31];
34 for (int i = 0; i < 31; i++) {
35 totals[i] = Op::init;
36 }
37
38 short stride = reduction_stride;
39 short size = reduction_size;
40 short blocks = stride / N_READS;
41 short extra = stride - blocks * N_READS;
42
43 size_t out_idx = tid.x + tsize.y * size_t(tid.y);
44 in += elem_to_loc(out_idx, shape, strides, ndim);
45
46 for (uint r = 0; r < non_col_reductions; r++) {
47 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
48
49 for (short i = 0; i < size; i++) {
50 for (short j = 0; j < blocks; j++) {
51 for (short k = 0; k < N_READS; k++) {
52 totals[j * N_READS + k] =
53 op(totals[j * N_READS + k],
54 static_cast<U>(row[i * stride + j * N_READS + k]));
55 }
56 }
57 for (short k = 0; k < extra; k++) {
58 totals[blocks * N_READS + k] =
59 op(totals[blocks * N_READS + k],
60 static_cast<U>(row[i * stride + blocks * N_READS + k]));
61 }
62 }
63
64 loop.next(reduce_shape, reduce_strides);
65 }
66 out += out_idx * reduction_stride;
67 for (short j = 0; j < stride; j++) {
68 out[j] = totals[j];
69 }
70 }
71
72 // Case 2: Long row small column
73 else if (reduction_size * non_col_reductions < 32) {
74 U totals[N_READS];
75 for (int i = 0; i < N_READS; i++) {
76 totals[i] = Op::init;
77 }
78
79 short size = reduction_size;
80 size_t offset = size_t(tid.x) * N_READS;
81 bool safe = offset + N_READS <= reduction_stride;
82 short extra = reduction_stride - offset;
83
84 size_t out_idx = tid.y + tsize.z * size_t(tid.z);
85 in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
86
87 for (uint r = 0; r < non_col_reductions; r++) {
88 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
89
90 if (safe) {
91 for (short i = 0; i < size; i++) {
92 for (short j = 0; j < N_READS; j++) {
93 totals[j] =
94 op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
95 }
96 }
97 } else {
98 for (short i = 0; i < size; i++) {
99 for (short j = 0; j < extra; j++) {
100 totals[j] =
101 op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
102 }
103 }
104 }
105
106 loop.next(reduce_shape, reduce_strides);
107 }
108 out += out_idx * reduction_stride + offset;
109 if (safe) {
110 for (short i = 0; i < N_READS; i++) {
111 out[i] = totals[i];
112 }
113 } else {
114 for (short i = 0; i < extra; i++) {
115 out[i] = totals[i];
116 }
117 }
118 }
119
120 // Case 3: Long row medium column
121 else {
122 threadgroup U shared_vals[1024];
123 U totals[N_READS];
124 for (int i = 0; i < N_READS; i++) {
125 totals[i] = Op::init;
126 }
127
128 short stride = reduction_stride;
129 short lid = simd_group_id * simd_size + simd_lane_id;
130 short2 tile((stride + N_READS - 1) / N_READS, 32);
131 short2 offset((lid % tile.x) * N_READS, lid / tile.x);
132 short sm_stride = tile.x * N_READS;
133 bool safe = offset.x + N_READS <= stride;
134
135 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
136 in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
137
138 // Read cooperatively and contiguously and aggregate the partial results.
139 size_t total = non_col_reductions * reduction_size;
140 loop.next(offset.y, reduce_shape, reduce_strides);
141 for (size_t r = offset.y; r < total; r += simd_size) {
142 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
143
144 if (safe) {
145 for (int i = 0; i < N_READS; i++) {
146 totals[i] = op(static_cast<U>(row[i]), totals[i]);
147 }
148 } else {
149 U vals[N_READS];
150 for (int i = 0; i < N_READS; i++) {
151 vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : op.init;
152 }
153 for (int i = 0; i < N_READS; i++) {
154 totals[i] = op(vals[i], totals[i]);
155 }
156 }
157
158 loop.next(simd_size, reduce_shape, reduce_strides);
159 }
160
161 // Each thread holds N_READS partial results but the simdgroups are not
162 // aligned to do the reduction across the simdgroup so we write our results
163 // in the shared memory and read them back according to the simdgroup.
164 for (int i = 0; i < N_READS; i++) {
165 shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
166 }
167 threadgroup_barrier(mem_flags::mem_threadgroup);
168 for (int i = 0; i < N_READS; i++) {
169 totals[i] = op.simd_reduce(
170 shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
171 }
172
173 // Write the output.
174 if (simd_lane_id == 0) {
175 short column = simd_group_id * N_READS;
176 out += out_idx * reduction_stride + column;
177 if (column + N_READS <= stride) {
178 for (int i = 0; i < N_READS; i++) {
179 out[i] = totals[i];
180 }
181 } else {
182 for (int i = 0; column + i < stride; i++) {
183 out[i] = totals[i];
184 }
185 }
186 }
187 }
188}
189
201template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
202[[kernel]] void col_reduce_looped(
203 const device T* in [[buffer(0)]],
204 device U* out [[buffer(1)]],
205 const constant size_t& reduction_size [[buffer(2)]],
206 const constant size_t& reduction_stride [[buffer(3)]],
207 const constant int* shape [[buffer(4)]],
208 const constant size_t* strides [[buffer(5)]],
209 const constant int& ndim [[buffer(6)]],
210 const constant int* reduce_shape [[buffer(7)]],
211 const constant size_t* reduce_strides [[buffer(8)]],
212 const constant int& reduce_ndim [[buffer(9)]],
213 const constant size_t& non_col_reductions [[buffer(10)]],
214 uint3 gid [[threadgroup_position_in_grid]],
215 uint3 gsize [[threadgroups_per_grid]],
216 uint simd_lane_id [[thread_index_in_simdgroup]],
217 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
218 Op op;
219 constexpr int n_simdgroups = 4;
220 constexpr short tgp_size = n_simdgroups * simd_size;
221 constexpr short n_reads = (BM * BN) / tgp_size;
222 constexpr short n_read_blocks = BN / n_reads;
223
224 threadgroup U shared_vals[BN * BM];
225 U totals[n_reads];
227 const device T* row;
228
229 for (int i = 0; i < n_reads; i++) {
230 totals[i] = Op::init;
231 }
232
233 short lid = simd_group_id * simd_size + simd_lane_id;
234 short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
235 size_t column = BN * gid.x + offset.x;
236 bool safe = column + n_reads <= reduction_stride;
237
238 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
239 size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
240 in += in_idx + column;
241
242 size_t total = non_col_reductions * reduction_size;
243 loop.next(offset.y, reduce_shape, reduce_strides);
244 for (size_t r = offset.y; r < total; r += BM) {
245 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
246
247 if (safe) {
248 for (int i = 0; i < n_reads; i++) {
249 totals[i] = op(static_cast<U>(row[i]), totals[i]);
250 }
251 } else {
252 U vals[n_reads];
253 for (int i = 0; i < n_reads; i++) {
254 vals[i] =
255 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
256 }
257 for (int i = 0; i < n_reads; i++) {
258 totals[i] = op(vals[i], totals[i]);
259 }
260 }
261
262 loop.next(BM, reduce_shape, reduce_strides);
263 }
264
265 // We can use a simd reduction to accumulate across BM so each thread writes
266 // the partial output to SM and then each simdgroup does BN / n_simdgroups
267 // accumulations.
268 if (BM == 32) {
269 constexpr int n_outputs = BN / n_simdgroups;
270 static_assert(
271 BM != 32 || n_outputs == n_reads,
272 "The tile should be selected such that n_outputs == n_reads");
273 for (int i = 0; i < n_reads; i++) {
274 shared_vals[offset.y * BN + offset.x + i] = totals[i];
275 }
276 threadgroup_barrier(mem_flags::mem_threadgroup);
277 short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
278 for (int i = 0; i < n_outputs; i++) {
279 totals[i] =
280 op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
281 }
282
283 // Write the output.
284 if (simd_lane_id == 0) {
285 size_t out_column = BN * gid.x + out_offset.x;
286 out += out_idx * reduction_stride + out_column;
287 if (out_column + n_outputs <= reduction_stride) {
288 for (int i = 0; i < n_outputs; i++) {
289 out[i] = totals[i];
290 }
291 } else {
292 for (int i = 0; out_column + i < reduction_stride; i++) {
293 out[i] = totals[i];
294 }
295 }
296 }
297 }
298
299 // Each thread holds n_reads partial results. We write them all out to shared
300 // memory and threads with offset.y == 0 aggregate the columns and write the
301 // outputs.
302 else {
303 short x_block = offset.x / n_reads;
304 for (int i = 0; i < n_reads; i++) {
305 shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
306 }
307 threadgroup_barrier(mem_flags::mem_threadgroup);
308 if (offset.y == 0) {
309 for (int i = 0; i < n_reads; i++) {
310 for (int j = 1; j < BM; j++) {
311 totals[i] =
312 op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
313 }
314 }
315 }
316
317 // Write the output.
318 if (offset.y == 0) {
319 out += out_idx * reduction_stride + column;
320 if (safe) {
321 for (int i = 0; i < n_reads; i++) {
322 out[i] = totals[i];
323 }
324 } else {
325 for (int i = 0; column + i < reduction_stride; i++) {
326 out[i] = totals[i];
327 }
328 }
329 }
330 }
331}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
METAL_FUNC stride_t elem_to_loc(uint elem, constant const int *shape, constant const stride_t *strides, int ndim)
Definition utils.h:87
Op op
Definition binary.h:129
static constexpr int REDUCE_N_READS
Definition defines.h:12
void col_reduce_looped(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
Our approach is the following simple looped approach:
Definition reduce_col.h:202
void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 tsize)
Definition reduce_col.h:9
Definition utils.h:197
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:202
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:229