MLX
Loading...
Searching...
No Matches
scan.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#define DEFINE_SIMD_SCAN() \
6 template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
7 T simd_scan(T val) { \
8 return simd_scan_impl(val); \
9 } \
10 \
11 template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
12 T simd_scan(T val) { \
13 for (int i = 1; i <= 16; i *= 2) { \
14 val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); \
15 } \
16 return val; \
17 }
18
19#define DEFINE_SIMD_EXCLUSIVE_SCAN() \
20 template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
21 T simd_exclusive_scan(T val) { \
22 return simd_exclusive_scan_impl(val); \
23 } \
24 \
25 template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
26 T simd_exclusive_scan(T val) { \
27 val = simd_scan(val); \
28 return simd_shuffle_and_fill_up(val, init, 1); \
29 }
30
31template <typename U>
32struct CumSum {
35
36 static constexpr constant U init = static_cast<U>(0);
37
38 template <typename T>
39 U operator()(U a, T b) {
40 return a + b;
41 }
42
43 U simd_scan_impl(U x) {
44 return simd_prefix_inclusive_sum(x);
45 }
46
47 U simd_exclusive_scan_impl(U x) {
48 return simd_prefix_exclusive_sum(x);
49 }
50};
51
52template <typename U>
53struct CumProd {
56
57 static constexpr constant U init = static_cast<U>(1.0f);
58
59 template <typename T>
60 U operator()(U a, T b) {
61 return a * b;
62 }
63
64 U simd_scan_impl(U x) {
65 return simd_prefix_inclusive_product(x);
66 }
67
68 U simd_exclusive_scan_impl(U x) {
69 return simd_prefix_exclusive_product(x);
70 }
71};
72
73template <>
74struct CumProd<bool> {
75 static constexpr constant bool init = true;
76
77 template <typename T>
78 bool operator()(bool a, T b) {
79 return a & static_cast<bool>(b);
80 }
81
82 bool simd_scan(bool x) {
83 for (int i = 1; i <= 16; i *= 2) {
84 bool other = simd_shuffle_and_fill_up(x, init, i);
85 x &= other;
86 }
87 return x;
88 }
89
90 bool simd_exclusive_scan(bool x) {
91 x = simd_scan(x);
92 return simd_shuffle_and_fill_up(x, init, 1);
93 }
94};
95
96template <typename U>
97struct CumMax {
98 static constexpr constant U init = Limits<U>::min;
99
100 template <typename T>
101 U operator()(U a, T b) {
102 return (a >= b) ? a : b;
103 }
104
105 U simd_scan(U x) {
106 for (int i = 1; i <= 16; i *= 2) {
107 U other = simd_shuffle_and_fill_up(x, init, i);
108 x = (x >= other) ? x : other;
109 }
110 return x;
111 }
112
114 x = simd_scan(x);
115 return simd_shuffle_and_fill_up(x, init, 1);
116 }
117};
118
119template <typename U>
120struct CumMin {
121 static constexpr constant U init = Limits<U>::max;
122
123 template <typename T>
124 U operator()(U a, T b) {
125 return (a <= b) ? a : b;
126 }
127
128 U simd_scan(U x) {
129 for (int i = 1; i <= 16; i *= 2) {
130 U other = simd_shuffle_and_fill_up(x, init, i);
131 x = (x <= other) ? x : other;
132 }
133 return x;
134 }
135
137 x = simd_scan(x);
138 return simd_shuffle_and_fill_up(x, init, 1);
139 }
140};
141
142template <typename T, typename U, int N_READS, bool reverse>
143inline void load_unsafe(U values[N_READS], const device T* input) {
144 if (reverse) {
145 for (int i = 0; i < N_READS; i++) {
146 values[N_READS - i - 1] = input[i];
147 }
148 } else {
149 for (int i = 0; i < N_READS; i++) {
150 values[i] = input[i];
151 }
152 }
153}
154
155template <typename T, typename U, int N_READS, bool reverse>
156inline void load_safe(
157 U values[N_READS],
158 const device T* input,
159 int start,
160 int total,
161 U init) {
162 if (reverse) {
163 for (int i = 0; i < N_READS; i++) {
164 values[N_READS - i - 1] =
165 (start + N_READS - i - 1 < total) ? input[i] : init;
166 }
167 } else {
168 for (int i = 0; i < N_READS; i++) {
169 values[i] = (start + i < total) ? input[i] : init;
170 }
171 }
172}
173
174template <typename U, int N_READS, bool reverse>
175inline void write_unsafe(U values[N_READS], device U* out) {
176 if (reverse) {
177 for (int i = 0; i < N_READS; i++) {
178 out[i] = values[N_READS - i - 1];
179 }
180 } else {
181 for (int i = 0; i < N_READS; i++) {
182 out[i] = values[i];
183 }
184 }
185}
186
187template <typename U, int N_READS, bool reverse>
188inline void write_safe(U values[N_READS], device U* out, int start, int total) {
189 if (reverse) {
190 for (int i = 0; i < N_READS; i++) {
191 if (start + N_READS - i - 1 < total) {
192 out[i] = values[N_READS - i - 1];
193 }
194 }
195 } else {
196 for (int i = 0; i < N_READS; i++) {
197 if (start + i < total) {
198 out[i] = values[i];
199 }
200 }
201 }
202}
203
204template <
205 typename T,
206 typename U,
207 typename Op,
208 int N_READS,
209 bool inclusive,
210 bool reverse>
211[[kernel]] void contiguous_scan(
212 const device T* in [[buffer(0)]],
213 device U* out [[buffer(1)]],
214 const constant size_t& axis_size [[buffer(2)]],
215 uint3 gid [[threadgroup_position_in_grid]],
216 uint3 gsize [[threadgroups_per_grid]],
217 uint3 lid [[thread_position_in_threadgroup]],
218 uint3 lsize [[threads_per_threadgroup]],
219 uint simd_lane_id [[thread_index_in_simdgroup]],
220 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
221 constexpr int simd_size = 32;
222 Op op;
223
224 // Position the pointers
225 size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size;
226 in += offset;
227 out += offset;
228
229 // Compute the number of simd_groups
230 uint simd_groups = lsize.x / simd_size;
231
232 // Allocate memory
233 U prefix = Op::init;
234 U values[N_READS];
235 threadgroup U simdgroup_sums[32];
236
237 // Loop over the reduced axis in blocks of size ceildiv(axis_size,
238 // N_READS*lsize)
239 // Read block
240 // Compute inclusive scan of the block
241 // Compute inclusive scan per thread
242 // Compute exclusive scan of thread sums in simdgroup
243 // Write simdgroup sums in SM
244 // Compute exclusive scan of simdgroup sums
245 // Compute the output by scanning prefix, prev_simdgroup, prev_thread,
246 // value
247 // Write block
248
249 for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) {
250 // Compute the block offset
251 uint offset = r * lsize.x * N_READS + lid.x * N_READS;
252
253 // Read the values
254 if (reverse) {
255 if ((offset + N_READS) < axis_size) {
257 values, in + axis_size - offset - N_READS);
258 } else {
260 values,
261 in + axis_size - offset - N_READS,
262 offset,
263 axis_size,
264 Op::init);
265 }
266 } else {
267 if ((offset + N_READS) < axis_size) {
268 load_unsafe<T, U, N_READS, reverse>(values, in + offset);
269 } else {
271 values, in + offset, offset, axis_size, Op::init);
272 }
273 }
274
275 // Compute an inclusive scan per thread
276 for (int i = 1; i < N_READS; i++) {
277 values[i] = op(values[i], values[i - 1]);
278 }
279
280 // Compute exclusive scan of thread sums
281 U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
282
283 // Write simdgroup_sums to SM
284 if (simd_lane_id == simd_size - 1) {
285 simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
286 }
287 threadgroup_barrier(mem_flags::mem_threadgroup);
288
289 // Compute exclusive scan of simdgroup_sums
290 if (simd_group_id == 0) {
291 U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
292 simdgroup_sums[simd_lane_id] = prev_simdgroup;
293 }
294 threadgroup_barrier(mem_flags::mem_threadgroup);
295
296 // Compute the output
297 for (int i = 0; i < N_READS; i++) {
298 values[i] = op(values[i], prefix);
299 values[i] = op(values[i], simdgroup_sums[simd_group_id]);
300 values[i] = op(values[i], prev_thread);
301 }
302
303 // Write the values
304 if (reverse) {
305 if (inclusive) {
306 if ((offset + N_READS) < axis_size) {
308 values, out + axis_size - offset - N_READS);
309 } else {
311 values, out + axis_size - offset - N_READS, offset, axis_size);
312 }
313 } else {
314 if (lid.x == 0 && offset == 0) {
315 out[axis_size - 1] = Op::init;
316 }
317 if ((offset + N_READS + 1) < axis_size) {
319 values, out + axis_size - offset - 1 - N_READS);
320 } else {
322 values,
323 out + axis_size - offset - 1 - N_READS,
324 offset + 1,
325 axis_size);
326 }
327 }
328 } else {
329 if (inclusive) {
330 if ((offset + N_READS) < axis_size) {
331 write_unsafe<U, N_READS, reverse>(values, out + offset);
332 } else {
334 values, out + offset, offset, axis_size);
335 }
336 } else {
337 if (lid.x == 0 && offset == 0) {
338 out[0] = Op::init;
339 }
340 if ((offset + N_READS + 1) < axis_size) {
341 write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
342 } else {
344 values, out + offset + 1, offset + 1, axis_size);
345 }
346 }
347 }
348 threadgroup_barrier(mem_flags::mem_threadgroup);
349
350 // Share the prefix
351 if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
352 simdgroup_sums[0] = values[N_READS - 1];
353 }
354 threadgroup_barrier(mem_flags::mem_threadgroup);
355 prefix = simdgroup_sums[0];
356 }
357}
358
359template <
360 typename T,
361 typename U,
362 typename Op,
363 int N_READS,
364 bool inclusive,
365 bool reverse>
366[[kernel]] void strided_scan(
367 const device T* in [[buffer(0)]],
368 device U* out [[buffer(1)]],
369 const constant size_t& axis_size [[buffer(2)]],
370 const constant size_t& stride [[buffer(3)]],
371 const constant size_t& stride_blocks [[buffer(4)]],
372 uint3 gid [[threadgroup_position_in_grid]],
373 uint3 gsize [[threadgroups_per_grid]],
374 uint3 lid [[thread_position_in_threadgroup]],
375 uint simd_lane_id [[thread_index_in_simdgroup]],
376 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
377 constexpr int simd_size = 32;
378 constexpr int BM = 32;
379 constexpr int BN = 32;
380 constexpr int BN_pad = 32 + 16 / sizeof(U);
381 constexpr int n_simds = BN / N_READS;
382 constexpr int n_scans = BN / n_simds;
383 Op op;
384
385 threadgroup U read_buffer[BM * BN_pad];
386 U values[n_scans];
387 U prefix[n_scans];
388 for (int i = 0; i < n_scans; i++) {
389 prefix[i] = Op::init;
390 }
391
392 // Compute offsets
393 size_t full_gid = gid.y + gsize.y * size_t(gid.z);
394 size_t offset = full_gid / stride_blocks * axis_size * stride;
395 size_t global_index_x = full_gid % stride_blocks * BN;
396 uint read_offset_y = (lid.x * N_READS) / BN;
397 uint read_offset_x = (lid.x * N_READS) % BN;
398 uint scan_offset_y = simd_lane_id;
399 uint scan_offset_x = simd_group_id * n_scans;
400
401 uint stride_limit = stride - global_index_x;
402 in += offset + global_index_x + read_offset_x;
403 out += offset + global_index_x + read_offset_x;
404 threadgroup U* read_into =
405 read_buffer + read_offset_y * BN_pad + read_offset_x;
406 threadgroup U* read_from =
407 read_buffer + scan_offset_y * BN_pad + scan_offset_x;
408
409 for (uint j = 0; j < axis_size; j += BM) {
410 // Calculate the indices for the current thread
411 uint index_y = j + read_offset_y;
412 uint check_index_y = index_y;
413 if (reverse) {
414 index_y = axis_size - 1 - index_y;
415 }
416
417 // Read in SM
418 if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
419 for (int i = 0; i < N_READS; i++) {
420 read_into[i] = in[index_y * stride + i];
421 }
422 } else {
423 for (int i = 0; i < N_READS; i++) {
424 if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
425 read_into[i] = in[index_y * stride + i];
426 } else {
427 read_into[i] = Op::init;
428 }
429 }
430 }
431 threadgroup_barrier(mem_flags::mem_threadgroup);
432
433 // Read strided into registers
434 for (int i = 0; i < n_scans; i++) {
435 values[i] = read_from[i];
436 }
437 simdgroup_barrier(mem_flags::mem_threadgroup);
438
439 // Perform the scan
440 for (int i = 0; i < n_scans; i++) {
441 values[i] = op.simd_scan(values[i]);
442 values[i] = op(values[i], prefix[i]);
443 prefix[i] = simd_shuffle(values[i], simd_size - 1);
444 }
445
446 // Write to SM
447 for (int i = 0; i < n_scans; i++) {
448 read_from[i] = values[i];
449 }
450 threadgroup_barrier(mem_flags::mem_threadgroup);
451
452 // Write to device memory
453 if (!inclusive) {
454 if (check_index_y == 0) {
455 if ((read_offset_x + N_READS) < stride_limit) {
456 for (int i = 0; i < N_READS; i++) {
457 out[index_y * stride + i] = Op::init;
458 }
459 } else {
460 for (int i = 0; i < N_READS; i++) {
461 if ((read_offset_x + i) < stride_limit) {
462 out[index_y * stride + i] = Op::init;
463 }
464 }
465 }
466 }
467 if (reverse) {
468 index_y -= 1;
469 check_index_y += 1;
470 } else {
471 index_y += 1;
472 check_index_y += 1;
473 }
474 }
475 if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
476 for (int i = 0; i < N_READS; i++) {
477 out[index_y * stride + i] = read_into[i];
478 }
479 } else {
480 for (int i = 0; i < N_READS; i++) {
481 if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
482 out[index_y * stride + i] = read_into[i];
483 }
484 }
485 }
486 }
487}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
uint64_t simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta)
Definition utils.h:342
uint64_t simd_shuffle(uint64_t data, uint16_t lane)
Definition utils.h:367
T ceildiv(T N, U M)
Compute ceil((float)N/(float)M)
Definition utils.h:272
Op op
Definition binary.h:129
#define DEFINE_SIMD_SCAN()
Definition scan.h:5
#define DEFINE_SIMD_EXCLUSIVE_SCAN()
Definition scan.h:19
void contiguous_scan(const device T *in, device U *out, const constant size_t &axis_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_group_id)
Definition scan.h:211
void strided_scan(const device T *in, device U *out, const constant size_t &axis_size, const constant size_t &stride, const constant size_t &stride_blocks, uint3 gid, uint3 gsize, uint3 lid, uint simd_lane_id, uint simd_group_id)
Definition scan.h:366
void write_unsafe(U values[N_READS], device U *out)
Definition scan.h:175
void load_unsafe(U values[N_READS], const device T *input)
Definition scan.h:143
void write_safe(U values[N_READS], device U *out, int start, int total)
Definition scan.h:188
void load_safe(U values[N_READS], const device T *input, int start, int total, U init)
Definition scan.h:156
Definition scan.h:97
static constexpr constant U init
Definition scan.h:98
U operator()(U a, T b)
Definition scan.h:101
U simd_scan(U x)
Definition scan.h:105
U simd_exclusive_scan(U x)
Definition scan.h:113
Definition scan.h:120
U simd_scan(U x)
Definition scan.h:128
U simd_exclusive_scan(U x)
Definition scan.h:136
static constexpr constant U init
Definition scan.h:121
U operator()(U a, T b)
Definition scan.h:124
bool simd_exclusive_scan(bool x)
Definition scan.h:90
bool simd_scan(bool x)
Definition scan.h:82
bool operator()(bool a, T b)
Definition scan.h:78
Definition scan.h:53
Definition scan.h:32
Definition utils.h:17