283[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void block_sort(
284 const device T* inp [[buffer(0)]],
285 device U* out [[buffer(1)]],
286 const constant
int& size_sorted_axis [[buffer(2)]],
287 const constant
int& in_stride_sorted_axis [[buffer(3)]],
288 const constant
int& out_stride_sorted_axis [[buffer(4)]],
289 const constant
int& in_stride_segment_axis [[buffer(5)]],
290 const constant
int& out_stride_segment_axis [[buffer(6)]],
291 uint3 tid [[threadgroup_position_in_grid]],
292 uint3 lid [[thread_position_in_threadgroup]]) {
295 using val_t =
typename sort_kernel::val_t;
296 using idx_t =
typename sort_kernel::idx_t;
299 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
300 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
301 sort_kernel::block_sort(
305 in_stride_sorted_axis,
306 out_stride_sorted_axis,
307 in_stride_segment_axis,
308 out_stride_segment_axis,
314 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
315 sort_kernel::block_sort(
319 in_stride_sorted_axis,
320 out_stride_sorted_axis,
321 in_stride_segment_axis,
322 out_stride_segment_axis,
338[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void block_sort_nc(
339 const device T* inp [[buffer(0)]],
340 device U* out [[buffer(1)]],
341 const constant
int& size_sorted_axis [[buffer(2)]],
342 const constant
int& in_stride_sorted_axis [[buffer(3)]],
343 const constant
int& out_stride_sorted_axis [[buffer(4)]],
344 const constant
int& nc_dim [[buffer(5)]],
345 const device
int* nc_shape [[buffer(6)]],
346 const device
size_t* in_nc_strides [[buffer(7)]],
347 const device
size_t* out_nc_strides [[buffer(8)]],
348 uint3 tid [[threadgroup_position_in_grid]],
349 uint3 lid [[thread_position_in_threadgroup]]) {
352 using val_t =
typename sort_kernel::val_t;
353 using idx_t =
typename sort_kernel::idx_t;
355 auto in_block_idx =
elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
356 auto out_block_idx =
elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
358 out += out_block_idx;
361 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
362 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
363 sort_kernel::block_sort(
367 in_stride_sorted_axis,
368 out_stride_sorted_axis,
376 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
377 sort_kernel::block_sort(
381 in_stride_sorted_axis,
382 out_stride_sorted_axis,
481[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]]
void mb_block_sort(
482 const device val_t* inp [[buffer(0)]],
483 device val_t* out_vals [[buffer(1)]],
484 device idx_t* out_idxs [[buffer(2)]],
485 const constant
int& size_sorted_axis [[buffer(3)]],
486 const constant
int& stride_sorted_axis [[buffer(4)]],
487 const constant
int& nc_dim [[buffer(5)]],
488 const device
int* nc_shape [[buffer(6)]],
489 const device
size_t* nc_strides [[buffer(7)]],
490 uint3 tid [[threadgroup_position_in_grid]],
491 uint3 lid [[thread_position_in_threadgroup]]) {
499 auto block_idx =
elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
501 out_vals += tid.y * size_sorted_axis;
502 out_idxs += tid.y * size_sorted_axis;
504 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
505 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
507 sort_kernel::block_sort(
526 device idx_t* block_partitions [[buffer(0)]],
527 const device val_t* dev_vals [[buffer(1)]],
528 const device idx_t* dev_idxs [[buffer(2)]],
529 const constant
int& size_sorted_axis [[buffer(3)]],
530 const constant
int& merge_tiles [[buffer(4)]],
531 const constant
int& n_blocks [[buffer(5)]],
532 uint3 tid [[threadgroup_position_in_grid]],
533 uint3 lid [[thread_position_in_threadgroup]],
534 uint3 tgp_dims [[threads_per_threadgroup]]) {
542 block_partitions += tid.y * tgp_dims.x;
543 dev_vals += tid.y * size_sorted_axis;
544 dev_idxs += tid.y * size_sorted_axis;
546 for (
int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
548 int merge_group = i / merge_tiles;
549 int merge_lane = i % merge_tiles;
551 int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
552 int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
554 int A_st =
min(size_sorted_axis, sort_st);
555 int A_ed =
min(size_sorted_axis, sort_st + sort_sz / 2);
557 int B_ed =
min(size_sorted_axis, B_st + sort_sz / 2);
559 int partition_at =
min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
560 int partition = sort_kernel::merge_partition(
567 block_partitions[i] = A_st + partition;
580 const device idx_t* block_partitions [[buffer(0)]],
581 const device val_t* dev_vals_in [[buffer(1)]],
582 const device idx_t* dev_idxs_in [[buffer(2)]],
583 device val_t* dev_vals_out [[buffer(3)]],
584 device idx_t* dev_idxs_out [[buffer(4)]],
585 const constant
int& size_sorted_axis [[buffer(5)]],
586 const constant
int& merge_tiles [[buffer(6)]],
587 const constant
int& num_tiles [[buffer(7)]],
588 uint3 tid [[threadgroup_position_in_grid]],
589 uint3 lid [[thread_position_in_threadgroup]]) {
598 using block_sort_t =
typename sort_kernel::block_merge_sort_t;
600 block_partitions += tid.y * (num_tiles + 1);
601 dev_vals_in += tid.y * size_sorted_axis;
602 dev_idxs_in += tid.y * size_sorted_axis;
603 dev_vals_out += tid.y * size_sorted_axis;
604 dev_idxs_out += tid.y * size_sorted_axis;
606 int block_idx = tid.x;
607 int merge_group = block_idx / merge_tiles;
608 int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
609 int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
610 int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
612 int A_st = block_partitions[block_idx + 0];
613 int A_ed = block_partitions[block_idx + 1];
614 int B_st =
min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
617 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
619 if ((block_idx % merge_tiles) == merge_tiles - 1) {
620 A_ed =
min(size_sorted_axis, sort_st + sort_sz / 2);
621 B_ed =
min(size_sorted_axis, sort_st + sort_sz);
624 int A_sz = A_ed - A_st;
625 int B_sz = B_ed - B_st;
628 thread val_t thread_vals[N_PER_THREAD];
629 thread idx_t thread_idxs[N_PER_THREAD];
630 for (
int i = 0; i < N_PER_THREAD; i++) {
631 int idx = BLOCK_THREADS * i + lid.x;
632 if (idx < (A_sz + B_sz)) {
633 thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
634 : dev_vals_in[B_st + idx - A_sz];
635 thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
636 : dev_idxs_in[B_st + idx - A_sz];
638 thread_vals[i] = CompareOp::init;
644 threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
645 threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
646 threadgroup_barrier(mem_flags::mem_threadgroup);
647 for (
int i = 0; i < N_PER_THREAD; i++) {
648 int idx = BLOCK_THREADS * i + lid.x;
649 tgp_vals[idx] = thread_vals[i];
650 tgp_idxs[idx] = thread_idxs[i];
652 threadgroup_barrier(mem_flags::mem_threadgroup);
655 int sort_md_local =
min(A_sz + B_sz, N_PER_THREAD *
int(lid.x));
657 int A_st_local = block_sort_t::merge_partition(
658 tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
659 int A_ed_local = A_sz;
661 int B_st_local = sort_md_local - A_st_local;
662 int B_ed_local = B_sz;
664 int A_sz_local = A_ed_local - A_st_local;
665 int B_sz_local = B_ed_local - B_st_local;
668 block_sort_t::merge_step(
669 tgp_vals + A_st_local,
670 tgp_vals + A_ed_local + B_st_local,
671 tgp_idxs + A_st_local,
672 tgp_idxs + A_ed_local + B_st_local,
678 threadgroup_barrier(mem_flags::mem_threadgroup);
679 for (
int i = 0; i < N_PER_THREAD; ++i) {
680 int idx = lid.x * N_PER_THREAD;
681 tgp_vals[idx + i] = thread_vals[i];
682 tgp_idxs[idx + i] = thread_idxs[i];
685 threadgroup_barrier(mem_flags::mem_threadgroup);
687 int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
688 for (
int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
689 int idx = base_idx + i;
690 if (idx < size_sorted_axis) {
691 dev_vals_out[idx] = tgp_vals[i];
692 dev_idxs_out[idx] = tgp_idxs[i];
void block_sort_nc(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &nc_dim, const device int *nc_shape, const device size_t *in_nc_strides, const device size_t *out_nc_strides, uint3 tid, uint3 lid)
Definition sort.h:338
void mb_block_merge(const device idx_t *block_partitions, const device val_t *dev_vals_in, const device idx_t *dev_idxs_in, device val_t *dev_vals_out, device idx_t *dev_idxs_out, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &num_tiles, uint3 tid, uint3 lid)
Definition sort.h:579
static METAL_FUNC void block_sort(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, threadgroup val_t *tgp_vals, threadgroup idx_t *tgp_idxs, uint3 tid, uint3 lid)
Definition sort.h:234