3#define MLX_MTL_CONST static constant constexpr const
4#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
38 static METAL_FUNC
void sort(
39 thread val_t (&vals)[N_PER_THREAD],
40 thread idx_t (&idxs)[N_PER_THREAD]) {
44 for (
short i = 0; i < N_PER_THREAD; ++i) {
46 for (
short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
47 if (op(vals[j + 1], vals[j])) {
71 const threadgroup val_t* As,
72 const threadgroup val_t* Bs,
78 short A_st =
max(0, sort_md - B_sz);
79 short A_ed =
min(sort_md, A_sz);
82 short md = A_st + (A_ed - A_st) / 2;
84 auto b = Bs[sort_md - 1 - md];
97 const threadgroup val_t* As,
98 const threadgroup val_t* Bs,
99 const threadgroup idx_t* As_idx,
100 const threadgroup idx_t* Bs_idx,
103 thread val_t (&vals)[N_PER_THREAD],
104 thread idx_t (&idxs)[N_PER_THREAD]) {
109 for (
int i = 0; i < N_PER_THREAD; ++i) {
112 bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
114 vals[i] = pred ? b : a;
115 idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
117 b_idx += short(pred);
118 a_idx += short(!pred);
123 threadgroup val_t* tgp_vals [[threadgroup(0)]],
124 threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
125 int size_sorted_axis,
126 uint3 lid [[thread_position_in_threadgroup]]) {
128 int idx = lid.x * N_PER_THREAD;
131 thread val_t thread_vals[N_PER_THREAD];
132 thread idx_t thread_idxs[N_PER_THREAD];
133 for (
int i = 0; i < N_PER_THREAD; ++i) {
134 thread_vals[i] = tgp_vals[idx + i];
136 thread_idxs[i] = tgp_idxs[idx + i];
141 if (idx < size_sorted_axis) {
146 for (
int merge_threads = 2; merge_threads <= BLOCK_THREADS;
147 merge_threads *= 2) {
149 threadgroup_barrier(mem_flags::mem_threadgroup);
150 for (
int i = 0; i < N_PER_THREAD; ++i) {
151 tgp_vals[idx + i] = thread_vals[i];
153 tgp_idxs[idx + i] = thread_idxs[i];
156 threadgroup_barrier(mem_flags::mem_threadgroup);
159 int merge_group = lid.x / merge_threads;
160 int merge_lane = lid.x % merge_threads;
162 int sort_sz = N_PER_THREAD * merge_threads;
163 int sort_st = N_PER_THREAD * merge_threads * merge_group;
168 int A_ed = sort_st + sort_sz / 2;
169 int B_st = sort_st + sort_sz / 2;
170 int B_ed = sort_st + sort_sz;
172 const threadgroup val_t* As = tgp_vals + A_st;
173 const threadgroup val_t* Bs = tgp_vals + B_st;
174 int A_sz = A_ed - A_st;
175 int B_sz = B_ed - B_st;
181 int sort_md = N_PER_THREAD * merge_lane;
185 Bs += sort_md - partition;
188 B_sz -= sort_md - partition;
190 const threadgroup idx_t* As_idx =
191 ARG_SORT ? tgp_idxs + A_st + partition :
nullptr;
192 const threadgroup idx_t* Bs_idx =
193 ARG_SORT ? tgp_idxs + B_st + sort_md - partition :
nullptr;
196 merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
200 threadgroup_barrier(mem_flags::mem_threadgroup);
201 for (
int i = 0; i < N_PER_THREAD; ++i) {
202 tgp_vals[idx + i] = thread_vals[i];
204 tgp_idxs[idx + i] = thread_idxs[i];
237 const constant
int& size_sorted_axis,
238 const constant
int& in_stride_sorted_axis,
239 const constant
int& out_stride_sorted_axis,
240 const constant
int& in_stride_segment_axis,
241 const constant
int& out_stride_segment_axis,
242 threadgroup
val_t* tgp_vals,
243 threadgroup
idx_t* tgp_idxs,
244 uint3 tid [[threadgroup_position_in_grid]],
245 uint3 lid [[thread_position_in_threadgroup]]) {
247 inp += tid.y * in_stride_segment_axis;
248 out += tid.y * out_stride_segment_axis;
251 for (
short i = lid.x; i <
N_PER_BLOCK; i += BLOCK_THREADS) {
252 tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
253 :
val_t(CompareOp::init);
260 threadgroup_barrier(mem_flags::mem_threadgroup);
264 threadgroup_barrier(mem_flags::mem_threadgroup);
267 for (
int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
269 out[i * out_stride_sorted_axis] = tgp_idxs[i];
271 out[i * out_stride_sorted_axis] = tgp_vals[i];
283[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void block_sort(
284 const device T* inp [[buffer(0)]],
285 device U* out [[buffer(1)]],
286 const constant
int& size_sorted_axis [[buffer(2)]],
287 const constant
int& in_stride_sorted_axis [[buffer(3)]],
288 const constant
int& out_stride_sorted_axis [[buffer(4)]],
289 const constant
int& in_stride_segment_axis [[buffer(5)]],
290 const constant
int& out_stride_segment_axis [[buffer(6)]],
291 uint3 tid [[threadgroup_position_in_grid]],
292 uint3 lid [[thread_position_in_threadgroup]]) {
295 using val_t =
typename sort_kernel::val_t;
296 using idx_t =
typename sort_kernel::idx_t;
299 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
300 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
301 sort_kernel::block_sort(
305 in_stride_sorted_axis,
306 out_stride_sorted_axis,
307 in_stride_segment_axis,
308 out_stride_segment_axis,
314 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
315 sort_kernel::block_sort(
319 in_stride_sorted_axis,
320 out_stride_sorted_axis,
321 in_stride_segment_axis,
322 out_stride_segment_axis,
338[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void block_sort_nc(
339 const device T* inp [[buffer(0)]],
340 device U* out [[buffer(1)]],
341 const constant
int& size_sorted_axis [[buffer(2)]],
342 const constant
int& in_stride_sorted_axis [[buffer(3)]],
343 const constant
int& out_stride_sorted_axis [[buffer(4)]],
344 const constant
int& nc_dim [[buffer(5)]],
345 const constant
int* nc_shape [[buffer(6)]],
346 const constant int64_t* in_nc_strides [[buffer(7)]],
347 const constant int64_t* out_nc_strides [[buffer(8)]],
348 uint3 tid [[threadgroup_position_in_grid]],
349 uint3 lid [[thread_position_in_threadgroup]]) {
352 using val_t =
typename sort_kernel::val_t;
353 using idx_t =
typename sort_kernel::idx_t;
355 auto in_block_idx =
elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
356 auto out_block_idx =
elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
358 out += out_block_idx;
361 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
362 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
363 sort_kernel::block_sort(
367 in_stride_sorted_axis,
368 out_stride_sorted_axis,
376 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
377 sort_kernel::block_sort(
381 in_stride_sorted_axis,
382 out_stride_sorted_axis,
411 const device val_t* inp,
412 device val_t* out_vals,
413 device idx_t* out_idxs,
414 const constant
int& size_sorted_axis,
415 const constant
int& stride_sorted_axis,
416 threadgroup val_t* tgp_vals,
417 threadgroup idx_t* tgp_idxs,
418 uint3 tid [[threadgroup_position_in_grid]],
419 uint3 lid [[thread_position_in_threadgroup]]) {
424 for (
short i = lid.x; i <
N_PER_BLOCK; i += BLOCK_THREADS) {
425 int idx = base_idx + i;
426 tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
427 : val_t(CompareOp::init);
432 threadgroup_barrier(mem_flags::mem_threadgroup);
436 threadgroup_barrier(mem_flags::mem_threadgroup);
439 for (
int i = lid.x; i <
N_PER_BLOCK; i += BLOCK_THREADS) {
440 int idx = base_idx + i;
441 if (idx < size_sorted_axis) {
442 out_vals[idx] = tgp_vals[i];
443 out_idxs[idx] = tgp_idxs[i];
449 const device val_t* As,
450 const device val_t* Bs,
456 int A_st =
max(0, sort_md - B_sz);
457 int A_ed =
min(sort_md, A_sz);
459 while (A_st < A_ed) {
460 int md = A_st + (A_ed - A_st) / 2;
462 auto b = Bs[sort_md - 1 - md];
481[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void mb_block_sort(
482 const device val_t* inp [[buffer(0)]],
483 device val_t* out_vals [[buffer(1)]],
484 device idx_t* out_idxs [[buffer(2)]],
485 const constant
int& size_sorted_axis [[buffer(3)]],
486 const constant
int& stride_sorted_axis [[buffer(4)]],
487 const constant
int& nc_dim [[buffer(5)]],
488 const constant
int* nc_shape [[buffer(6)]],
489 const constant int64_t* nc_strides [[buffer(7)]],
490 uint3 tid [[threadgroup_position_in_grid]],
491 uint3 lid [[thread_position_in_threadgroup]]) {
499 auto block_idx =
elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
501 out_vals += tid.y * size_sorted_axis;
502 out_idxs += tid.y * size_sorted_axis;
504 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
505 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
507 sort_kernel::block_sort(
526 device idx_t* block_partitions [[buffer(0)]],
527 const device val_t* dev_vals [[buffer(1)]],
528 const device idx_t* dev_idxs [[buffer(2)]],
529 const constant
int& size_sorted_axis [[buffer(3)]],
530 const constant
int& merge_tiles [[buffer(4)]],
531 const constant
int& n_blocks [[buffer(5)]],
532 uint3 tid [[threadgroup_position_in_grid]],
533 uint3 lid [[thread_position_in_threadgroup]],
534 uint3 tgp_dims [[threads_per_threadgroup]]) {
542 block_partitions += tid.y * tgp_dims.x;
543 dev_vals += tid.y * size_sorted_axis;
544 dev_idxs += tid.y * size_sorted_axis;
546 for (
int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
548 int merge_group = i / merge_tiles;
549 int merge_lane = i % merge_tiles;
551 int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
552 int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
554 int A_st =
min(size_sorted_axis, sort_st);
555 int A_ed =
min(size_sorted_axis, sort_st + sort_sz / 2);
557 int B_ed =
min(size_sorted_axis, B_st + sort_sz / 2);
559 int partition_at =
min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
560 int partition = sort_kernel::merge_partition(
567 block_partitions[i] = A_st + partition;
578[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void
580 const device idx_t* block_partitions [[buffer(0)]],
581 const device val_t* dev_vals_in [[buffer(1)]],
582 const device idx_t* dev_idxs_in [[buffer(2)]],
583 device val_t* dev_vals_out [[buffer(3)]],
584 device idx_t* dev_idxs_out [[buffer(4)]],
585 const constant
int& size_sorted_axis [[buffer(5)]],
586 const constant
int& merge_tiles [[buffer(6)]],
587 const constant
int& num_tiles [[buffer(7)]],
588 uint3 tid [[threadgroup_position_in_grid]],
589 uint3 lid [[thread_position_in_threadgroup]]) {
598 using block_sort_t =
typename sort_kernel::block_merge_sort_t;
600 block_partitions += tid.y * (num_tiles + 1);
601 dev_vals_in += tid.y * size_sorted_axis;
602 dev_idxs_in += tid.y * size_sorted_axis;
603 dev_vals_out += tid.y * size_sorted_axis;
604 dev_idxs_out += tid.y * size_sorted_axis;
606 int block_idx = tid.x;
607 int merge_group = block_idx / merge_tiles;
608 int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
609 int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
610 int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
612 int A_st = block_partitions[block_idx + 0];
613 int A_ed = block_partitions[block_idx + 1];
614 int B_st =
min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
617 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
619 if ((block_idx % merge_tiles) == merge_tiles - 1) {
620 A_ed =
min(size_sorted_axis, sort_st + sort_sz / 2);
621 B_ed =
min(size_sorted_axis, sort_st + sort_sz);
624 int A_sz = A_ed - A_st;
625 int B_sz = B_ed - B_st;
628 thread val_t thread_vals[N_PER_THREAD];
629 thread idx_t thread_idxs[N_PER_THREAD];
630 for (
int i = 0; i < N_PER_THREAD; i++) {
631 int idx = BLOCK_THREADS * i + lid.x;
632 if (idx < (A_sz + B_sz)) {
633 thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
634 : dev_vals_in[B_st + idx - A_sz];
635 thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
636 : dev_idxs_in[B_st + idx - A_sz];
638 thread_vals[i] = CompareOp::init;
644 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
645 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
646 threadgroup_barrier(mem_flags::mem_threadgroup);
647 for (
int i = 0; i < N_PER_THREAD; i++) {
648 int idx = BLOCK_THREADS * i + lid.x;
649 tgp_vals[idx] = thread_vals[i];
650 tgp_idxs[idx] = thread_idxs[i];
652 threadgroup_barrier(mem_flags::mem_threadgroup);
655 int sort_md_local =
min(A_sz + B_sz, N_PER_THREAD *
int(lid.x));
657 int A_st_local = block_sort_t::merge_partition(
658 tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
659 int A_ed_local = A_sz;
661 int B_st_local = sort_md_local - A_st_local;
662 int B_ed_local = B_sz;
664 int A_sz_local = A_ed_local - A_st_local;
665 int B_sz_local = B_ed_local - B_st_local;
668 block_sort_t::merge_step(
669 tgp_vals + A_st_local,
670 tgp_vals + A_ed_local + B_st_local,
671 tgp_idxs + A_st_local,
672 tgp_idxs + A_ed_local + B_st_local,
678 threadgroup_barrier(mem_flags::mem_threadgroup);
679 for (
int i = 0; i < N_PER_THREAD; ++i) {
680 int idx = lid.x * N_PER_THREAD;
681 tgp_vals[idx + i] = thread_vals[i];
682 tgp_idxs[idx + i] = thread_idxs[i];
685 threadgroup_barrier(mem_flags::mem_threadgroup);
687 int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
688 for (
int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
689 int idx = base_idx + i;
690 if (idx < size_sorted_axis) {
691 dev_vals_out[idx] = tgp_vals[i];
692 dev_idxs_out[idx] = tgp_idxs[i];
#define MLX_MTL_CONST
Definition sort.h:3
void mb_block_partition(device idx_t *block_partitions, const device val_t *dev_vals, const device idx_t *dev_idxs, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &n_blocks, uint3 tid, uint3 lid, uint3 tgp_dims)
Definition sort.h:525
METAL_FUNC void thread_swap(thread T &a, thread T &b)
Definition sort.h:16
void mb_block_sort(const device val_t *inp, device val_t *out_vals, device idx_t *out_idxs, const constant int &size_sorted_axis, const constant int &stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant int64_t *nc_strides, uint3 tid, uint3 lid)
Definition sort.h:481
void block_sort(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, uint3 tid, uint3 lid)
Definition sort.h:283
void mb_block_merge(const device idx_t *block_partitions, const device val_t *dev_vals_in, const device idx_t *dev_idxs_in, device val_t *dev_vals_out, device idx_t *dev_idxs_out, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &num_tiles, uint3 tid, uint3 lid)
Definition sort.h:579
constant constexpr const int zero_helper
Definition sort.h:330
void block_sort_nc(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant int64_t *in_nc_strides, const constant int64_t *out_nc_strides, uint3 tid, uint3 lid)
Definition sort.h:338
#define MLX_MTL_LOOP_UNROLL
Definition sort.h:4
static METAL_FUNC int merge_partition(const threadgroup val_t *As, const threadgroup val_t *Bs, short A_sz, short B_sz, short sort_md)
Definition sort.h:70
static METAL_FUNC void merge_step(const threadgroup val_t *As, const threadgroup val_t *Bs, const threadgroup idx_t *As_idx, const threadgroup idx_t *Bs_idx, short A_sz, short B_sz, thread val_t(&vals)[N_PER_THREAD], thread idx_t(&idxs)[N_PER_THREAD])
Definition sort.h:96
static METAL_FUNC void sort(threadgroup val_t *tgp_vals, threadgroup idx_t *tgp_idxs, int size_sorted_axis, uint3 lid)
Definition sort.h:122
ThreadSort< val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp > thread_sort_t
Definition sort.h:68
uint idx_t
Definition sort.h:223
T val_t
Definition sort.h:222
static METAL_FUNC void block_sort(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, threadgroup val_t *tgp_vals, threadgroup idx_t *tgp_idxs, uint3 tid, uint3 lid)
Definition sort.h:234
static constant constexpr const short N_PER_BLOCK
Definition sort.h:232
BlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp > block_merge_sort_t
Definition sort.h:224
static METAL_FUNC void block_sort(const device val_t *inp, device val_t *out_vals, device idx_t *out_idxs, const constant int &size_sorted_axis, const constant int &stride_sorted_axis, threadgroup val_t *tgp_vals, threadgroup idx_t *tgp_idxs, uint3 tid, uint3 lid)
Definition sort.h:410
static METAL_FUNC int merge_partition(const device val_t *As, const device val_t *Bs, int A_sz, int B_sz, int sort_md)
Definition sort.h:448
static constant constexpr const short N_PER_BLOCK
Definition sort.h:408
BlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp > block_merge_sort_t
Definition sort.h:400
METAL_FUNC bool operator()(T a, T b)
Definition sort.h:26
static constexpr constant T init
Definition sort.h:24
static const constant U max
Definition utils.h:24
static METAL_FUNC void sort(thread val_t(&vals)[N_PER_THREAD], thread idx_t(&idxs)[N_PER_THREAD])
Definition sort.h:38