283[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] 
void block_sort(
 
  284    const device T* inp [[buffer(0)]],
 
  285    device U* out [[buffer(1)]],
 
  286    const constant 
int& size_sorted_axis [[buffer(2)]],
 
  287    const constant 
int& in_stride_sorted_axis [[buffer(3)]],
 
  288    const constant 
int& out_stride_sorted_axis [[buffer(4)]],
 
  289    const constant 
int& in_stride_segment_axis [[buffer(5)]],
 
  290    const constant 
int& out_stride_segment_axis [[buffer(6)]],
 
  291    uint3 tid [[threadgroup_position_in_grid]],
 
  292    uint3 lid [[thread_position_in_threadgroup]]) {
 
  295  using val_t = 
typename sort_kernel::val_t;
 
  296  using idx_t = 
typename sort_kernel::idx_t;
 
  299    threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
 
  300    threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
 
  301    sort_kernel::block_sort(
 
  305        in_stride_sorted_axis,
 
  306        out_stride_sorted_axis,
 
  307        in_stride_segment_axis,
 
  308        out_stride_segment_axis,
 
  314    threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
 
  315    sort_kernel::block_sort(
 
  319        in_stride_sorted_axis,
 
  320        out_stride_sorted_axis,
 
  321        in_stride_segment_axis,
 
  322        out_stride_segment_axis,
 
 
  338[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] 
void block_sort_nc(
 
  339    const device T* inp [[buffer(0)]],
 
  340    device U* out [[buffer(1)]],
 
  341    const constant 
int& size_sorted_axis [[buffer(2)]],
 
  342    const constant 
int& in_stride_sorted_axis [[buffer(3)]],
 
  343    const constant 
int& out_stride_sorted_axis [[buffer(4)]],
 
  344    const constant 
int& nc_dim [[buffer(5)]],
 
  345    const constant 
int* nc_shape [[buffer(6)]],
 
  346    const constant 
size_t* in_nc_strides [[buffer(7)]],
 
  347    const constant 
size_t* out_nc_strides [[buffer(8)]],
 
  348    uint3 tid [[threadgroup_position_in_grid]],
 
  349    uint3 lid [[thread_position_in_threadgroup]]) {
 
  352  using val_t = 
typename sort_kernel::val_t;
 
  353  using idx_t = 
typename sort_kernel::idx_t;
 
  355  auto in_block_idx = 
elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
 
  356  auto out_block_idx = 
elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
 
  358  out += out_block_idx;
 
  361    threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
 
  362    threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
 
  363    sort_kernel::block_sort(
 
  367        in_stride_sorted_axis,
 
  368        out_stride_sorted_axis,
 
  376    threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
 
  377    sort_kernel::block_sort(
 
  381        in_stride_sorted_axis,
 
  382        out_stride_sorted_axis,
 
 
  481[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] 
void mb_block_sort(
 
  482    const device val_t* inp [[buffer(0)]],
 
  483    device val_t* out_vals [[buffer(1)]],
 
  484    device idx_t* out_idxs [[buffer(2)]],
 
  485    const constant 
int& size_sorted_axis [[buffer(3)]],
 
  486    const constant 
int& stride_sorted_axis [[buffer(4)]],
 
  487    const constant 
int& nc_dim [[buffer(5)]],
 
  488    const constant 
int* nc_shape [[buffer(6)]],
 
  489    const constant 
size_t* nc_strides [[buffer(7)]],
 
  490    uint3 tid [[threadgroup_position_in_grid]],
 
  491    uint3 lid [[thread_position_in_threadgroup]]) {
 
  499  auto block_idx = 
elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
 
  501  out_vals += tid.y * size_sorted_axis;
 
  502  out_idxs += tid.y * size_sorted_axis;
 
  504  threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
 
  505  threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
 
  507  sort_kernel::block_sort(
 
 
  526    device idx_t* block_partitions [[buffer(0)]],
 
  527    const device val_t* dev_vals [[buffer(1)]],
 
  528    const device idx_t* dev_idxs [[buffer(2)]],
 
  529    const constant 
int& size_sorted_axis [[buffer(3)]],
 
  530    const constant 
int& merge_tiles [[buffer(4)]],
 
  531    const constant 
int& n_blocks [[buffer(5)]],
 
  532    uint3 tid [[threadgroup_position_in_grid]],
 
  533    uint3 lid [[thread_position_in_threadgroup]],
 
  534    uint3 tgp_dims [[threads_per_threadgroup]]) {
 
  542  block_partitions += tid.y * tgp_dims.x;
 
  543  dev_vals += tid.y * size_sorted_axis;
 
  544  dev_idxs += tid.y * size_sorted_axis;
 
  546  for (
int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
 
  548    int merge_group = i / merge_tiles;
 
  549    int merge_lane = i % merge_tiles;
 
  551    int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
 
  552    int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
 
  554    int A_st = 
min(size_sorted_axis, sort_st);
 
  555    int A_ed = 
min(size_sorted_axis, sort_st + sort_sz / 2);
 
  557    int B_ed = 
min(size_sorted_axis, B_st + sort_sz / 2);
 
  559    int partition_at = 
min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
 
  560    int partition = sort_kernel::merge_partition(
 
  567    block_partitions[i] = A_st + partition;
 
 
  580    const device idx_t* block_partitions [[buffer(0)]],
 
  581    const device val_t* dev_vals_in [[buffer(1)]],
 
  582    const device idx_t* dev_idxs_in [[buffer(2)]],
 
  583    device val_t* dev_vals_out [[buffer(3)]],
 
  584    device idx_t* dev_idxs_out [[buffer(4)]],
 
  585    const constant 
int& size_sorted_axis [[buffer(5)]],
 
  586    const constant 
int& merge_tiles [[buffer(6)]],
 
  587    const constant 
int& num_tiles [[buffer(7)]],
 
  588    uint3 tid [[threadgroup_position_in_grid]],
 
  589    uint3 lid [[thread_position_in_threadgroup]]) {
 
  598  using block_sort_t = 
typename sort_kernel::block_merge_sort_t;
 
  600  block_partitions += tid.y * (num_tiles + 1);
 
  601  dev_vals_in += tid.y * size_sorted_axis;
 
  602  dev_idxs_in += tid.y * size_sorted_axis;
 
  603  dev_vals_out += tid.y * size_sorted_axis;
 
  604  dev_idxs_out += tid.y * size_sorted_axis;
 
  606  int block_idx = tid.x;
 
  607  int merge_group = block_idx / merge_tiles;
 
  608  int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
 
  609  int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
 
  610  int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
 
  612  int A_st = block_partitions[block_idx + 0];
 
  613  int A_ed = block_partitions[block_idx + 1];
 
  614  int B_st = 
min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
 
  617      2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
 
  619  if ((block_idx % merge_tiles) == merge_tiles - 1) {
 
  620    A_ed = 
min(size_sorted_axis, sort_st + sort_sz / 2);
 
  621    B_ed = 
min(size_sorted_axis, sort_st + sort_sz);
 
  624  int A_sz = A_ed - A_st;
 
  625  int B_sz = B_ed - B_st;
 
  628  thread val_t thread_vals[N_PER_THREAD];
 
  629  thread idx_t thread_idxs[N_PER_THREAD];
 
  630  for (
int i = 0; i < N_PER_THREAD; i++) {
 
  631    int idx = BLOCK_THREADS * i + lid.x;
 
  632    if (idx < (A_sz + B_sz)) {
 
  633      thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
 
  634                                    : dev_vals_in[B_st + idx - A_sz];
 
  635      thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
 
  636                                    : dev_idxs_in[B_st + idx - A_sz];
 
  638      thread_vals[i] = CompareOp::init;
 
  644  threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
 
  645  threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
 
  646  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  647  for (
int i = 0; i < N_PER_THREAD; i++) {
 
  648    int idx = BLOCK_THREADS * i + lid.x;
 
  649    tgp_vals[idx] = thread_vals[i];
 
  650    tgp_idxs[idx] = thread_idxs[i];
 
  652  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  655  int sort_md_local = 
min(A_sz + B_sz, N_PER_THREAD * 
int(lid.x));
 
  657  int A_st_local = block_sort_t::merge_partition(
 
  658      tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
 
  659  int A_ed_local = A_sz;
 
  661  int B_st_local = sort_md_local - A_st_local;
 
  662  int B_ed_local = B_sz;
 
  664  int A_sz_local = A_ed_local - A_st_local;
 
  665  int B_sz_local = B_ed_local - B_st_local;
 
  668  block_sort_t::merge_step(
 
  669      tgp_vals + A_st_local,
 
  670      tgp_vals + A_ed_local + B_st_local,
 
  671      tgp_idxs + A_st_local,
 
  672      tgp_idxs + A_ed_local + B_st_local,
 
  678  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  679  for (
int i = 0; i < N_PER_THREAD; ++i) {
 
  680    int idx = lid.x * N_PER_THREAD;
 
  681    tgp_vals[idx + i] = thread_vals[i];
 
  682    tgp_idxs[idx + i] = thread_idxs[i];
 
  685  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  687  int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
 
  688  for (
int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
 
  689    int idx = base_idx + i;
 
  690    if (idx < size_sorted_axis) {
 
  691      dev_vals_out[idx] = tgp_vals[i];
 
  692      dev_idxs_out[idx] = tgp_idxs[i];
 
 
void block_sort_nc(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant size_t *in_nc_strides, const constant size_t *out_nc_strides, uint3 tid, uint3 lid)
Definition sort.h:338
 
void mb_block_sort(const device val_t *inp, device val_t *out_vals, device idx_t *out_idxs, const constant int &size_sorted_axis, const constant int &stride_sorted_axis, const constant int &nc_dim, const constant int *nc_shape, const constant size_t *nc_strides, uint3 tid, uint3 lid)
Definition sort.h:481
 
void mb_block_merge(const device idx_t *block_partitions, const device val_t *dev_vals_in, const device idx_t *dev_idxs_in, device val_t *dev_vals_out, device idx_t *dev_idxs_out, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &num_tiles, uint3 tid, uint3 lid)
Definition sort.h:579
 
static METAL_FUNC void block_sort(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, threadgroup val_t *tgp_vals, threadgroup idx_t *tgp_idxs, uint3 tid, uint3 lid)
Definition sort.h:234