faster sort (#1831)

This commit is contained in:
Awni Hannun 2025-02-05 06:10:22 -08:00 committed by GitHub
parent a229c8cef0
commit fe5987b81d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 98 additions and 91 deletions

View File

@ -29,17 +29,16 @@ struct LessThan {
}; };
template < template <
typename val_t, typename ValT,
typename idx_t, typename IdxT,
bool ARG_SORT, bool ARG_SORT,
short N_PER_THREAD, short N_PER_THREAD,
typename CompareOp> typename CompareOp>
struct ThreadSort { struct ThreadSort {
static METAL_FUNC void sort( static METAL_FUNC void sort(
thread val_t (&vals)[N_PER_THREAD], thread ValT (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) { thread IdxT (&idxs)[N_PER_THREAD]) {
CompareOp op; CompareOp op;
MLX_MTL_LOOP_UNROLL MLX_MTL_LOOP_UNROLL
for (short i = 0; i < N_PER_THREAD; ++i) { for (short i = 0; i < N_PER_THREAD; ++i) {
MLX_MTL_LOOP_UNROLL MLX_MTL_LOOP_UNROLL
@ -58,18 +57,18 @@ struct ThreadSort {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template < template <
typename val_t, typename ValT,
typename idx_t, typename IdxT,
bool ARG_SORT, bool ARG_SORT,
short BLOCK_THREADS, short BLOCK_THREADS,
short N_PER_THREAD, short N_PER_THREAD,
typename CompareOp> typename CompareOp>
struct BlockMergeSort { struct BlockMergeSort {
using thread_sort_t = using thread_sort_t =
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>; ThreadSort<ValT, IdxT, ARG_SORT, N_PER_THREAD, CompareOp>;
static METAL_FUNC int merge_partition( static METAL_FUNC int merge_partition(
const threadgroup val_t* As, const threadgroup ValT* As,
const threadgroup val_t* Bs, const threadgroup ValT* Bs,
short A_sz, short A_sz,
short B_sz, short B_sz,
short sort_md) { short sort_md) {
@ -94,14 +93,14 @@ struct BlockMergeSort {
} }
static METAL_FUNC void merge_step( static METAL_FUNC void merge_step(
const threadgroup val_t* As, const threadgroup ValT* As,
const threadgroup val_t* Bs, const threadgroup ValT* Bs,
const threadgroup idx_t* As_idx, const threadgroup IdxT* As_idx,
const threadgroup idx_t* Bs_idx, const threadgroup IdxT* Bs_idx,
short A_sz, short A_sz,
short B_sz, short B_sz,
thread val_t (&vals)[N_PER_THREAD], thread ValT (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) { thread IdxT (&idxs)[N_PER_THREAD]) {
CompareOp op; CompareOp op;
short a_idx = 0; short a_idx = 0;
short b_idx = 0; short b_idx = 0;
@ -120,16 +119,16 @@ struct BlockMergeSort {
} }
static METAL_FUNC void sort( static METAL_FUNC void sort(
threadgroup val_t* tgp_vals [[threadgroup(0)]], threadgroup ValT* tgp_vals [[threadgroup(0)]],
threadgroup idx_t* tgp_idxs [[threadgroup(1)]], threadgroup IdxT* tgp_idxs [[threadgroup(1)]],
int size_sorted_axis, int size_sorted_axis,
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
// Get thread location // Get thread location
int idx = lid.x * N_PER_THREAD; int idx = lid.x * N_PER_THREAD;
// Load from shared memory // Load from shared memory
thread val_t thread_vals[N_PER_THREAD]; thread ValT thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD]; thread IdxT thread_idxs[N_PER_THREAD];
for (int i = 0; i < N_PER_THREAD; ++i) { for (int i = 0; i < N_PER_THREAD; ++i) {
thread_vals[i] = tgp_vals[idx + i]; thread_vals[i] = tgp_vals[idx + i];
if (ARG_SORT) { if (ARG_SORT) {
@ -169,8 +168,8 @@ struct BlockMergeSort {
int B_st = sort_st + sort_sz / 2; int B_st = sort_st + sort_sz / 2;
int B_ed = sort_st + sort_sz; int B_ed = sort_st + sort_sz;
const threadgroup val_t* As = tgp_vals + A_st; const threadgroup ValT* As = tgp_vals + A_st;
const threadgroup val_t* Bs = tgp_vals + B_st; const threadgroup ValT* Bs = tgp_vals + B_st;
int A_sz = A_ed - A_st; int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st; int B_sz = B_ed - B_st;
@ -187,9 +186,9 @@ struct BlockMergeSort {
A_sz -= partition; A_sz -= partition;
B_sz -= sort_md - partition; B_sz -= sort_md - partition;
const threadgroup idx_t* As_idx = const threadgroup IdxT* As_idx =
ARG_SORT ? tgp_idxs + A_st + partition : nullptr; ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
const threadgroup idx_t* Bs_idx = const threadgroup IdxT* Bs_idx =
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
// Merge starting at the partition and store results in thread registers // Merge starting at the partition and store results in thread registers
@ -219,11 +218,11 @@ template <
short N_PER_THREAD, short N_PER_THREAD,
typename CompareOp = LessThan<T>> typename CompareOp = LessThan<T>>
struct KernelMergeSort { struct KernelMergeSort {
using val_t = T; using ValT = T;
using idx_t = uint; using IdxT = uint;
using block_merge_sort_t = BlockMergeSort< using block_merge_sort_t = BlockMergeSort<
val_t, ValT,
idx_t, IdxT,
ARG_SORT, ARG_SORT,
BLOCK_THREADS, BLOCK_THREADS,
N_PER_THREAD, N_PER_THREAD,
@ -239,8 +238,8 @@ struct KernelMergeSort {
const constant int& out_stride_sorted_axis, const constant int& out_stride_sorted_axis,
const constant int& in_stride_segment_axis, const constant int& in_stride_segment_axis,
const constant int& out_stride_segment_axis, const constant int& out_stride_segment_axis,
threadgroup val_t* tgp_vals, threadgroup ValT* tgp_vals,
threadgroup idx_t* tgp_idxs, threadgroup IdxT* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index // tid.y tells us the segment index
@ -250,7 +249,7 @@ struct KernelMergeSort {
// Copy into threadgroup memory // Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis] tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
: val_t(CompareOp::init); : ValT(CompareOp::init);
if (ARG_SORT) { if (ARG_SORT) {
tgp_idxs[i] = i; tgp_idxs[i] = i;
} }
@ -292,12 +291,12 @@ template <
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t; using ValT = typename sort_kernel::ValT;
using idx_t = typename sort_kernel::idx_t; using IdxT = typename sort_kernel::IdxT;
if (ARG_SORT) { if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort( sort_kernel::block_sort(
inp, inp,
out, out,
@ -311,7 +310,7 @@ template <
tid, tid,
lid); lid);
} else { } else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort( sort_kernel::block_sort(
inp, inp,
out, out,
@ -349,8 +348,8 @@ template <
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t; using ValT = typename sort_kernel::ValT;
using idx_t = typename sort_kernel::idx_t; using IdxT = typename sort_kernel::IdxT;
auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim); auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim); auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
@ -358,8 +357,8 @@ template <
out += out_block_idx; out += out_block_idx;
if (ARG_SORT) { if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort( sort_kernel::block_sort(
inp, inp,
out, out,
@ -373,7 +372,7 @@ template <
tid, tid,
lid); lid);
} else { } else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort( sort_kernel::block_sort(
inp, inp,
out, out,
@ -390,16 +389,16 @@ template <
} }
template < template <
typename val_t, typename ValT,
typename idx_t, typename IdxT,
bool ARG_SORT, bool ARG_SORT,
short BLOCK_THREADS, short BLOCK_THREADS,
short N_PER_THREAD, short N_PER_THREAD,
typename CompareOp = LessThan<val_t>> typename CompareOp = LessThan<ValT>>
struct KernelMultiBlockMergeSort { struct KernelMultiBlockMergeSort {
using block_merge_sort_t = BlockMergeSort< using block_merge_sort_t = BlockMergeSort<
val_t, ValT,
idx_t, IdxT,
ARG_SORT, ARG_SORT,
BLOCK_THREADS, BLOCK_THREADS,
N_PER_THREAD, N_PER_THREAD,
@ -408,13 +407,13 @@ struct KernelMultiBlockMergeSort {
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort( static METAL_FUNC void block_sort(
const device val_t* inp, const device ValT* inp,
device val_t* out_vals, device ValT* out_vals,
device idx_t* out_idxs, device IdxT* out_idxs,
const constant int& size_sorted_axis, const constant int& size_sorted_axis,
const constant int& stride_sorted_axis, const constant int& stride_sorted_axis,
threadgroup val_t* tgp_vals, threadgroup ValT* tgp_vals,
threadgroup idx_t* tgp_idxs, threadgroup IdxT* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index // tid.y tells us the segment index
@ -424,7 +423,7 @@ struct KernelMultiBlockMergeSort {
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i; int idx = base_idx + i;
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
: val_t(CompareOp::init); : ValT(CompareOp::init);
tgp_idxs[i] = idx; tgp_idxs[i] = idx;
} }
@ -446,8 +445,8 @@ struct KernelMultiBlockMergeSort {
} }
static METAL_FUNC int merge_partition( static METAL_FUNC int merge_partition(
const device val_t* As, const device ValT* As,
const device val_t* Bs, const device ValT* Bs,
int A_sz, int A_sz,
int B_sz, int B_sz,
int sort_md) { int sort_md) {
@ -473,15 +472,15 @@ struct KernelMultiBlockMergeSort {
}; };
template < template <
typename val_t, typename ValT,
typename idx_t, typename IdxT,
bool ARG_SORT, bool ARG_SORT,
short BLOCK_THREADS, short BLOCK_THREADS,
short N_PER_THREAD> short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort( [[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
const device val_t* inp [[buffer(0)]], const device ValT* inp [[buffer(0)]],
device val_t* out_vals [[buffer(1)]], device ValT* out_vals [[buffer(1)]],
device idx_t* out_idxs [[buffer(2)]], device IdxT* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]], const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]], const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]], const constant int& nc_dim [[buffer(5)]],
@ -490,8 +489,8 @@ template <
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort< using sort_kernel = KernelMultiBlockMergeSort<
val_t, ValT,
idx_t, IdxT,
ARG_SORT, ARG_SORT,
BLOCK_THREADS, BLOCK_THREADS,
N_PER_THREAD>; N_PER_THREAD>;
@ -501,8 +500,8 @@ template <
out_vals += tid.y * size_sorted_axis; out_vals += tid.y * size_sorted_axis;
out_idxs += tid.y * size_sorted_axis; out_idxs += tid.y * size_sorted_axis;
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort( sort_kernel::block_sort(
inp, inp,
@ -517,15 +516,15 @@ template <
} }
template < template <
typename val_t, typename ValT,
typename idx_t, typename IdxT,
bool ARG_SORT, bool ARG_SORT,
short BLOCK_THREADS, short BLOCK_THREADS,
short N_PER_THREAD> short N_PER_THREAD>
[[kernel]] void mb_block_partition( [[kernel]] void mb_block_partition(
device idx_t* block_partitions [[buffer(0)]], device IdxT* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]], const device ValT* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]], const device IdxT* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]], const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]], const constant int& merge_tiles [[buffer(4)]],
const constant int& n_blocks [[buffer(5)]], const constant int& n_blocks [[buffer(5)]],
@ -533,8 +532,8 @@ template <
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]) { uint3 tgp_dims [[threads_per_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort< using sort_kernel = KernelMultiBlockMergeSort<
val_t, ValT,
idx_t, IdxT,
ARG_SORT, ARG_SORT,
BLOCK_THREADS, BLOCK_THREADS,
N_PER_THREAD>; N_PER_THREAD>;
@ -569,27 +568,27 @@ template <
} }
template < template <
typename val_t, typename ValT,
typename idx_t, typename IdxT,
bool ARG_SORT, bool ARG_SORT,
short BLOCK_THREADS, short BLOCK_THREADS,
short N_PER_THREAD, short N_PER_THREAD,
typename CompareOp = LessThan<val_t>> typename CompareOp = LessThan<ValT>>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void [[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_merge( mb_block_merge(
const device idx_t* block_partitions [[buffer(0)]], const device IdxT* block_partitions [[buffer(0)]],
const device val_t* dev_vals_in [[buffer(1)]], const device ValT* dev_vals_in [[buffer(1)]],
const device idx_t* dev_idxs_in [[buffer(2)]], const device IdxT* dev_idxs_in [[buffer(2)]],
device val_t* dev_vals_out [[buffer(3)]], device ValT* dev_vals_out [[buffer(3)]],
device idx_t* dev_idxs_out [[buffer(4)]], device IdxT* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]], const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]], const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]], const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort< using sort_kernel = KernelMultiBlockMergeSort<
val_t, ValT,
idx_t, IdxT,
ARG_SORT, ARG_SORT,
BLOCK_THREADS, BLOCK_THREADS,
N_PER_THREAD, N_PER_THREAD,
@ -625,8 +624,8 @@ mb_block_merge(
int B_sz = B_ed - B_st; int B_sz = B_ed - B_st;
// Load from global memory // Load from global memory
thread val_t thread_vals[N_PER_THREAD]; thread ValT thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD]; thread IdxT thread_idxs[N_PER_THREAD];
for (int i = 0; i < N_PER_THREAD; i++) { for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x; int idx = BLOCK_THREADS * i + lid.x;
if (idx < (A_sz + B_sz)) { if (idx < (A_sz + B_sz)) {
@ -641,8 +640,8 @@ mb_block_merge(
} }
// Write to shared memory // Write to shared memory
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; i++) { for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x; int idx = BLOCK_THREADS * i + lid.x;

View File

@ -22,10 +22,12 @@
_block_sort, itname, itype, itname, itype, false, bn, tn) _block_sort, itname, itype, itname, itype, false, bn, tn)
#define instantiate_block_sort_tn(itname, itype, bn) \ #define instantiate_block_sort_tn(itname, itype, bn) \
instantiate_block_sort_base(itname, itype, bn, 8) \ instantiate_block_sort_base(itname, itype, bn, 4) \
instantiate_arg_block_sort_base(itname, itype, bn, 8) instantiate_arg_block_sort_base(itname, itype, bn, 4)
#define instantiate_block_sort_bn(itname, itype) \ #define instantiate_block_sort_bn(itname, itype) \
instantiate_block_sort_tn(itname, itype, 32) \
instantiate_block_sort_tn(itname, itype, 64) \
instantiate_block_sort_tn(itname, itype, 128) \ instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256) \ instantiate_block_sort_tn(itname, itype, 256) \
instantiate_block_sort_tn(itname, itype, 512) instantiate_block_sort_tn(itname, itype, 512)
@ -41,6 +43,8 @@ instantiate_block_sort_bn(float32, float)
instantiate_block_sort_bn(bfloat16, bfloat16_t) instantiate_block_sort_bn(bfloat16, bfloat16_t)
#define instantiate_block_sort_long(itname, itype) \ #define instantiate_block_sort_long(itname, itype) \
instantiate_block_sort_tn(itname, itype, 32) \
instantiate_block_sort_tn(itname, itype, 64) \
instantiate_block_sort_tn(itname, itype, 128) \ instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256) instantiate_block_sort_tn(itname, itype, 256)
@ -57,7 +61,7 @@ instantiate_block_sort_long(int64, int64_t)
mb_block_merge, vtype, itype, arg_sort, bn, tn) mb_block_merge, vtype, itype, arg_sort, bn, tn)
#define instantiate_multi_block_sort_base(vtname, vtype) \ #define instantiate_multi_block_sort_base(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8) instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 4)
instantiate_multi_block_sort_base(uint8, uint8_t) instantiate_multi_block_sort_base(uint8, uint8_t)
instantiate_multi_block_sort_base(uint16, uint16_t) instantiate_multi_block_sort_base(uint16, uint16_t)
@ -70,7 +74,7 @@ instantiate_multi_block_sort_base(float32, float)
instantiate_multi_block_sort_base(bfloat16, bfloat16_t) instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
#define instantiate_multi_block_sort_long(vtname, vtype) \ #define instantiate_multi_block_sort_long(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8) instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 4)
instantiate_multi_block_sort_long(uint64, uint64_t) instantiate_multi_block_sort_long(uint64, uint64_t)
instantiate_multi_block_sort_long(int64, int64_t) // clang-format on instantiate_multi_block_sort_long(int64, int64_t) // clang-format on

View File

@ -274,16 +274,20 @@ void gpu_merge_sort(
int size_sorted_axis = in.shape(axis); int size_sorted_axis = in.shape(axis);
// Get kernel size // Get kernel size
int tn = 8; int tn = 4;
int bn = 128;
int potential_bn = (size_sorted_axis + tn - 1) / tn; int potential_bn = (size_sorted_axis + tn - 1) / tn;
int bn;
if (potential_bn > 256) { if (potential_bn > 256) {
bn = 512; bn = 512;
} else if (potential_bn > 128) { } else if (potential_bn > 128) {
bn = 256; bn = 256;
} else { } else if (potential_bn > 64) {
bn = 128; bn = 128;
} else if (potential_bn > 32) {
bn = 64;
} else {
bn = 32;
} }
if (bn == 512 && size_of(in.dtype()) > 4) { if (bn == 512 && size_of(in.dtype()) > 4) {