MLX
Loading...
Searching...
No Matches
sort.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#define MLX_MTL_CONST static constant constexpr const
4#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
5
6using namespace metal;
7
8// Based on GPU merge sort algorithm at
9// https://github.com/NVIDIA/cccl/tree/main/cub/cub
10
12// Thread-level sort
14
15template <typename T>
16METAL_FUNC void thread_swap(thread T& a, thread T& b) {
17 T w = a;
18 a = b;
19 b = w;
20}
21
22template <typename T>
23struct LessThan {
24 static constexpr constant T init = Limits<T>::max;
25
26 METAL_FUNC bool operator()(T a, T b) {
27 return a < b;
28 }
29};
30
31template <
32 typename val_t,
33 typename idx_t,
34 bool ARG_SORT,
35 short N_PER_THREAD,
36 typename CompareOp>
37struct ThreadSort {
38 static METAL_FUNC void sort(
39 thread val_t (&vals)[N_PER_THREAD],
40 thread idx_t (&idxs)[N_PER_THREAD]) {
41 CompareOp op;
42
44 for (short i = 0; i < N_PER_THREAD; ++i) {
46 for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
47 if (op(vals[j + 1], vals[j])) {
48 thread_swap(vals[j + 1], vals[j]);
49 thread_swap(idxs[j + 1], idxs[j]);
50 }
51 }
52 }
53 }
54};
55
57// Threadgroup-level sort
59
60template <
61 typename val_t,
62 typename idx_t,
63 bool ARG_SORT,
64 short BLOCK_THREADS,
65 short N_PER_THREAD,
66 typename CompareOp>
70 static METAL_FUNC int merge_partition(
71 const threadgroup val_t* As,
72 const threadgroup val_t* Bs,
73 short A_sz,
74 short B_sz,
75 short sort_md) {
76 CompareOp op;
77
78 short A_st = max(0, sort_md - B_sz);
79 short A_ed = min(sort_md, A_sz);
80
81 while (A_st < A_ed) {
82 short md = A_st + (A_ed - A_st) / 2;
83 auto a = As[md];
84 auto b = Bs[sort_md - 1 - md];
85
86 if (op(b, a)) {
87 A_ed = md;
88 } else {
89 A_st = md + 1;
90 }
91 }
92
93 return A_ed;
94 }
95
96 static METAL_FUNC void merge_step(
97 const threadgroup val_t* As,
98 const threadgroup val_t* Bs,
99 const threadgroup idx_t* As_idx,
100 const threadgroup idx_t* Bs_idx,
101 short A_sz,
102 short B_sz,
103 thread val_t (&vals)[N_PER_THREAD],
104 thread idx_t (&idxs)[N_PER_THREAD]) {
105 CompareOp op;
106 short a_idx = 0;
107 short b_idx = 0;
108
109 for (int i = 0; i < N_PER_THREAD; ++i) {
110 auto a = As[a_idx];
111 auto b = Bs[b_idx];
112 bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
113
114 vals[i] = pred ? b : a;
115 idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
116
117 b_idx += short(pred);
118 a_idx += short(!pred);
119 }
120 }
121
122 static METAL_FUNC void sort(
123 threadgroup val_t* tgp_vals [[threadgroup(0)]],
124 threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
125 int size_sorted_axis,
126 uint3 lid [[thread_position_in_threadgroup]]) {
127 // Get thread location
128 int idx = lid.x * N_PER_THREAD;
129
130 // Load from shared memory
131 thread val_t thread_vals[N_PER_THREAD];
132 thread idx_t thread_idxs[N_PER_THREAD];
133 for (int i = 0; i < N_PER_THREAD; ++i) {
134 thread_vals[i] = tgp_vals[idx + i];
135 if (ARG_SORT) {
136 thread_idxs[i] = tgp_idxs[idx + i];
137 }
138 }
139
140 // Per thread sort
141 if (idx < size_sorted_axis) {
142 thread_sort_t::sort(thread_vals, thread_idxs);
143 }
144
145 // Do merges using threadgroup memory
146 for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
147 merge_threads *= 2) {
148 // Update threadgroup memory
149 threadgroup_barrier(mem_flags::mem_threadgroup);
150 for (int i = 0; i < N_PER_THREAD; ++i) {
151 tgp_vals[idx + i] = thread_vals[i];
152 if (ARG_SORT) {
153 tgp_idxs[idx + i] = thread_idxs[i];
154 }
155 }
156 threadgroup_barrier(mem_flags::mem_threadgroup);
157
158 // Find location in merge step
159 int merge_group = lid.x / merge_threads;
160 int merge_lane = lid.x % merge_threads;
161
162 int sort_sz = N_PER_THREAD * merge_threads;
163 int sort_st = N_PER_THREAD * merge_threads * merge_group;
164
165 // As = tgp_vals[A_st:A_ed] is sorted
166 // Bs = tgp_vals[B_st:B_ed] is sorted
167 int A_st = sort_st;
168 int A_ed = sort_st + sort_sz / 2;
169 int B_st = sort_st + sort_sz / 2;
170 int B_ed = sort_st + sort_sz;
171
172 const threadgroup val_t* As = tgp_vals + A_st;
173 const threadgroup val_t* Bs = tgp_vals + B_st;
174 int A_sz = A_ed - A_st;
175 int B_sz = B_ed - B_st;
176
177 // Find a partition of merge elements
178 // Ci = merge(As[partition:], Bs[sort_md - partition:])
179 // of size N_PER_THREAD for each merge lane i
180 // C = [Ci] is sorted
181 int sort_md = N_PER_THREAD * merge_lane;
182 int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
183
184 As += partition;
185 Bs += sort_md - partition;
186
187 A_sz -= partition;
188 B_sz -= sort_md - partition;
189
190 const threadgroup idx_t* As_idx =
191 ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
192 const threadgroup idx_t* Bs_idx =
193 ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
194
195 // Merge starting at the partition and store results in thread registers
196 merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
197 }
198
199 // Write out to shared memory
200 threadgroup_barrier(mem_flags::mem_threadgroup);
201 for (int i = 0; i < N_PER_THREAD; ++i) {
202 tgp_vals[idx + i] = thread_vals[i];
203 if (ARG_SORT) {
204 tgp_idxs[idx + i] = thread_idxs[i];
205 }
206 }
207 }
208};
209
211// Kernel sort
213
214template <
215 typename T,
216 typename U,
217 bool ARG_SORT,
218 short BLOCK_THREADS,
219 short N_PER_THREAD,
220 typename CompareOp = LessThan<T>>
222 using val_t = T;
223 using idx_t = uint;
225 val_t,
226 idx_t,
227 ARG_SORT,
228 BLOCK_THREADS,
229 N_PER_THREAD,
230 CompareOp>;
231
232 MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
233
234 static METAL_FUNC void block_sort(
235 const device T* inp,
236 device U* out,
237 const constant int& size_sorted_axis,
238 const constant int& in_stride_sorted_axis,
239 const constant int& out_stride_sorted_axis,
240 const constant int& in_stride_segment_axis,
241 const constant int& out_stride_segment_axis,
242 threadgroup val_t* tgp_vals,
243 threadgroup idx_t* tgp_idxs,
244 uint3 tid [[threadgroup_position_in_grid]],
245 uint3 lid [[thread_position_in_threadgroup]]) {
246 // tid.y tells us the segment index
247 inp += tid.y * in_stride_segment_axis;
248 out += tid.y * out_stride_segment_axis;
249
250 // Copy into threadgroup memory
251 for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
252 tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
253 : val_t(CompareOp::init);
254 if (ARG_SORT) {
255 tgp_idxs[i] = i;
256 }
257 }
258
259 // Sort elements within the block
260 threadgroup_barrier(mem_flags::mem_threadgroup);
261
262 block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
263
264 threadgroup_barrier(mem_flags::mem_threadgroup);
265
266 // Write output
267 for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
268 if (ARG_SORT) {
269 out[i * out_stride_sorted_axis] = tgp_idxs[i];
270 } else {
271 out[i * out_stride_sorted_axis] = tgp_vals[i];
272 }
273 }
274 }
275};
276
277template <
278 typename T,
279 typename U,
280 bool ARG_SORT,
281 short BLOCK_THREADS,
282 short N_PER_THREAD>
283[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
284 const device T* inp [[buffer(0)]],
285 device U* out [[buffer(1)]],
286 const constant int& size_sorted_axis [[buffer(2)]],
287 const constant int& in_stride_sorted_axis [[buffer(3)]],
288 const constant int& out_stride_sorted_axis [[buffer(4)]],
289 const constant int& in_stride_segment_axis [[buffer(5)]],
290 const constant int& out_stride_segment_axis [[buffer(6)]],
291 uint3 tid [[threadgroup_position_in_grid]],
292 uint3 lid [[thread_position_in_threadgroup]]) {
293 using sort_kernel =
295 using val_t = typename sort_kernel::val_t;
296 using idx_t = typename sort_kernel::idx_t;
297
298 if (ARG_SORT) {
299 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
300 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
301 sort_kernel::block_sort(
302 inp,
303 out,
304 size_sorted_axis,
305 in_stride_sorted_axis,
306 out_stride_sorted_axis,
307 in_stride_segment_axis,
308 out_stride_segment_axis,
309 tgp_vals,
310 tgp_idxs,
311 tid,
312 lid);
313 } else {
314 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
315 sort_kernel::block_sort(
316 inp,
317 out,
318 size_sorted_axis,
319 in_stride_sorted_axis,
320 out_stride_sorted_axis,
321 in_stride_segment_axis,
322 out_stride_segment_axis,
323 tgp_vals,
324 nullptr,
325 tid,
326 lid);
327 }
328}
329
330constant constexpr const int zero_helper = 0;
331
332template <
333 typename T,
334 typename U,
335 bool ARG_SORT,
336 short BLOCK_THREADS,
337 short N_PER_THREAD>
338[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
339 const device T* inp [[buffer(0)]],
340 device U* out [[buffer(1)]],
341 const constant int& size_sorted_axis [[buffer(2)]],
342 const constant int& in_stride_sorted_axis [[buffer(3)]],
343 const constant int& out_stride_sorted_axis [[buffer(4)]],
344 const constant int& nc_dim [[buffer(5)]],
345 const device int* nc_shape [[buffer(6)]],
346 const device size_t* in_nc_strides [[buffer(7)]],
347 const device size_t* out_nc_strides [[buffer(8)]],
348 uint3 tid [[threadgroup_position_in_grid]],
349 uint3 lid [[thread_position_in_threadgroup]]) {
350 using sort_kernel =
352 using val_t = typename sort_kernel::val_t;
353 using idx_t = typename sort_kernel::idx_t;
354
355 auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
356 auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
357 inp += in_block_idx;
358 out += out_block_idx;
359
360 if (ARG_SORT) {
361 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
362 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
363 sort_kernel::block_sort(
364 inp,
365 out,
366 size_sorted_axis,
367 in_stride_sorted_axis,
368 out_stride_sorted_axis,
371 tgp_vals,
372 tgp_idxs,
373 tid,
374 lid);
375 } else {
376 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
377 sort_kernel::block_sort(
378 inp,
379 out,
380 size_sorted_axis,
381 in_stride_sorted_axis,
382 out_stride_sorted_axis,
385 tgp_vals,
386 nullptr,
387 tid,
388 lid);
389 }
390}
391
392template <
393 typename val_t,
394 typename idx_t,
395 bool ARG_SORT,
396 short BLOCK_THREADS,
397 short N_PER_THREAD,
398 typename CompareOp = LessThan<val_t>>
401 val_t,
402 idx_t,
403 ARG_SORT,
404 BLOCK_THREADS,
405 N_PER_THREAD,
406 CompareOp>;
407
408 MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
409
410 static METAL_FUNC void block_sort(
411 const device val_t* inp,
412 device val_t* out_vals,
413 device idx_t* out_idxs,
414 const constant int& size_sorted_axis,
415 const constant int& stride_sorted_axis,
416 threadgroup val_t* tgp_vals,
417 threadgroup idx_t* tgp_idxs,
418 uint3 tid [[threadgroup_position_in_grid]],
419 uint3 lid [[thread_position_in_threadgroup]]) {
420 // tid.y tells us the segment index
421 int base_idx = tid.x * N_PER_BLOCK;
422
423 // Copy into threadgroup memory
424 for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
425 int idx = base_idx + i;
426 tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
427 : val_t(CompareOp::init);
428 tgp_idxs[i] = idx;
429 }
430
431 // Sort elements within the block
432 threadgroup_barrier(mem_flags::mem_threadgroup);
433
434 block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
435
436 threadgroup_barrier(mem_flags::mem_threadgroup);
437
438 // Write output
439 for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
440 int idx = base_idx + i;
441 if (idx < size_sorted_axis) {
442 out_vals[idx] = tgp_vals[i];
443 out_idxs[idx] = tgp_idxs[i];
444 }
445 }
446 }
447
448 static METAL_FUNC int merge_partition(
449 const device val_t* As,
450 const device val_t* Bs,
451 int A_sz,
452 int B_sz,
453 int sort_md) {
454 CompareOp op;
455
456 int A_st = max(0, sort_md - B_sz);
457 int A_ed = min(sort_md, A_sz);
458
459 while (A_st < A_ed) {
460 int md = A_st + (A_ed - A_st) / 2;
461 auto a = As[md];
462 auto b = Bs[sort_md - 1 - md];
463
464 if (op(b, a)) {
465 A_ed = md;
466 } else {
467 A_st = md + 1;
468 }
469 }
470
471 return A_ed;
472 }
473};
474
475template <
476 typename val_t,
477 typename idx_t,
478 bool ARG_SORT,
479 short BLOCK_THREADS,
480 short N_PER_THREAD>
481[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
482 const device val_t* inp [[buffer(0)]],
483 device val_t* out_vals [[buffer(1)]],
484 device idx_t* out_idxs [[buffer(2)]],
485 const constant int& size_sorted_axis [[buffer(3)]],
486 const constant int& stride_sorted_axis [[buffer(4)]],
487 const constant int& nc_dim [[buffer(5)]],
488 const device int* nc_shape [[buffer(6)]],
489 const device size_t* nc_strides [[buffer(7)]],
490 uint3 tid [[threadgroup_position_in_grid]],
491 uint3 lid [[thread_position_in_threadgroup]]) {
492 using sort_kernel = KernelMultiBlockMergeSort<
493 val_t,
494 idx_t,
495 ARG_SORT,
496 BLOCK_THREADS,
497 N_PER_THREAD>;
498
499 auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
500 inp += block_idx;
501 out_vals += tid.y * size_sorted_axis;
502 out_idxs += tid.y * size_sorted_axis;
503
504 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
505 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
506
507 sort_kernel::block_sort(
508 inp,
509 out_vals,
510 out_idxs,
511 size_sorted_axis,
512 stride_sorted_axis,
513 tgp_vals,
514 tgp_idxs,
515 tid,
516 lid);
517}
518
519template <
520 typename val_t,
521 typename idx_t,
522 bool ARG_SORT,
523 short BLOCK_THREADS,
524 short N_PER_THREAD>
525[[kernel]] void mb_block_partition(
526 device idx_t* block_partitions [[buffer(0)]],
527 const device val_t* dev_vals [[buffer(1)]],
528 const device idx_t* dev_idxs [[buffer(2)]],
529 const constant int& size_sorted_axis [[buffer(3)]],
530 const constant int& merge_tiles [[buffer(4)]],
531 const constant int& n_blocks [[buffer(5)]],
532 uint3 tid [[threadgroup_position_in_grid]],
533 uint3 lid [[thread_position_in_threadgroup]],
534 uint3 tgp_dims [[threads_per_threadgroup]]) {
535 using sort_kernel = KernelMultiBlockMergeSort<
536 val_t,
537 idx_t,
538 ARG_SORT,
539 BLOCK_THREADS,
540 N_PER_THREAD>;
541
542 block_partitions += tid.y * tgp_dims.x;
543 dev_vals += tid.y * size_sorted_axis;
544 dev_idxs += tid.y * size_sorted_axis;
545
546 for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
547 // Find location in merge step
548 int merge_group = i / merge_tiles;
549 int merge_lane = i % merge_tiles;
550
551 int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
552 int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
553
554 int A_st = min(size_sorted_axis, sort_st);
555 int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
556 int B_st = A_ed;
557 int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
558
559 int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
560 int partition = sort_kernel::merge_partition(
561 dev_vals + A_st,
562 dev_vals + B_st,
563 A_ed - A_st,
564 B_ed - B_st,
565 partition_at);
566
567 block_partitions[i] = A_st + partition;
568 }
569}
570
571template <
572 typename val_t,
573 typename idx_t,
574 bool ARG_SORT,
575 short BLOCK_THREADS,
576 short N_PER_THREAD,
577 typename CompareOp = LessThan<val_t>>
578[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
580 const device idx_t* block_partitions [[buffer(0)]],
581 const device val_t* dev_vals_in [[buffer(1)]],
582 const device idx_t* dev_idxs_in [[buffer(2)]],
583 device val_t* dev_vals_out [[buffer(3)]],
584 device idx_t* dev_idxs_out [[buffer(4)]],
585 const constant int& size_sorted_axis [[buffer(5)]],
586 const constant int& merge_tiles [[buffer(6)]],
587 const constant int& num_tiles [[buffer(7)]],
588 uint3 tid [[threadgroup_position_in_grid]],
589 uint3 lid [[thread_position_in_threadgroup]]) {
590 using sort_kernel = KernelMultiBlockMergeSort<
591 val_t,
592 idx_t,
593 ARG_SORT,
594 BLOCK_THREADS,
595 N_PER_THREAD,
596 CompareOp>;
597
598 using block_sort_t = typename sort_kernel::block_merge_sort_t;
599
600 block_partitions += tid.y * (num_tiles + 1);
601 dev_vals_in += tid.y * size_sorted_axis;
602 dev_idxs_in += tid.y * size_sorted_axis;
603 dev_vals_out += tid.y * size_sorted_axis;
604 dev_idxs_out += tid.y * size_sorted_axis;
605
606 int block_idx = tid.x;
607 int merge_group = block_idx / merge_tiles;
608 int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
609 int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
610 int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
611
612 int A_st = block_partitions[block_idx + 0];
613 int A_ed = block_partitions[block_idx + 1];
614 int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
615 int B_ed = min(
616 size_sorted_axis,
617 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
618
619 if ((block_idx % merge_tiles) == merge_tiles - 1) {
620 A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
621 B_ed = min(size_sorted_axis, sort_st + sort_sz);
622 }
623
624 int A_sz = A_ed - A_st;
625 int B_sz = B_ed - B_st;
626
627 // Load from global memory
628 thread val_t thread_vals[N_PER_THREAD];
629 thread idx_t thread_idxs[N_PER_THREAD];
630 for (int i = 0; i < N_PER_THREAD; i++) {
631 int idx = BLOCK_THREADS * i + lid.x;
632 if (idx < (A_sz + B_sz)) {
633 thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
634 : dev_vals_in[B_st + idx - A_sz];
635 thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
636 : dev_idxs_in[B_st + idx - A_sz];
637 } else {
638 thread_vals[i] = CompareOp::init;
639 thread_idxs[i] = 0;
640 }
641 }
642
643 // Write to shared memory
644 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
645 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
646 threadgroup_barrier(mem_flags::mem_threadgroup);
647 for (int i = 0; i < N_PER_THREAD; i++) {
648 int idx = BLOCK_THREADS * i + lid.x;
649 tgp_vals[idx] = thread_vals[i];
650 tgp_idxs[idx] = thread_idxs[i];
651 }
652 threadgroup_barrier(mem_flags::mem_threadgroup);
653
654 // Merge
655 int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
656
657 int A_st_local = block_sort_t::merge_partition(
658 tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
659 int A_ed_local = A_sz;
660
661 int B_st_local = sort_md_local - A_st_local;
662 int B_ed_local = B_sz;
663
664 int A_sz_local = A_ed_local - A_st_local;
665 int B_sz_local = B_ed_local - B_st_local;
666
667 // Do merge
668 block_sort_t::merge_step(
669 tgp_vals + A_st_local,
670 tgp_vals + A_ed_local + B_st_local,
671 tgp_idxs + A_st_local,
672 tgp_idxs + A_ed_local + B_st_local,
673 A_sz_local,
674 B_sz_local,
675 thread_vals,
676 thread_idxs);
677
678 threadgroup_barrier(mem_flags::mem_threadgroup);
679 for (int i = 0; i < N_PER_THREAD; ++i) {
680 int idx = lid.x * N_PER_THREAD;
681 tgp_vals[idx + i] = thread_vals[i];
682 tgp_idxs[idx + i] = thread_idxs[i];
683 }
684
685 threadgroup_barrier(mem_flags::mem_threadgroup);
686 // Write output
687 int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
688 for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
689 int idx = base_idx + i;
690 if (idx < size_sorted_axis) {
691 dev_vals_out[idx] = tgp_vals[i];
692 dev_idxs_out[idx] = tgp_idxs[i];
693 }
694 }
695}
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:87
Op op
Definition binary.h:141
Definition bf16.h:265
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
#define MLX_MTL_CONST
Definition sort.h:3
void block_sort_nc(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &nc_dim, const device int *nc_shape, const device size_t *in_nc_strides, const device size_t *out_nc_strides, uint3 tid, uint3 lid)
Definition sort.h:338
void mb_block_sort(const device val_t *inp, device val_t *out_vals, device idx_t *out_idxs, const constant int &size_sorted_axis, const constant int &stride_sorted_axis, const constant int &nc_dim, const device int *nc_shape, const device size_t *nc_strides, uint3 tid, uint3 lid)
Definition sort.h:481
void mb_block_partition(device idx_t *block_partitions, const device val_t *dev_vals, const device idx_t *dev_idxs, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &n_blocks, uint3 tid, uint3 lid, uint3 tgp_dims)
Definition sort.h:525
METAL_FUNC void thread_swap(thread T &a, thread T &b)
Definition sort.h:16
void block_sort(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, uint3 tid, uint3 lid)
Definition sort.h:283
void mb_block_merge(const device idx_t *block_partitions, const device val_t *dev_vals_in, const device idx_t *dev_idxs_in, device val_t *dev_vals_out, device idx_t *dev_idxs_out, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &num_tiles, uint3 tid, uint3 lid)
Definition sort.h:579
constant constexpr const int zero_helper
Definition sort.h:330
#define MLX_MTL_LOOP_UNROLL
Definition sort.h:4
Definition sort.h:67
static METAL_FUNC int merge_partition(const threadgroup val_t *As, const threadgroup val_t *Bs, short A_sz, short B_sz, short sort_md)
Definition sort.h:70
static METAL_FUNC void merge_step(const threadgroup val_t *As, const threadgroup val_t *Bs, const threadgroup idx_t *As_idx, const threadgroup idx_t *Bs_idx, short A_sz, short B_sz, thread val_t(&vals)[N_PER_THREAD], thread idx_t(&idxs)[N_PER_THREAD])
Definition sort.h:96
static METAL_FUNC void sort(threadgroup val_t *tgp_vals, threadgroup idx_t *tgp_idxs, int size_sorted_axis, uint3 lid)
Definition sort.h:122
Definition sort.h:221
uint idx_t
Definition sort.h:223
T val_t
Definition sort.h:222
static METAL_FUNC void block_sort(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, threadgroup val_t *tgp_vals, threadgroup idx_t *tgp_idxs, uint3 tid, uint3 lid)
Definition sort.h:234
static constant constexpr const short N_PER_BLOCK
Definition sort.h:232
Definition sort.h:399
static METAL_FUNC void block_sort(const device val_t *inp, device val_t *out_vals, device idx_t *out_idxs, const constant int &size_sorted_axis, const constant int &stride_sorted_axis, threadgroup val_t *tgp_vals, threadgroup idx_t *tgp_idxs, uint3 tid, uint3 lid)
Definition sort.h:410
static METAL_FUNC int merge_partition(const device val_t *As, const device val_t *Bs, int A_sz, int B_sz, int sort_md)
Definition sort.h:448
static constant constexpr const short N_PER_BLOCK
Definition sort.h:408
Definition sort.h:23
METAL_FUNC bool operator()(T a, T b)
Definition sort.h:26
static constexpr constant T init
Definition sort.h:24
Definition utils.h:17
Definition sort.h:37
static METAL_FUNC void sort(thread val_t(&vals)[N_PER_THREAD], thread idx_t(&idxs)[N_PER_THREAD])
Definition sort.h:38