283[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void block_sort(
284 const device T* inp [[buffer(0)]],
285 device U* out [[buffer(1)]],
286 const constant
int& size_sorted_axis [[buffer(2)]],
287 const constant
int& in_stride_sorted_axis [[buffer(3)]],
288 const constant
int& out_stride_sorted_axis [[buffer(4)]],
289 const constant
int& in_stride_segment_axis [[buffer(5)]],
290 const constant
int& out_stride_segment_axis [[buffer(6)]],
291 uint3 tid [[threadgroup_position_in_grid]],
292 uint3 lid [[thread_position_in_threadgroup]]) {
295 using val_t =
typename sort_kernel::val_t;
296 using idx_t =
typename sort_kernel::idx_t;
299 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
300 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
301 sort_kernel::block_sort(
305 in_stride_sorted_axis,
306 out_stride_sorted_axis,
307 in_stride_segment_axis,
308 out_stride_segment_axis,
314 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
315 sort_kernel::block_sort(
319 in_stride_sorted_axis,
320 out_stride_sorted_axis,
321 in_stride_segment_axis,
322 out_stride_segment_axis,
338[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void block_sort_nc(
339 const device T* inp [[buffer(0)]],
340 device U* out [[buffer(1)]],
341 const constant
int& size_sorted_axis [[buffer(2)]],
342 const constant
int& in_stride_sorted_axis [[buffer(3)]],
343 const constant
int& out_stride_sorted_axis [[buffer(4)]],
344 const constant
int& nc_dim [[buffer(5)]],
345 const device
int* nc_shape [[buffer(6)]],
346 const device
size_t* in_nc_strides [[buffer(7)]],
347 const device
size_t* out_nc_strides [[buffer(8)]],
348 uint3 tid [[threadgroup_position_in_grid]],
349 uint3 lid [[thread_position_in_threadgroup]]) {
352 using val_t =
typename sort_kernel::val_t;
353 using idx_t =
typename sort_kernel::idx_t;
355 auto in_block_idx =
elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
356 auto out_block_idx =
elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
358 out += out_block_idx;
361 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
362 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
363 sort_kernel::block_sort(
367 in_stride_sorted_axis,
368 out_stride_sorted_axis,
376 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
377 sort_kernel::block_sort(
381 in_stride_sorted_axis,
382 out_stride_sorted_axis,
481[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void mb_block_sort(
482 const device val_t* inp [[buffer(0)]],
483 device val_t* out_vals [[buffer(1)]],
484 device idx_t* out_idxs [[buffer(2)]],
485 const constant
int& size_sorted_axis [[buffer(3)]],
486 const constant
int& stride_sorted_axis [[buffer(4)]],
487 const constant
int& nc_dim [[buffer(5)]],
488 const device
int* nc_shape [[buffer(6)]],
489 const device
size_t* nc_strides [[buffer(7)]],
490 uint3 tid [[threadgroup_position_in_grid]],
491 uint3 lid [[thread_position_in_threadgroup]]) {
499 auto block_idx =
elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
501 out_vals += tid.y * size_sorted_axis;
502 out_idxs += tid.y * size_sorted_axis;
504 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
505 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
507 sort_kernel::block_sort(
527 device idx_t* block_partitions [[buffer(0)]],
528 const device val_t* dev_vals [[buffer(1)]],
529 const device idx_t* dev_idxs [[buffer(2)]],
530 const constant
int& size_sorted_axis [[buffer(3)]],
531 const constant
int& merge_tiles [[buffer(4)]],
532 uint3 tid [[threadgroup_position_in_grid]],
533 uint3 lid [[thread_position_in_threadgroup]],
534 uint3 tgp_dims [[threads_per_threadgroup]]) {
542 block_partitions += tid.y * tgp_dims.x;
543 dev_vals += tid.y * size_sorted_axis;
544 dev_idxs += tid.y * size_sorted_axis;
547 int merge_group = lid.x / merge_tiles;
548 int merge_lane = lid.x % merge_tiles;
550 int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
551 int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
553 int A_st =
min(size_sorted_axis, sort_st);
554 int A_ed =
min(size_sorted_axis, sort_st + sort_sz / 2);
556 int B_ed =
min(size_sorted_axis, B_st + sort_sz / 2);
558 int partition_at =
min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
559 int partition = sort_kernel::merge_partition(
560 dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
562 block_partitions[lid.x] = A_st + partition;
574 const device idx_t* block_partitions [[buffer(0)]],
575 const device val_t* dev_vals_in [[buffer(1)]],
576 const device idx_t* dev_idxs_in [[buffer(2)]],
577 device val_t* dev_vals_out [[buffer(3)]],
578 device idx_t* dev_idxs_out [[buffer(4)]],
579 const constant
int& size_sorted_axis [[buffer(5)]],
580 const constant
int& merge_tiles [[buffer(6)]],
581 const constant
int& num_tiles [[buffer(7)]],
582 uint3 tid [[threadgroup_position_in_grid]],
583 uint3 lid [[thread_position_in_threadgroup]]) {
592 using block_sort_t =
typename sort_kernel::block_merge_sort_t;
594 block_partitions += tid.y * (num_tiles + 1);
595 dev_vals_in += tid.y * size_sorted_axis;
596 dev_idxs_in += tid.y * size_sorted_axis;
597 dev_vals_out += tid.y * size_sorted_axis;
598 dev_idxs_out += tid.y * size_sorted_axis;
600 int block_idx = tid.x;
601 int merge_group = block_idx / merge_tiles;
602 int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
603 int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
604 int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
606 int A_st = block_partitions[block_idx + 0];
607 int A_ed = block_partitions[block_idx + 1];
608 int B_st =
min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
611 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
613 if ((block_idx % merge_tiles) == merge_tiles - 1) {
614 A_ed =
min(size_sorted_axis, sort_st + sort_sz / 2);
615 B_ed =
min(size_sorted_axis, sort_st + sort_sz);
618 int A_sz = A_ed - A_st;
619 int B_sz = B_ed - B_st;
622 thread val_t thread_vals[N_PER_THREAD];
623 thread idx_t thread_idxs[N_PER_THREAD];
624 for (
int i = 0; i < N_PER_THREAD; i++) {
625 int idx = BLOCK_THREADS * i + lid.x;
626 if (idx < (A_sz + B_sz)) {
627 thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
628 : dev_vals_in[B_st + idx - A_sz];
629 thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
630 : dev_idxs_in[B_st + idx - A_sz];
632 thread_vals[i] = CompareOp::init;
638 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
639 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
640 threadgroup_barrier(mem_flags::mem_threadgroup);
641 for (
int i = 0; i < N_PER_THREAD; i++) {
642 int idx = BLOCK_THREADS * i + lid.x;
643 tgp_vals[idx] = thread_vals[i];
644 tgp_idxs[idx] = thread_idxs[i];
646 threadgroup_barrier(mem_flags::mem_threadgroup);
649 int sort_md_local =
min(A_sz + B_sz, N_PER_THREAD *
int(lid.x));
651 int A_st_local = block_sort_t::merge_partition(
652 tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
653 int A_ed_local = A_sz;
655 int B_st_local = sort_md_local - A_st_local;
656 int B_ed_local = B_sz;
658 int A_sz_local = A_ed_local - A_st_local;
659 int B_sz_local = B_ed_local - B_st_local;
662 block_sort_t::merge_step(
663 tgp_vals + A_st_local,
664 tgp_vals + A_ed_local + B_st_local,
665 tgp_idxs + A_st_local,
666 tgp_idxs + A_ed_local + B_st_local,
672 threadgroup_barrier(mem_flags::mem_threadgroup);
673 for (
int i = 0; i < N_PER_THREAD; ++i) {
674 int idx = lid.x * N_PER_THREAD;
675 tgp_vals[idx + i] = thread_vals[i];
676 tgp_idxs[idx + i] = thread_idxs[i];
679 threadgroup_barrier(mem_flags::mem_threadgroup);
681 int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
682 for (
int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
683 int idx = base_idx + i;
684 if (idx < size_sorted_axis) {
685 dev_vals_out[idx] = tgp_vals[i];
686 dev_idxs_out[idx] = tgp_idxs[i];
void block_sort_nc(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &nc_dim, const device int *nc_shape, const device size_t *in_nc_strides, const device size_t *out_nc_strides, uint3 tid, uint3 lid)
Definition sort.h:338
void mb_block_merge(const device idx_t *block_partitions, const device val_t *dev_vals_in, const device idx_t *dev_idxs_in, device val_t *dev_vals_out, device idx_t *dev_idxs_out, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &num_tiles, uint3 tid, uint3 lid)
Definition sort.h:573
static METAL_FUNC void block_sort(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, threadgroup val_t *tgp_vals, threadgroup idx_t *tgp_idxs, uint3 tid, uint3 lid)
Definition sort.h:234