MLX
 
Loading...
Searching...
No Matches
sort.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#define MLX_MTL_CONST static constant constexpr const
4#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
5
6using namespace metal;
7
8// Based on GPU merge sort algorithm at
9// https://github.com/NVIDIA/cccl/tree/main/cub/cub
10
12// Thread-level sort
14
15template <typename T>
16METAL_FUNC void thread_swap(thread T& a, thread T& b) {
17 T w = a;
18 a = b;
19 b = w;
20}
21
22template <typename T>
23struct LessThan {
24 static constexpr constant T init = Limits<T>::max;
25
26 METAL_FUNC bool operator()(T a, T b) {
27 return a < b;
28 }
29};
30
31template <
32 typename ValT,
33 typename IdxT,
34 bool ARG_SORT,
35 short N_PER_THREAD,
36 typename CompareOp>
37struct ThreadSort {
38 static METAL_FUNC void sort(
39 thread ValT (&vals)[N_PER_THREAD],
40 thread IdxT (&idxs)[N_PER_THREAD]) {
41 CompareOp op;
43 for (short i = 0; i < N_PER_THREAD; ++i) {
45 for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
46 if (op(vals[j + 1], vals[j])) {
47 thread_swap(vals[j + 1], vals[j]);
48 thread_swap(idxs[j + 1], idxs[j]);
49 }
50 }
51 }
52 }
53};
54
56// Threadgroup-level sort
58
59template <
60 typename ValT,
61 typename IdxT,
62 bool ARG_SORT,
63 short BLOCK_THREADS,
64 short N_PER_THREAD,
65 typename CompareOp>
69 static METAL_FUNC int merge_partition(
70 const threadgroup ValT* As,
71 const threadgroup ValT* Bs,
72 short A_sz,
73 short B_sz,
74 short sort_md) {
75 CompareOp op;
76
77 short A_st = max(0, sort_md - B_sz);
78 short A_ed = min(sort_md, A_sz);
79
80 while (A_st < A_ed) {
81 short md = A_st + (A_ed - A_st) / 2;
82 auto a = As[md];
83 auto b = Bs[sort_md - 1 - md];
84
85 if (op(b, a)) {
86 A_ed = md;
87 } else {
88 A_st = md + 1;
89 }
90 }
91
92 return A_ed;
93 }
94
95 static METAL_FUNC void merge_step(
96 const threadgroup ValT* As,
97 const threadgroup ValT* Bs,
98 const threadgroup IdxT* As_idx,
99 const threadgroup IdxT* Bs_idx,
100 short A_sz,
101 short B_sz,
102 thread ValT (&vals)[N_PER_THREAD],
103 thread IdxT (&idxs)[N_PER_THREAD]) {
104 CompareOp op;
105 short a_idx = 0;
106 short b_idx = 0;
107
108 for (int i = 0; i < N_PER_THREAD; ++i) {
109 auto a = As[a_idx];
110 auto b = Bs[b_idx];
111 bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
112
113 vals[i] = pred ? b : a;
114 idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
115
116 b_idx += short(pred);
117 a_idx += short(!pred);
118 }
119 }
120
121 static METAL_FUNC void sort(
122 threadgroup ValT* tgp_vals [[threadgroup(0)]],
123 threadgroup IdxT* tgp_idxs [[threadgroup(1)]],
124 int size_sorted_axis,
125 uint3 lid [[thread_position_in_threadgroup]]) {
126 // Get thread location
127 int idx = lid.x * N_PER_THREAD;
128
129 // Load from shared memory
130 thread ValT thread_vals[N_PER_THREAD];
131 thread IdxT thread_idxs[N_PER_THREAD];
132 for (int i = 0; i < N_PER_THREAD; ++i) {
133 thread_vals[i] = tgp_vals[idx + i];
134 if (ARG_SORT) {
135 thread_idxs[i] = tgp_idxs[idx + i];
136 }
137 }
138
139 // Per thread sort
140 if (idx < size_sorted_axis) {
141 thread_sort_t::sort(thread_vals, thread_idxs);
142 }
143
144 // Do merges using threadgroup memory
145 for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
146 merge_threads *= 2) {
147 // Update threadgroup memory
148 threadgroup_barrier(mem_flags::mem_threadgroup);
149 for (int i = 0; i < N_PER_THREAD; ++i) {
150 tgp_vals[idx + i] = thread_vals[i];
151 if (ARG_SORT) {
152 tgp_idxs[idx + i] = thread_idxs[i];
153 }
154 }
155 threadgroup_barrier(mem_flags::mem_threadgroup);
156
157 // Find location in merge step
158 int merge_group = lid.x / merge_threads;
159 int merge_lane = lid.x % merge_threads;
160
161 int sort_sz = N_PER_THREAD * merge_threads;
162 int sort_st = N_PER_THREAD * merge_threads * merge_group;
163
164 // As = tgp_vals[A_st:A_ed] is sorted
165 // Bs = tgp_vals[B_st:B_ed] is sorted
166 int A_st = sort_st;
167 int A_ed = sort_st + sort_sz / 2;
168 int B_st = sort_st + sort_sz / 2;
169 int B_ed = sort_st + sort_sz;
170
171 const threadgroup ValT* As = tgp_vals + A_st;
172 const threadgroup ValT* Bs = tgp_vals + B_st;
173 int A_sz = A_ed - A_st;
174 int B_sz = B_ed - B_st;
175
176 // Find a partition of merge elements
177 // Ci = merge(As[partition:], Bs[sort_md - partition:])
178 // of size N_PER_THREAD for each merge lane i
179 // C = [Ci] is sorted
180 int sort_md = N_PER_THREAD * merge_lane;
181 int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
182
183 As += partition;
184 Bs += sort_md - partition;
185
186 A_sz -= partition;
187 B_sz -= sort_md - partition;
188
189 const threadgroup IdxT* As_idx =
190 ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
191 const threadgroup IdxT* Bs_idx =
192 ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
193
194 // Merge starting at the partition and store results in thread registers
195 merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
196 }
197
198 // Write out to shared memory
199 threadgroup_barrier(mem_flags::mem_threadgroup);
200 for (int i = 0; i < N_PER_THREAD; ++i) {
201 tgp_vals[idx + i] = thread_vals[i];
202 if (ARG_SORT) {
203 tgp_idxs[idx + i] = thread_idxs[i];
204 }
205 }
206 }
207};
208
210// Kernel sort
212
213template <
214 typename T,
215 typename U,
216 bool ARG_SORT,
217 short BLOCK_THREADS,
218 short N_PER_THREAD,
219 typename CompareOp = LessThan<T>>
221 using ValT = T;
222 using IdxT = uint;
224 ValT,
225 IdxT,
226 ARG_SORT,
227 BLOCK_THREADS,
228 N_PER_THREAD,
229 CompareOp>;
230
231 MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
232
233 static METAL_FUNC void block_sort(
234 const device T* inp,
235 device U* out,
236 const constant int& size_sorted_axis,
237 const constant int& in_stride_sorted_axis,
238 const constant int& out_stride_sorted_axis,
239 const constant int& in_stride_segment_axis,
240 const constant int& out_stride_segment_axis,
241 threadgroup ValT* tgp_vals,
242 threadgroup IdxT* tgp_idxs,
243 uint3 tid [[threadgroup_position_in_grid]],
244 uint3 lid [[thread_position_in_threadgroup]]) {
245 // tid.y tells us the segment index
246 inp += tid.y * in_stride_segment_axis;
247 out += tid.y * out_stride_segment_axis;
248
249 // Copy into threadgroup memory
250 for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
251 tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
252 : ValT(CompareOp::init);
253 if (ARG_SORT) {
254 tgp_idxs[i] = i;
255 }
256 }
257
258 // Sort elements within the block
259 threadgroup_barrier(mem_flags::mem_threadgroup);
260
261 block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
262
263 threadgroup_barrier(mem_flags::mem_threadgroup);
264
265 // Write output
266 for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
267 if (ARG_SORT) {
268 out[i * out_stride_sorted_axis] = tgp_idxs[i];
269 } else {
270 out[i * out_stride_sorted_axis] = tgp_vals[i];
271 }
272 }
273 }
274};
275
276template <
277 typename T,
278 typename U,
279 bool ARG_SORT,
280 short BLOCK_THREADS,
281 short N_PER_THREAD>
282[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
283 const device T* inp [[buffer(0)]],
284 device U* out [[buffer(1)]],
285 const constant int& size_sorted_axis [[buffer(2)]],
286 const constant int& in_stride_sorted_axis [[buffer(3)]],
287 const constant int& out_stride_sorted_axis [[buffer(4)]],
288 const constant int& in_stride_segment_axis [[buffer(5)]],
289 const constant int& out_stride_segment_axis [[buffer(6)]],
290 uint3 tid [[threadgroup_position_in_grid]],
291 uint3 lid [[thread_position_in_threadgroup]]) {
292 using sort_kernel =
294 using ValT = typename sort_kernel::ValT;
295 using IdxT = typename sort_kernel::IdxT;
296
297 if (ARG_SORT) {
298 threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
299 threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
300 sort_kernel::block_sort(
301 inp,
302 out,
303 size_sorted_axis,
304 in_stride_sorted_axis,
305 out_stride_sorted_axis,
306 in_stride_segment_axis,
307 out_stride_segment_axis,
308 tgp_vals,
309 tgp_idxs,
310 tid,
311 lid);
312 } else {
313 threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
314 sort_kernel::block_sort(
315 inp,
316 out,
317 size_sorted_axis,
318 in_stride_sorted_axis,
319 out_stride_sorted_axis,
320 in_stride_segment_axis,
321 out_stride_segment_axis,
322 tgp_vals,
323 nullptr,
324 tid,
325 lid);
326 }
327}
328
329constant constexpr const int zero_helper = 0;
330
331template <
332 typename T,
333 typename U,
334 bool ARG_SORT,
335 short BLOCK_THREADS,
336 short N_PER_THREAD>
337[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
338 const device T* inp [[buffer(0)]],
339 device U* out [[buffer(1)]],
340 const constant int& size_sorted_axis [[buffer(2)]],
341 const constant int& in_stride_sorted_axis [[buffer(3)]],
342 const constant int& out_stride_sorted_axis [[buffer(4)]],
343 const constant int& nc_dim [[buffer(5)]],
344 const constant int* nc_shape [[buffer(6)]],
345 const constant int64_t* in_nc_strides [[buffer(7)]],
346 const constant int64_t* out_nc_strides [[buffer(8)]],
347 uint3 tid [[threadgroup_position_in_grid]],
348 uint3 lid [[thread_position_in_threadgroup]]) {
349 using sort_kernel =
351 using ValT = typename sort_kernel::ValT;
352 using IdxT = typename sort_kernel::IdxT;
353
354 auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
355 auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
356 inp += in_block_idx;
357 out += out_block_idx;
358
359 if (ARG_SORT) {
360 threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
361 threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
362 sort_kernel::block_sort(
363 inp,
364 out,
365 size_sorted_axis,
366 in_stride_sorted_axis,
367 out_stride_sorted_axis,
370 tgp_vals,
371 tgp_idxs,
372 tid,
373 lid);
374 } else {
375 threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
376 sort_kernel::block_sort(
377 inp,
378 out,
379 size_sorted_axis,
380 in_stride_sorted_axis,
381 out_stride_sorted_axis,
384 tgp_vals,
385 nullptr,
386 tid,
387 lid);
388 }
389}
390
391template <
392 typename ValT,
393 typename IdxT,
394 bool ARG_SORT,
395 short BLOCK_THREADS,
396 short N_PER_THREAD,
397 typename CompareOp = LessThan<ValT>>
400 ValT,
401 IdxT,
402 ARG_SORT,
403 BLOCK_THREADS,
404 N_PER_THREAD,
405 CompareOp>;
406
407 MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
408
409 static METAL_FUNC void block_sort(
410 const device ValT* inp,
411 device ValT* out_vals,
412 device IdxT* out_idxs,
413 const constant int& size_sorted_axis,
414 const constant int& stride_sorted_axis,
415 threadgroup ValT* tgp_vals,
416 threadgroup IdxT* tgp_idxs,
417 uint3 tid [[threadgroup_position_in_grid]],
418 uint3 lid [[thread_position_in_threadgroup]]) {
419 // tid.y tells us the segment index
420 int base_idx = tid.x * N_PER_BLOCK;
421
422 // Copy into threadgroup memory
423 for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
424 int idx = base_idx + i;
425 tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
426 : ValT(CompareOp::init);
427 tgp_idxs[i] = idx;
428 }
429
430 // Sort elements within the block
431 threadgroup_barrier(mem_flags::mem_threadgroup);
432
433 block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
434
435 threadgroup_barrier(mem_flags::mem_threadgroup);
436
437 // Write output
438 for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
439 int idx = base_idx + i;
440 if (idx < size_sorted_axis) {
441 out_vals[idx] = tgp_vals[i];
442 out_idxs[idx] = tgp_idxs[i];
443 }
444 }
445 }
446
447 static METAL_FUNC int merge_partition(
448 const device ValT* As,
449 const device ValT* Bs,
450 int A_sz,
451 int B_sz,
452 int sort_md) {
453 CompareOp op;
454
455 int A_st = max(0, sort_md - B_sz);
456 int A_ed = min(sort_md, A_sz);
457
458 while (A_st < A_ed) {
459 int md = A_st + (A_ed - A_st) / 2;
460 auto a = As[md];
461 auto b = Bs[sort_md - 1 - md];
462
463 if (op(b, a)) {
464 A_ed = md;
465 } else {
466 A_st = md + 1;
467 }
468 }
469
470 return A_ed;
471 }
472};
473
474template <
475 typename ValT,
476 typename IdxT,
477 bool ARG_SORT,
478 short BLOCK_THREADS,
479 short N_PER_THREAD>
480[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
481 const device ValT* inp [[buffer(0)]],
482 device ValT* out_vals [[buffer(1)]],
483 device IdxT* out_idxs [[buffer(2)]],
484 const constant int& size_sorted_axis [[buffer(3)]],
485 const constant int& stride_sorted_axis [[buffer(4)]],
486 const constant int& nc_dim [[buffer(5)]],
487 const constant int* nc_shape [[buffer(6)]],
488 const constant int64_t* nc_strides [[buffer(7)]],
489 uint3 tid [[threadgroup_position_in_grid]],
490 uint3 lid [[thread_position_in_threadgroup]]) {
491 using sort_kernel = KernelMultiBlockMergeSort<
492 ValT,
493 IdxT,
494 ARG_SORT,
495 BLOCK_THREADS,
496 N_PER_THREAD>;
497
498 auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
499 inp += block_idx;
500 out_vals += tid.y * size_sorted_axis;
501 out_idxs += tid.y * size_sorted_axis;
502
503 threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
504 threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
505
506 sort_kernel::block_sort(
507 inp,
508 out_vals,
509 out_idxs,
510 size_sorted_axis,
511 stride_sorted_axis,
512 tgp_vals,
513 tgp_idxs,
514 tid,
515 lid);
516}
517
518template <
519 typename ValT,
520 typename IdxT,
521 bool ARG_SORT,
522 short BLOCK_THREADS,
523 short N_PER_THREAD>
524[[kernel]] void mb_block_partition(
525 device IdxT* block_partitions [[buffer(0)]],
526 const device ValT* dev_vals [[buffer(1)]],
527 const device IdxT* dev_idxs [[buffer(2)]],
528 const constant int& size_sorted_axis [[buffer(3)]],
529 const constant int& merge_tiles [[buffer(4)]],
530 const constant int& n_blocks [[buffer(5)]],
531 uint3 tid [[threadgroup_position_in_grid]],
532 uint3 lid [[thread_position_in_threadgroup]],
533 uint3 tgp_dims [[threads_per_threadgroup]]) {
534 using sort_kernel = KernelMultiBlockMergeSort<
535 ValT,
536 IdxT,
537 ARG_SORT,
538 BLOCK_THREADS,
539 N_PER_THREAD>;
540
541 block_partitions += tid.y * tgp_dims.x;
542 dev_vals += tid.y * size_sorted_axis;
543 dev_idxs += tid.y * size_sorted_axis;
544
545 for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
546 // Find location in merge step
547 int merge_group = i / merge_tiles;
548 int merge_lane = i % merge_tiles;
549
550 int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
551 int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
552
553 int A_st = min(size_sorted_axis, sort_st);
554 int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
555 int B_st = A_ed;
556 int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
557
558 int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
559 int partition = sort_kernel::merge_partition(
560 dev_vals + A_st,
561 dev_vals + B_st,
562 A_ed - A_st,
563 B_ed - B_st,
564 partition_at);
565
566 block_partitions[i] = A_st + partition;
567 }
568}
569
570template <
571 typename ValT,
572 typename IdxT,
573 bool ARG_SORT,
574 short BLOCK_THREADS,
575 short N_PER_THREAD,
576 typename CompareOp = LessThan<ValT>>
577[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
579 const device IdxT* block_partitions [[buffer(0)]],
580 const device ValT* dev_vals_in [[buffer(1)]],
581 const device IdxT* dev_idxs_in [[buffer(2)]],
582 device ValT* dev_vals_out [[buffer(3)]],
583 device IdxT* dev_idxs_out [[buffer(4)]],
584 const constant int& size_sorted_axis [[buffer(5)]],
585 const constant int& merge_tiles [[buffer(6)]],
586 const constant int& num_tiles [[buffer(7)]],
587 uint3 tid [[threadgroup_position_in_grid]],
588 uint3 lid [[thread_position_in_threadgroup]]) {
589 using sort_kernel = KernelMultiBlockMergeSort<
590 ValT,
591 IdxT,
592 ARG_SORT,
593 BLOCK_THREADS,
594 N_PER_THREAD,
595 CompareOp>;
596
597 using block_sort_t = typename sort_kernel::block_merge_sort_t;
598
599 block_partitions += tid.y * (num_tiles + 1);
600 dev_vals_in += tid.y * size_sorted_axis;
601 dev_idxs_in += tid.y * size_sorted_axis;
602 dev_vals_out += tid.y * size_sorted_axis;
603 dev_idxs_out += tid.y * size_sorted_axis;
604
605 int block_idx = tid.x;
606 int merge_group = block_idx / merge_tiles;
607 int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
608 int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
609 int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
610
611 int A_st = block_partitions[block_idx + 0];
612 int A_ed = block_partitions[block_idx + 1];
613 int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
614 int B_ed = min(
615 size_sorted_axis,
616 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
617
618 if ((block_idx % merge_tiles) == merge_tiles - 1) {
619 A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
620 B_ed = min(size_sorted_axis, sort_st + sort_sz);
621 }
622
623 int A_sz = A_ed - A_st;
624 int B_sz = B_ed - B_st;
625
626 // Load from global memory
627 thread ValT thread_vals[N_PER_THREAD];
628 thread IdxT thread_idxs[N_PER_THREAD];
629 for (int i = 0; i < N_PER_THREAD; i++) {
630 int idx = BLOCK_THREADS * i + lid.x;
631 if (idx < (A_sz + B_sz)) {
632 thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
633 : dev_vals_in[B_st + idx - A_sz];
634 thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
635 : dev_idxs_in[B_st + idx - A_sz];
636 } else {
637 thread_vals[i] = CompareOp::init;
638 thread_idxs[i] = 0;
639 }
640 }
641
642 // Write to shared memory
643 threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
644 threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
645 threadgroup_barrier(mem_flags::mem_threadgroup);
646 for (int i = 0; i < N_PER_THREAD; i++) {
647 int idx = BLOCK_THREADS * i + lid.x;
648 tgp_vals[idx] = thread_vals[i];
649 tgp_idxs[idx] = thread_idxs[i];
650 }
651 threadgroup_barrier(mem_flags::mem_threadgroup);
652
653 // Merge
654 int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
655
656 int A_st_local = block_sort_t::merge_partition(
657 tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
658 int A_ed_local = A_sz;
659
660 int B_st_local = sort_md_local - A_st_local;
661 int B_ed_local = B_sz;
662
663 int A_sz_local = A_ed_local - A_st_local;
664 int B_sz_local = B_ed_local - B_st_local;
665
666 // Do merge
667 block_sort_t::merge_step(
668 tgp_vals + A_st_local,
669 tgp_vals + A_ed_local + B_st_local,
670 tgp_idxs + A_st_local,
671 tgp_idxs + A_ed_local + B_st_local,
672 A_sz_local,
673 B_sz_local,
674 thread_vals,
675 thread_idxs);
676
677 threadgroup_barrier(mem_flags::mem_threadgroup);
678 for (int i = 0; i < N_PER_THREAD; ++i) {
679 int idx = lid.x * N_PER_THREAD;
680 tgp_vals[idx + i] = thread_vals[i];
681 tgp_idxs[idx + i] = thread_idxs[i];
682 }
683
684 threadgroup_barrier(mem_flags::mem_threadgroup);
685 // Write output
686 int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
687 for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
688 int idx = base_idx + i;
689 if (idx < size_sorted_axis) {
690 dev_vals_out[idx] = tgp_vals[i];
691 dev_idxs_out[idx] = tgp_idxs[i];
692 }
693 }
694}
METAL_FUNC IdxT elem_to_loc(IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim)
Definition utils.h:93
Definition bf16_math.h:226
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
#define MLX_MTL_CONST
Definition sort.h:3
METAL_FUNC void thread_swap(thread T &a, thread T &b)
Definition sort.h:16
void mb_block_partition(device IdxT *block_partitions, const device ValT *dev_vals, const device IdxT *dev_idxs, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &n_blocks, uint3 tid, uint3 lid, uint3 tgp_dims)
Definition sort.h:524
void block_sort(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, uint3 tid, uint3 lid)
Definition sort.h:282
void mb_block_merge(const device IdxT *block_partitions, const device ValT *dev_vals_in, const device IdxT *dev_idxs_in, device ValT *dev_vals_out, device IdxT *dev_idxs_out, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &num_tiles, uint3 tid, uint3 lid)
Definition sort.h:578
constant constexpr const int zero_helper
Definition sort.h:329
void mb_block_sort(const device ValT *inp, device ValT *out_vals, device IdxT *out_idxs, const constant int &size_sorted_axis, const constant int &stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant int64_t *nc_strides, uint3 tid, uint3 lid)
Definition sort.h:480
void block_sort_nc(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant int64_t *in_nc_strides, const constant int64_t *out_nc_strides, uint3 tid, uint3 lid)
Definition sort.h:337
#define MLX_MTL_LOOP_UNROLL
Definition sort.h:4
Definition sort.h:66
static METAL_FUNC void merge_step(const threadgroup ValT *As, const threadgroup ValT *Bs, const threadgroup IdxT *As_idx, const threadgroup IdxT *Bs_idx, short A_sz, short B_sz, thread ValT(&vals)[N_PER_THREAD], thread IdxT(&idxs)[N_PER_THREAD])
Definition sort.h:95
static METAL_FUNC void sort(threadgroup ValT *tgp_vals, threadgroup IdxT *tgp_idxs, int size_sorted_axis, uint3 lid)
Definition sort.h:121
static METAL_FUNC int merge_partition(const threadgroup ValT *As, const threadgroup ValT *Bs, short A_sz, short B_sz, short sort_md)
Definition sort.h:69
ThreadSort< ValT, IdxT, ARG_SORT, N_PER_THREAD, CompareOp > thread_sort_t
Definition sort.h:67
Definition sort.h:220
BlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp > block_merge_sort_t
Definition sort.h:223
uint IdxT
Definition sort.h:222
static METAL_FUNC void block_sort(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, threadgroup ValT *tgp_vals, threadgroup IdxT *tgp_idxs, uint3 tid, uint3 lid)
Definition sort.h:233
T ValT
Definition sort.h:221
static constant constexpr const short N_PER_BLOCK
Definition sort.h:231
Definition sort.h:398
static constant constexpr const short N_PER_BLOCK
Definition sort.h:407
static METAL_FUNC void block_sort(const device ValT *inp, device ValT *out_vals, device IdxT *out_idxs, const constant int &size_sorted_axis, const constant int &stride_sorted_axis, threadgroup ValT *tgp_vals, threadgroup IdxT *tgp_idxs, uint3 tid, uint3 lid)
Definition sort.h:409
static METAL_FUNC int merge_partition(const device ValT *As, const device ValT *Bs, int A_sz, int B_sz, int sort_md)
Definition sort.h:447
BlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp > block_merge_sort_t
Definition sort.h:399
Definition sort.h:23
METAL_FUNC bool operator()(T a, T b)
Definition sort.h:26
static constexpr constant T init
Definition sort.h:24
static const constant U max
Definition utils.h:24
Definition sort.h:37
static METAL_FUNC void sort(thread ValT(&vals)[N_PER_THREAD], thread IdxT(&idxs)[N_PER_THREAD])
Definition sort.h:38