3#define MLX_MTL_CONST static constant constexpr const
4#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
38 static METAL_FUNC
void sort(
39 thread ValT (&vals)[N_PER_THREAD],
40 thread IdxT (&idxs)[N_PER_THREAD]) {
43 for (
short i = 0; i < N_PER_THREAD; ++i) {
45 for (
short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
46 if (op(vals[j + 1], vals[j])) {
70 const threadgroup ValT* As,
71 const threadgroup ValT* Bs,
77 short A_st =
max(0, sort_md - B_sz);
78 short A_ed =
min(sort_md, A_sz);
81 short md = A_st + (A_ed - A_st) / 2;
83 auto b = Bs[sort_md - 1 - md];
96 const threadgroup ValT* As,
97 const threadgroup ValT* Bs,
98 const threadgroup IdxT* As_idx,
99 const threadgroup IdxT* Bs_idx,
102 thread ValT (&vals)[N_PER_THREAD],
103 thread IdxT (&idxs)[N_PER_THREAD]) {
108 for (
int i = 0; i < N_PER_THREAD; ++i) {
111 bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
113 vals[i] = pred ? b : a;
114 idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
116 b_idx += short(pred);
117 a_idx += short(!pred);
122 threadgroup ValT* tgp_vals [[threadgroup(0)]],
123 threadgroup IdxT* tgp_idxs [[threadgroup(1)]],
124 int size_sorted_axis,
125 uint3 lid [[thread_position_in_threadgroup]]) {
127 int idx = lid.x * N_PER_THREAD;
130 thread ValT thread_vals[N_PER_THREAD];
131 thread IdxT thread_idxs[N_PER_THREAD];
132 for (
int i = 0; i < N_PER_THREAD; ++i) {
133 thread_vals[i] = tgp_vals[idx + i];
135 thread_idxs[i] = tgp_idxs[idx + i];
140 if (idx < size_sorted_axis) {
145 for (
int merge_threads = 2; merge_threads <= BLOCK_THREADS;
146 merge_threads *= 2) {
148 threadgroup_barrier(mem_flags::mem_threadgroup);
149 for (
int i = 0; i < N_PER_THREAD; ++i) {
150 tgp_vals[idx + i] = thread_vals[i];
152 tgp_idxs[idx + i] = thread_idxs[i];
155 threadgroup_barrier(mem_flags::mem_threadgroup);
158 int merge_group = lid.x / merge_threads;
159 int merge_lane = lid.x % merge_threads;
161 int sort_sz = N_PER_THREAD * merge_threads;
162 int sort_st = N_PER_THREAD * merge_threads * merge_group;
167 int A_ed = sort_st + sort_sz / 2;
168 int B_st = sort_st + sort_sz / 2;
169 int B_ed = sort_st + sort_sz;
171 const threadgroup ValT* As = tgp_vals + A_st;
172 const threadgroup ValT* Bs = tgp_vals + B_st;
173 int A_sz = A_ed - A_st;
174 int B_sz = B_ed - B_st;
180 int sort_md = N_PER_THREAD * merge_lane;
184 Bs += sort_md - partition;
187 B_sz -= sort_md - partition;
189 const threadgroup IdxT* As_idx =
190 ARG_SORT ? tgp_idxs + A_st + partition :
nullptr;
191 const threadgroup IdxT* Bs_idx =
192 ARG_SORT ? tgp_idxs + B_st + sort_md - partition :
nullptr;
195 merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
199 threadgroup_barrier(mem_flags::mem_threadgroup);
200 for (
int i = 0; i < N_PER_THREAD; ++i) {
201 tgp_vals[idx + i] = thread_vals[i];
203 tgp_idxs[idx + i] = thread_idxs[i];
236 const constant
int& size_sorted_axis,
237 const constant
int& in_stride_sorted_axis,
238 const constant
int& out_stride_sorted_axis,
239 const constant
int& in_stride_segment_axis,
240 const constant
int& out_stride_segment_axis,
241 threadgroup
ValT* tgp_vals,
242 threadgroup
IdxT* tgp_idxs,
243 uint3 tid [[threadgroup_position_in_grid]],
244 uint3 lid [[thread_position_in_threadgroup]]) {
246 inp += tid.y * in_stride_segment_axis;
247 out += tid.y * out_stride_segment_axis;
250 for (
short i = lid.x; i <
N_PER_BLOCK; i += BLOCK_THREADS) {
251 tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
252 :
ValT(CompareOp::init);
259 threadgroup_barrier(mem_flags::mem_threadgroup);
263 threadgroup_barrier(mem_flags::mem_threadgroup);
266 for (
int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
268 out[i * out_stride_sorted_axis] = tgp_idxs[i];
270 out[i * out_stride_sorted_axis] = tgp_vals[i];
282[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void block_sort(
283 const device T* inp [[buffer(0)]],
284 device U* out [[buffer(1)]],
285 const constant
int& size_sorted_axis [[buffer(2)]],
286 const constant
int& in_stride_sorted_axis [[buffer(3)]],
287 const constant
int& out_stride_sorted_axis [[buffer(4)]],
288 const constant
int& in_stride_segment_axis [[buffer(5)]],
289 const constant
int& out_stride_segment_axis [[buffer(6)]],
290 uint3 tid [[threadgroup_position_in_grid]],
291 uint3 lid [[thread_position_in_threadgroup]]) {
294 using ValT =
typename sort_kernel::ValT;
295 using IdxT =
typename sort_kernel::IdxT;
298 threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
299 threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
300 sort_kernel::block_sort(
304 in_stride_sorted_axis,
305 out_stride_sorted_axis,
306 in_stride_segment_axis,
307 out_stride_segment_axis,
313 threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
314 sort_kernel::block_sort(
318 in_stride_sorted_axis,
319 out_stride_sorted_axis,
320 in_stride_segment_axis,
321 out_stride_segment_axis,
337[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void block_sort_nc(
338 const device T* inp [[buffer(0)]],
339 device U* out [[buffer(1)]],
340 const constant
int& size_sorted_axis [[buffer(2)]],
341 const constant
int& in_stride_sorted_axis [[buffer(3)]],
342 const constant
int& out_stride_sorted_axis [[buffer(4)]],
343 const constant
int& nc_dim [[buffer(5)]],
344 const constant
int* nc_shape [[buffer(6)]],
345 const constant int64_t* in_nc_strides [[buffer(7)]],
346 const constant int64_t* out_nc_strides [[buffer(8)]],
347 uint3 tid [[threadgroup_position_in_grid]],
348 uint3 lid [[thread_position_in_threadgroup]]) {
351 using ValT =
typename sort_kernel::ValT;
352 using IdxT =
typename sort_kernel::IdxT;
354 auto in_block_idx =
elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
355 auto out_block_idx =
elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
357 out += out_block_idx;
360 threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
361 threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
362 sort_kernel::block_sort(
366 in_stride_sorted_axis,
367 out_stride_sorted_axis,
375 threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
376 sort_kernel::block_sort(
380 in_stride_sorted_axis,
381 out_stride_sorted_axis,
410 const device ValT* inp,
411 device ValT* out_vals,
412 device IdxT* out_idxs,
413 const constant
int& size_sorted_axis,
414 const constant
int& stride_sorted_axis,
415 threadgroup ValT* tgp_vals,
416 threadgroup IdxT* tgp_idxs,
417 uint3 tid [[threadgroup_position_in_grid]],
418 uint3 lid [[thread_position_in_threadgroup]]) {
423 for (
short i = lid.x; i <
N_PER_BLOCK; i += BLOCK_THREADS) {
424 int idx = base_idx + i;
425 tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
426 : ValT(CompareOp::init);
431 threadgroup_barrier(mem_flags::mem_threadgroup);
435 threadgroup_barrier(mem_flags::mem_threadgroup);
438 for (
int i = lid.x; i <
N_PER_BLOCK; i += BLOCK_THREADS) {
439 int idx = base_idx + i;
440 if (idx < size_sorted_axis) {
441 out_vals[idx] = tgp_vals[i];
442 out_idxs[idx] = tgp_idxs[i];
448 const device ValT* As,
449 const device ValT* Bs,
455 int A_st =
max(0, sort_md - B_sz);
456 int A_ed =
min(sort_md, A_sz);
458 while (A_st < A_ed) {
459 int md = A_st + (A_ed - A_st) / 2;
461 auto b = Bs[sort_md - 1 - md];
480[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void mb_block_sort(
481 const device ValT* inp [[buffer(0)]],
482 device ValT* out_vals [[buffer(1)]],
483 device IdxT* out_idxs [[buffer(2)]],
484 const constant
int& size_sorted_axis [[buffer(3)]],
485 const constant
int& stride_sorted_axis [[buffer(4)]],
486 const constant
int& nc_dim [[buffer(5)]],
487 const constant
int* nc_shape [[buffer(6)]],
488 const constant int64_t* nc_strides [[buffer(7)]],
489 uint3 tid [[threadgroup_position_in_grid]],
490 uint3 lid [[thread_position_in_threadgroup]]) {
498 auto block_idx =
elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
500 out_vals += tid.y * size_sorted_axis;
501 out_idxs += tid.y * size_sorted_axis;
503 threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
504 threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
506 sort_kernel::block_sort(
525 device IdxT* block_partitions [[buffer(0)]],
526 const device ValT* dev_vals [[buffer(1)]],
527 const device IdxT* dev_idxs [[buffer(2)]],
528 const constant
int& size_sorted_axis [[buffer(3)]],
529 const constant
int& merge_tiles [[buffer(4)]],
530 const constant
int& n_blocks [[buffer(5)]],
531 uint3 tid [[threadgroup_position_in_grid]],
532 uint3 lid [[thread_position_in_threadgroup]],
533 uint3 tgp_dims [[threads_per_threadgroup]]) {
541 block_partitions += tid.y * tgp_dims.x;
542 dev_vals += tid.y * size_sorted_axis;
543 dev_idxs += tid.y * size_sorted_axis;
545 for (
int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
547 int merge_group = i / merge_tiles;
548 int merge_lane = i % merge_tiles;
550 int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
551 int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
553 int A_st =
min(size_sorted_axis, sort_st);
554 int A_ed =
min(size_sorted_axis, sort_st + sort_sz / 2);
556 int B_ed =
min(size_sorted_axis, B_st + sort_sz / 2);
558 int partition_at =
min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
559 int partition = sort_kernel::merge_partition(
566 block_partitions[i] = A_st + partition;
577[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void
579 const device IdxT* block_partitions [[buffer(0)]],
580 const device ValT* dev_vals_in [[buffer(1)]],
581 const device IdxT* dev_idxs_in [[buffer(2)]],
582 device ValT* dev_vals_out [[buffer(3)]],
583 device IdxT* dev_idxs_out [[buffer(4)]],
584 const constant
int& size_sorted_axis [[buffer(5)]],
585 const constant
int& merge_tiles [[buffer(6)]],
586 const constant
int& num_tiles [[buffer(7)]],
587 uint3 tid [[threadgroup_position_in_grid]],
588 uint3 lid [[thread_position_in_threadgroup]]) {
597 using block_sort_t =
typename sort_kernel::block_merge_sort_t;
599 block_partitions += tid.y * (num_tiles + 1);
600 dev_vals_in += tid.y * size_sorted_axis;
601 dev_idxs_in += tid.y * size_sorted_axis;
602 dev_vals_out += tid.y * size_sorted_axis;
603 dev_idxs_out += tid.y * size_sorted_axis;
605 int block_idx = tid.x;
606 int merge_group = block_idx / merge_tiles;
607 int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
608 int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
609 int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
611 int A_st = block_partitions[block_idx + 0];
612 int A_ed = block_partitions[block_idx + 1];
613 int B_st =
min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
616 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
618 if ((block_idx % merge_tiles) == merge_tiles - 1) {
619 A_ed =
min(size_sorted_axis, sort_st + sort_sz / 2);
620 B_ed =
min(size_sorted_axis, sort_st + sort_sz);
623 int A_sz = A_ed - A_st;
624 int B_sz = B_ed - B_st;
627 thread ValT thread_vals[N_PER_THREAD];
628 thread IdxT thread_idxs[N_PER_THREAD];
629 for (
int i = 0; i < N_PER_THREAD; i++) {
630 int idx = BLOCK_THREADS * i + lid.x;
631 if (idx < (A_sz + B_sz)) {
632 thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
633 : dev_vals_in[B_st + idx - A_sz];
634 thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
635 : dev_idxs_in[B_st + idx - A_sz];
637 thread_vals[i] = CompareOp::init;
643 threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
644 threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
645 threadgroup_barrier(mem_flags::mem_threadgroup);
646 for (
int i = 0; i < N_PER_THREAD; i++) {
647 int idx = BLOCK_THREADS * i + lid.x;
648 tgp_vals[idx] = thread_vals[i];
649 tgp_idxs[idx] = thread_idxs[i];
651 threadgroup_barrier(mem_flags::mem_threadgroup);
654 int sort_md_local =
min(A_sz + B_sz, N_PER_THREAD *
int(lid.x));
656 int A_st_local = block_sort_t::merge_partition(
657 tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
658 int A_ed_local = A_sz;
660 int B_st_local = sort_md_local - A_st_local;
661 int B_ed_local = B_sz;
663 int A_sz_local = A_ed_local - A_st_local;
664 int B_sz_local = B_ed_local - B_st_local;
667 block_sort_t::merge_step(
668 tgp_vals + A_st_local,
669 tgp_vals + A_ed_local + B_st_local,
670 tgp_idxs + A_st_local,
671 tgp_idxs + A_ed_local + B_st_local,
677 threadgroup_barrier(mem_flags::mem_threadgroup);
678 for (
int i = 0; i < N_PER_THREAD; ++i) {
679 int idx = lid.x * N_PER_THREAD;
680 tgp_vals[idx + i] = thread_vals[i];
681 tgp_idxs[idx + i] = thread_idxs[i];
684 threadgroup_barrier(mem_flags::mem_threadgroup);
686 int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
687 for (
int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
688 int idx = base_idx + i;
689 if (idx < size_sorted_axis) {
690 dev_vals_out[idx] = tgp_vals[i];
691 dev_idxs_out[idx] = tgp_idxs[i];
#define MLX_MTL_CONST
Definition sort.h:3
METAL_FUNC void thread_swap(thread T &a, thread T &b)
Definition sort.h:16
void mb_block_partition(device IdxT *block_partitions, const device ValT *dev_vals, const device IdxT *dev_idxs, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &n_blocks, uint3 tid, uint3 lid, uint3 tgp_dims)
Definition sort.h:524
void block_sort(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, uint3 tid, uint3 lid)
Definition sort.h:282
void mb_block_merge(const device IdxT *block_partitions, const device ValT *dev_vals_in, const device IdxT *dev_idxs_in, device ValT *dev_vals_out, device IdxT *dev_idxs_out, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &num_tiles, uint3 tid, uint3 lid)
Definition sort.h:578
constant constexpr const int zero_helper
Definition sort.h:329
void mb_block_sort(const device ValT *inp, device ValT *out_vals, device IdxT *out_idxs, const constant int &size_sorted_axis, const constant int &stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant int64_t *nc_strides, uint3 tid, uint3 lid)
Definition sort.h:480
void block_sort_nc(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant int64_t *in_nc_strides, const constant int64_t *out_nc_strides, uint3 tid, uint3 lid)
Definition sort.h:337
#define MLX_MTL_LOOP_UNROLL
Definition sort.h:4
static METAL_FUNC void merge_step(const threadgroup ValT *As, const threadgroup ValT *Bs, const threadgroup IdxT *As_idx, const threadgroup IdxT *Bs_idx, short A_sz, short B_sz, thread ValT(&vals)[N_PER_THREAD], thread IdxT(&idxs)[N_PER_THREAD])
Definition sort.h:95
static METAL_FUNC void sort(threadgroup ValT *tgp_vals, threadgroup IdxT *tgp_idxs, int size_sorted_axis, uint3 lid)
Definition sort.h:121
static METAL_FUNC int merge_partition(const threadgroup ValT *As, const threadgroup ValT *Bs, short A_sz, short B_sz, short sort_md)
Definition sort.h:69
ThreadSort< ValT, IdxT, ARG_SORT, N_PER_THREAD, CompareOp > thread_sort_t
Definition sort.h:67
BlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp > block_merge_sort_t
Definition sort.h:223
uint IdxT
Definition sort.h:222
static METAL_FUNC void block_sort(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, threadgroup ValT *tgp_vals, threadgroup IdxT *tgp_idxs, uint3 tid, uint3 lid)
Definition sort.h:233
T ValT
Definition sort.h:221
static constant constexpr const short N_PER_BLOCK
Definition sort.h:231
static constant constexpr const short N_PER_BLOCK
Definition sort.h:407
static METAL_FUNC void block_sort(const device ValT *inp, device ValT *out_vals, device IdxT *out_idxs, const constant int &size_sorted_axis, const constant int &stride_sorted_axis, threadgroup ValT *tgp_vals, threadgroup IdxT *tgp_idxs, uint3 tid, uint3 lid)
Definition sort.h:409
static METAL_FUNC int merge_partition(const device ValT *As, const device ValT *Bs, int A_sz, int B_sz, int sort_md)
Definition sort.h:447
BlockMergeSort< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp > block_merge_sort_t
Definition sort.h:399
METAL_FUNC bool operator()(T a, T b)
Definition sort.h:26
static constexpr constant T init
Definition sort.h:24
static const constant U max
Definition utils.h:24
static METAL_FUNC void sort(thread ValT(&vals)[N_PER_THREAD], thread IdxT(&idxs)[N_PER_THREAD])
Definition sort.h:38