MLX
Loading...
Searching...
No Matches
scan.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3template <typename U>
4struct CumSum {
5 static constexpr constant U init = static_cast<U>(0);
6
7 template <typename T>
8 U operator()(U a, T b) {
9 return a + b;
10 }
11
12 U simd_scan(U x) {
13 return simd_prefix_inclusive_sum(x);
14 }
15
17 return simd_prefix_exclusive_sum(x);
18 }
19};
20
21template <typename U>
22struct CumProd {
23 static constexpr constant U init = static_cast<U>(1.0f);
24
25 template <typename T>
26 U operator()(U a, T b) {
27 return a * b;
28 }
29
30 U simd_scan(U x) {
31 return simd_prefix_inclusive_product(x);
32 }
33
35 return simd_prefix_exclusive_product(x);
36 }
37};
38
39template <>
40struct CumProd<bool> {
41 static constexpr constant bool init = true;
42
43 template <typename T>
44 bool operator()(bool a, T b) {
45 return a & static_cast<bool>(b);
46 }
47
48 bool simd_scan(bool x) {
49 for (int i = 1; i <= 16; i *= 2) {
50 bool other = simd_shuffle_up(x, i);
51 x &= other;
52 }
53 return x;
54 }
55
56 bool simd_exclusive_scan(bool x) {
57 x = simd_scan(x);
58 return simd_shuffle_and_fill_up(x, init, 1);
59 }
60};
61
62template <typename U>
63struct CumMax {
64 static constexpr constant U init = Limits<U>::min;
65
66 template <typename T>
67 U operator()(U a, T b) {
68 return (a >= b) ? a : b;
69 }
70
71 U simd_scan(U x) {
72 for (int i = 1; i <= 16; i *= 2) {
73 U other = simd_shuffle_up(x, i);
74 x = (x >= other) ? x : other;
75 }
76 return x;
77 }
78
80 x = simd_scan(x);
81 return simd_shuffle_and_fill_up(x, init, 1);
82 }
83};
84
85template <typename U>
86struct CumMin {
87 static constexpr constant U init = Limits<U>::max;
88
89 template <typename T>
90 U operator()(U a, T b) {
91 return (a <= b) ? a : b;
92 }
93
94 U simd_scan(U x) {
95 for (int i = 1; i <= 16; i *= 2) {
96 U other = simd_shuffle_up(x, i);
97 x = (x <= other) ? x : other;
98 }
99 return x;
100 }
101
103 x = simd_scan(x);
104 return simd_shuffle_and_fill_up(x, init, 1);
105 }
106};
107
108template <typename T, typename U, int N_READS, bool reverse>
109inline void load_unsafe(U values[N_READS], const device T* input) {
110 if (reverse) {
111 for (int i = 0; i < N_READS; i++) {
112 values[N_READS - i - 1] = input[i];
113 }
114 } else {
115 for (int i = 0; i < N_READS; i++) {
116 values[i] = input[i];
117 }
118 }
119}
120
121template <typename T, typename U, int N_READS, bool reverse>
122inline void load_safe(
123 U values[N_READS],
124 const device T* input,
125 int start,
126 int total,
127 U init) {
128 if (reverse) {
129 for (int i = 0; i < N_READS; i++) {
130 values[N_READS - i - 1] =
131 (start + N_READS - i - 1 < total) ? input[i] : init;
132 }
133 } else {
134 for (int i = 0; i < N_READS; i++) {
135 values[i] = (start + i < total) ? input[i] : init;
136 }
137 }
138}
139
140template <typename U, int N_READS, bool reverse>
141inline void write_unsafe(U values[N_READS], device U* out) {
142 if (reverse) {
143 for (int i = 0; i < N_READS; i++) {
144 out[i] = values[N_READS - i - 1];
145 }
146 } else {
147 for (int i = 0; i < N_READS; i++) {
148 out[i] = values[i];
149 }
150 }
151}
152
153template <typename U, int N_READS, bool reverse>
154inline void write_safe(U values[N_READS], device U* out, int start, int total) {
155 if (reverse) {
156 for (int i = 0; i < N_READS; i++) {
157 if (start + N_READS - i - 1 < total) {
158 out[i] = values[N_READS - i - 1];
159 }
160 }
161 } else {
162 for (int i = 0; i < N_READS; i++) {
163 if (start + i < total) {
164 out[i] = values[i];
165 }
166 }
167 }
168}
169
170template <
171 typename T,
172 typename U,
173 typename Op,
174 int N_READS,
175 bool inclusive,
176 bool reverse>
177[[kernel]] void contiguous_scan(
178 const device T* in [[buffer(0)]],
179 device U* out [[buffer(1)]],
180 const constant size_t& axis_size [[buffer(2)]],
181 uint gid [[thread_position_in_grid]],
182 uint lid [[thread_position_in_threadgroup]],
183 uint lsize [[threads_per_threadgroup]],
184 uint simd_size [[threads_per_simdgroup]],
185 uint simd_lane_id [[thread_index_in_simdgroup]],
186 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
187 Op op;
188
189 // Position the pointers
190 in += (gid / lsize) * axis_size;
191 out += (gid / lsize) * axis_size;
192
193 // Compute the number of simd_groups
194 uint simd_groups = lsize / simd_size;
195
196 // Allocate memory
197 U prefix = Op::init;
198 U values[N_READS];
199 threadgroup U simdgroup_sums[32];
200
201 // Loop over the reduced axis in blocks of size ceildiv(axis_size,
202 // N_READS*lsize)
203 // Read block
204 // Compute inclusive scan of the block
205 // Compute inclusive scan per thread
206 // Compute exclusive scan of thread sums in simdgroup
207 // Write simdgroup sums in SM
208 // Compute exclusive scan of simdgroup sums
209 // Compute the output by scanning prefix, prev_simdgroup, prev_thread,
210 // value
211 // Write block
212
213 for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
214 // Compute the block offset
215 uint offset = r * lsize * N_READS + lid * N_READS;
216
217 // Read the values
218 if (reverse) {
219 if ((offset + N_READS) < axis_size) {
221 values, in + axis_size - offset - N_READS);
222 } else {
224 values,
225 in + axis_size - offset - N_READS,
226 offset,
227 axis_size,
228 Op::init);
229 }
230 } else {
231 if ((offset + N_READS) < axis_size) {
232 load_unsafe<T, U, N_READS, reverse>(values, in + offset);
233 } else {
235 values, in + offset, offset, axis_size, Op::init);
236 }
237 }
238
239 // Compute an inclusive scan per thread
240 for (int i = 1; i < N_READS; i++) {
241 values[i] = op(values[i], values[i - 1]);
242 }
243
244 // Compute exclusive scan of thread sums
245 U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
246
247 // Write simdgroup_sums to SM
248 if (simd_lane_id == simd_size - 1) {
249 simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
250 }
251 threadgroup_barrier(mem_flags::mem_threadgroup);
252
253 // Compute exclusive scan of simdgroup_sums
254 if (simd_group_id == 0) {
255 U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
256 simdgroup_sums[simd_lane_id] = prev_simdgroup;
257 }
258 threadgroup_barrier(mem_flags::mem_threadgroup);
259
260 // Compute the output
261 for (int i = 0; i < N_READS; i++) {
262 values[i] = op(values[i], prefix);
263 values[i] = op(values[i], simdgroup_sums[simd_group_id]);
264 values[i] = op(values[i], prev_thread);
265 }
266
267 // Write the values
268 if (reverse) {
269 if (inclusive) {
270 if ((offset + N_READS) < axis_size) {
272 values, out + axis_size - offset - N_READS);
273 } else {
275 values, out + axis_size - offset - N_READS, offset, axis_size);
276 }
277 } else {
278 if (lid == 0 && offset == 0) {
279 out[axis_size - 1] = Op::init;
280 }
281 if ((offset + N_READS + 1) < axis_size) {
283 values, out + axis_size - offset - 1 - N_READS);
284 } else {
286 values,
287 out + axis_size - offset - 1 - N_READS,
288 offset + 1,
289 axis_size);
290 }
291 }
292 } else {
293 if (inclusive) {
294 if ((offset + N_READS) < axis_size) {
295 write_unsafe<U, N_READS, reverse>(values, out + offset);
296 } else {
298 values, out + offset, offset, axis_size);
299 }
300 } else {
301 if (lid == 0 && offset == 0) {
302 out[0] = Op::init;
303 }
304 if ((offset + N_READS + 1) < axis_size) {
305 write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
306 } else {
308 values, out + offset + 1, offset + 1, axis_size);
309 }
310 }
311 }
312 threadgroup_barrier(mem_flags::mem_threadgroup);
313
314 // Share the prefix
315 if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
316 simdgroup_sums[0] = values[N_READS - 1];
317 }
318 threadgroup_barrier(mem_flags::mem_threadgroup);
319 prefix = simdgroup_sums[0];
320 }
321}
322
323template <
324 typename T,
325 typename U,
326 typename Op,
327 int N_READS,
328 bool inclusive,
329 bool reverse>
330[[kernel]] void strided_scan(
331 const device T* in [[buffer(0)]],
332 device U* out [[buffer(1)]],
333 const constant size_t& axis_size [[buffer(2)]],
334 const constant size_t& stride [[buffer(3)]],
335 uint2 gid [[threadgroup_position_in_grid]],
336 uint2 lid [[thread_position_in_threadgroup]],
337 uint2 lsize [[threads_per_threadgroup]],
338 uint simd_size [[threads_per_simdgroup]]) {
339 Op op;
340
341 // Allocate memory
342 threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
343 U values[N_READS];
344 U prefix[N_READS];
345 for (int i = 0; i < N_READS; i++) {
346 prefix[i] = Op::init;
347 }
348
349 // Compute offsets
350 int offset = gid.y * axis_size * stride;
351 int global_index_x = gid.x * lsize.y * N_READS;
352
353 for (uint j = 0; j < axis_size; j += simd_size) {
354 // Calculate the indices for the current thread
355 uint index_y = j + lid.y;
356 uint check_index_y = index_y;
357 uint index_x = global_index_x + lid.x * N_READS;
358 if (reverse) {
359 index_y = axis_size - 1 - index_y;
360 }
361
362 // Read in SM
363 if (check_index_y < axis_size && (index_x + N_READS) < stride) {
364 for (int i = 0; i < N_READS; i++) {
365 read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
366 in[offset + index_y * stride + index_x + i];
367 }
368 } else {
369 for (int i = 0; i < N_READS; i++) {
370 if (check_index_y < axis_size && (index_x + i) < stride) {
371 read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
372 in[offset + index_y * stride + index_x + i];
373 } else {
374 read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
375 Op::init;
376 }
377 }
378 }
379 threadgroup_barrier(mem_flags::mem_threadgroup);
380
381 // Read strided into registers
382 for (int i = 0; i < N_READS; i++) {
383 values[i] =
384 read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
385 }
386 // Do we need the following barrier? Shouldn't all simd threads execute
387 // simultaneously?
388 simdgroup_barrier(mem_flags::mem_threadgroup);
389
390 // Perform the scan
391 for (int i = 0; i < N_READS; i++) {
392 values[i] = op.simd_scan(values[i]);
393 values[i] = op(values[i], prefix[i]);
394 prefix[i] = simd_shuffle(values[i], simd_size - 1);
395 }
396
397 // Write to SM
398 for (int i = 0; i < N_READS; i++) {
399 read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
400 values[i];
401 }
402 threadgroup_barrier(mem_flags::mem_threadgroup);
403
404 // Write to device memory
405 if (!inclusive) {
406 if (check_index_y == 0) {
407 if ((index_x + N_READS) < stride) {
408 for (int i = 0; i < N_READS; i++) {
409 out[offset + index_y * stride + index_x + i] = Op::init;
410 }
411 } else {
412 for (int i = 0; i < N_READS; i++) {
413 if ((index_x + i) < stride) {
414 out[offset + index_y * stride + index_x + i] = Op::init;
415 }
416 }
417 }
418 }
419 if (reverse) {
420 index_y -= 1;
421 check_index_y += 1;
422 } else {
423 index_y += 1;
424 check_index_y += 1;
425 }
426 }
427 if (check_index_y < axis_size && (index_x + N_READS) < stride) {
428 for (int i = 0; i < N_READS; i++) {
429 out[offset + index_y * stride + index_x + i] =
430 read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
431 }
432 } else {
433 for (int i = 0; i < N_READS; i++) {
434 if (check_index_y < axis_size && (index_x + i) < stride) {
435 out[offset + index_y * stride + index_x + i] =
436 read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
437 }
438 }
439 }
440 }
441}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
T ceildiv(T N, U M)
Compute ceil((float)N/(float)M)
Definition utils.h:272
Op op
Definition binary.h:129
void contiguous_scan(const device T *in, device U *out, const constant size_t &axis_size, uint gid, uint lid, uint lsize, uint simd_size, uint simd_lane_id, uint simd_group_id)
Definition scan.h:177
void write_unsafe(U values[N_READS], device U *out)
Definition scan.h:141
void load_unsafe(U values[N_READS], const device T *input)
Definition scan.h:109
void write_safe(U values[N_READS], device U *out, int start, int total)
Definition scan.h:154
void load_safe(U values[N_READS], const device T *input, int start, int total, U init)
Definition scan.h:122
void strided_scan(const device T *in, device U *out, const constant size_t &axis_size, const constant size_t &stride, uint2 gid, uint2 lid, uint2 lsize, uint simd_size)
Definition scan.h:330
Definition scan.h:63
static constexpr constant U init
Definition scan.h:64
U operator()(U a, T b)
Definition scan.h:67
U simd_scan(U x)
Definition scan.h:71
U simd_exclusive_scan(U x)
Definition scan.h:79
Definition scan.h:86
U simd_scan(U x)
Definition scan.h:94
U simd_exclusive_scan(U x)
Definition scan.h:102
static constexpr constant U init
Definition scan.h:87
U operator()(U a, T b)
Definition scan.h:90
bool simd_exclusive_scan(bool x)
Definition scan.h:56
bool simd_scan(bool x)
Definition scan.h:48
bool operator()(bool a, T b)
Definition scan.h:44
Definition scan.h:22
U simd_exclusive_scan(U x)
Definition scan.h:34
U simd_scan(U x)
Definition scan.h:30
static constexpr constant U init
Definition scan.h:23
U operator()(U a, T b)
Definition scan.h:26
Definition scan.h:4
U simd_exclusive_scan(U x)
Definition scan.h:16
U simd_scan(U x)
Definition scan.h:12
static constexpr constant U init
Definition scan.h:5
U operator()(U a, T b)
Definition scan.h:8
Definition utils.h:17