From fe5987b81d3b0c35b5f26717535290f3821669f4 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 5 Feb 2025 06:10:22 -0800 Subject: [PATCH] faster sort (#1831) --- mlx/backend/metal/kernels/sort.h | 167 +++++++++++++-------------- mlx/backend/metal/kernels/sort.metal | 12 +- mlx/backend/metal/sort.cpp | 10 +- 3 files changed, 98 insertions(+), 91 deletions(-) diff --git a/mlx/backend/metal/kernels/sort.h b/mlx/backend/metal/kernels/sort.h index e9a00c777..b067150d8 100644 --- a/mlx/backend/metal/kernels/sort.h +++ b/mlx/backend/metal/kernels/sort.h @@ -29,17 +29,16 @@ struct LessThan { }; template < - typename val_t, - typename idx_t, + typename ValT, + typename IdxT, bool ARG_SORT, short N_PER_THREAD, typename CompareOp> struct ThreadSort { static METAL_FUNC void sort( - thread val_t (&vals)[N_PER_THREAD], - thread idx_t (&idxs)[N_PER_THREAD]) { + thread ValT (&vals)[N_PER_THREAD], + thread IdxT (&idxs)[N_PER_THREAD]) { CompareOp op; - MLX_MTL_LOOP_UNROLL for (short i = 0; i < N_PER_THREAD; ++i) { MLX_MTL_LOOP_UNROLL @@ -58,18 +57,18 @@ struct ThreadSort { /////////////////////////////////////////////////////////////////////////////// template < - typename val_t, - typename idx_t, + typename ValT, + typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD, typename CompareOp> struct BlockMergeSort { using thread_sort_t = - ThreadSort; + ThreadSort; static METAL_FUNC int merge_partition( - const threadgroup val_t* As, - const threadgroup val_t* Bs, + const threadgroup ValT* As, + const threadgroup ValT* Bs, short A_sz, short B_sz, short sort_md) { @@ -94,14 +93,14 @@ struct BlockMergeSort { } static METAL_FUNC void merge_step( - const threadgroup val_t* As, - const threadgroup val_t* Bs, - const threadgroup idx_t* As_idx, - const threadgroup idx_t* Bs_idx, + const threadgroup ValT* As, + const threadgroup ValT* Bs, + const threadgroup IdxT* As_idx, + const threadgroup IdxT* Bs_idx, short A_sz, short B_sz, - thread val_t (&vals)[N_PER_THREAD], - thread idx_t (&idxs)[N_PER_THREAD]) { + thread ValT (&vals)[N_PER_THREAD], + thread IdxT (&idxs)[N_PER_THREAD]) { CompareOp op; short a_idx = 0; short b_idx = 0; @@ -120,16 +119,16 @@ struct BlockMergeSort { } static METAL_FUNC void sort( - threadgroup val_t* tgp_vals [[threadgroup(0)]], - threadgroup idx_t* tgp_idxs [[threadgroup(1)]], + threadgroup ValT* tgp_vals [[threadgroup(0)]], + threadgroup IdxT* tgp_idxs [[threadgroup(1)]], int size_sorted_axis, uint3 lid [[thread_position_in_threadgroup]]) { // Get thread location int idx = lid.x * N_PER_THREAD; // Load from shared memory - thread val_t thread_vals[N_PER_THREAD]; - thread idx_t thread_idxs[N_PER_THREAD]; + thread ValT thread_vals[N_PER_THREAD]; + thread IdxT thread_idxs[N_PER_THREAD]; for (int i = 0; i < N_PER_THREAD; ++i) { thread_vals[i] = tgp_vals[idx + i]; if (ARG_SORT) { @@ -169,8 +168,8 @@ struct BlockMergeSort { int B_st = sort_st + sort_sz / 2; int B_ed = sort_st + sort_sz; - const threadgroup val_t* As = tgp_vals + A_st; - const threadgroup val_t* Bs = tgp_vals + B_st; + const threadgroup ValT* As = tgp_vals + A_st; + const threadgroup ValT* Bs = tgp_vals + B_st; int A_sz = A_ed - A_st; int B_sz = B_ed - B_st; @@ -187,9 +186,9 @@ struct BlockMergeSort { A_sz -= partition; B_sz -= sort_md - partition; - const threadgroup idx_t* As_idx = + const threadgroup IdxT* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr; - const threadgroup idx_t* Bs_idx = + const threadgroup IdxT* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; // Merge starting at the partition and store results in thread registers @@ -219,11 +218,11 @@ template < short N_PER_THREAD, typename CompareOp = LessThan> struct KernelMergeSort { - using val_t = T; - using idx_t = uint; + using ValT = T; + using IdxT = uint; using block_merge_sort_t = BlockMergeSort< - val_t, - idx_t, + ValT, + IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, @@ -239,8 +238,8 @@ struct KernelMergeSort { const constant int& out_stride_sorted_axis, const constant int& in_stride_segment_axis, const constant int& out_stride_segment_axis, - threadgroup val_t* tgp_vals, - threadgroup idx_t* tgp_idxs, + threadgroup ValT* tgp_vals, + threadgroup IdxT* tgp_idxs, uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // tid.y tells us the segment index @@ -250,7 +249,7 @@ struct KernelMergeSort { // Copy into threadgroup memory for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis] - : val_t(CompareOp::init); + : ValT(CompareOp::init); if (ARG_SORT) { tgp_idxs[i] = i; } @@ -292,12 +291,12 @@ template < uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = KernelMergeSort; - using val_t = typename sort_kernel::val_t; - using idx_t = typename sort_kernel::idx_t; + using ValT = typename sort_kernel::ValT; + using IdxT = typename sort_kernel::IdxT; if (ARG_SORT) { - threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; - threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out, @@ -311,7 +310,7 @@ template < tid, lid); } else { - threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out, @@ -349,8 +348,8 @@ template < uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = KernelMergeSort; - using val_t = typename sort_kernel::val_t; - using idx_t = typename sort_kernel::idx_t; + using ValT = typename sort_kernel::ValT; + using IdxT = typename sort_kernel::IdxT; auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim); auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim); @@ -358,8 +357,8 @@ template < out += out_block_idx; if (ARG_SORT) { - threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; - threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out, @@ -373,7 +372,7 @@ template < tid, lid); } else { - threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, out, @@ -390,16 +389,16 @@ template < } template < - typename val_t, - typename idx_t, + typename ValT, + typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD, - typename CompareOp = LessThan> + typename CompareOp = LessThan> struct KernelMultiBlockMergeSort { using block_merge_sort_t = BlockMergeSort< - val_t, - idx_t, + ValT, + IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, @@ -408,13 +407,13 @@ struct KernelMultiBlockMergeSort { MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; static METAL_FUNC void block_sort( - const device val_t* inp, - device val_t* out_vals, - device idx_t* out_idxs, + const device ValT* inp, + device ValT* out_vals, + device IdxT* out_idxs, const constant int& size_sorted_axis, const constant int& stride_sorted_axis, - threadgroup val_t* tgp_vals, - threadgroup idx_t* tgp_idxs, + threadgroup ValT* tgp_vals, + threadgroup IdxT* tgp_idxs, uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // tid.y tells us the segment index @@ -424,7 +423,7 @@ struct KernelMultiBlockMergeSort { for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { int idx = base_idx + i; tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] - : val_t(CompareOp::init); + : ValT(CompareOp::init); tgp_idxs[i] = idx; } @@ -446,8 +445,8 @@ struct KernelMultiBlockMergeSort { } static METAL_FUNC int merge_partition( - const device val_t* As, - const device val_t* Bs, + const device ValT* As, + const device ValT* Bs, int A_sz, int B_sz, int sort_md) { @@ -473,15 +472,15 @@ struct KernelMultiBlockMergeSort { }; template < - typename val_t, - typename idx_t, + typename ValT, + typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> [[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort( - const device val_t* inp [[buffer(0)]], - device val_t* out_vals [[buffer(1)]], - device idx_t* out_idxs [[buffer(2)]], + const device ValT* inp [[buffer(0)]], + device ValT* out_vals [[buffer(1)]], + device IdxT* out_idxs [[buffer(2)]], const constant int& size_sorted_axis [[buffer(3)]], const constant int& stride_sorted_axis [[buffer(4)]], const constant int& nc_dim [[buffer(5)]], @@ -490,8 +489,8 @@ template < uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = KernelMultiBlockMergeSort< - val_t, - idx_t, + ValT, + IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; @@ -501,8 +500,8 @@ template < out_vals += tid.y * size_sorted_axis; out_idxs += tid.y * size_sorted_axis; - threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; - threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; sort_kernel::block_sort( inp, @@ -517,15 +516,15 @@ template < } template < - typename val_t, - typename idx_t, + typename ValT, + typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> [[kernel]] void mb_block_partition( - device idx_t* block_partitions [[buffer(0)]], - const device val_t* dev_vals [[buffer(1)]], - const device idx_t* dev_idxs [[buffer(2)]], + device IdxT* block_partitions [[buffer(0)]], + const device ValT* dev_vals [[buffer(1)]], + const device IdxT* dev_idxs [[buffer(2)]], const constant int& size_sorted_axis [[buffer(3)]], const constant int& merge_tiles [[buffer(4)]], const constant int& n_blocks [[buffer(5)]], @@ -533,8 +532,8 @@ template < uint3 lid [[thread_position_in_threadgroup]], uint3 tgp_dims [[threads_per_threadgroup]]) { using sort_kernel = KernelMultiBlockMergeSort< - val_t, - idx_t, + ValT, + IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; @@ -569,27 +568,27 @@ template < } template < - typename val_t, - typename idx_t, + typename ValT, + typename IdxT, bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD, - typename CompareOp = LessThan> + typename CompareOp = LessThan> [[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge( - const device idx_t* block_partitions [[buffer(0)]], - const device val_t* dev_vals_in [[buffer(1)]], - const device idx_t* dev_idxs_in [[buffer(2)]], - device val_t* dev_vals_out [[buffer(3)]], - device idx_t* dev_idxs_out [[buffer(4)]], + const device IdxT* block_partitions [[buffer(0)]], + const device ValT* dev_vals_in [[buffer(1)]], + const device IdxT* dev_idxs_in [[buffer(2)]], + device ValT* dev_vals_out [[buffer(3)]], + device IdxT* dev_idxs_out [[buffer(4)]], const constant int& size_sorted_axis [[buffer(5)]], const constant int& merge_tiles [[buffer(6)]], const constant int& num_tiles [[buffer(7)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = KernelMultiBlockMergeSort< - val_t, - idx_t, + ValT, + IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, @@ -625,8 +624,8 @@ mb_block_merge( int B_sz = B_ed - B_st; // Load from global memory - thread val_t thread_vals[N_PER_THREAD]; - thread idx_t thread_idxs[N_PER_THREAD]; + thread ValT thread_vals[N_PER_THREAD]; + thread IdxT thread_idxs[N_PER_THREAD]; for (int i = 0; i < N_PER_THREAD; i++) { int idx = BLOCK_THREADS * i + lid.x; if (idx < (A_sz + B_sz)) { @@ -641,8 +640,8 @@ mb_block_merge( } // Write to shared memory - threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; - threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; threadgroup_barrier(mem_flags::mem_threadgroup); for (int i = 0; i < N_PER_THREAD; i++) { int idx = BLOCK_THREADS * i + lid.x; diff --git a/mlx/backend/metal/kernels/sort.metal b/mlx/backend/metal/kernels/sort.metal index 68248484d..7f198d55d 100644 --- a/mlx/backend/metal/kernels/sort.metal +++ b/mlx/backend/metal/kernels/sort.metal @@ -22,10 +22,12 @@ _block_sort, itname, itype, itname, itype, false, bn, tn) #define instantiate_block_sort_tn(itname, itype, bn) \ - instantiate_block_sort_base(itname, itype, bn, 8) \ - instantiate_arg_block_sort_base(itname, itype, bn, 8) + instantiate_block_sort_base(itname, itype, bn, 4) \ + instantiate_arg_block_sort_base(itname, itype, bn, 4) #define instantiate_block_sort_bn(itname, itype) \ + instantiate_block_sort_tn(itname, itype, 32) \ + instantiate_block_sort_tn(itname, itype, 64) \ instantiate_block_sort_tn(itname, itype, 128) \ instantiate_block_sort_tn(itname, itype, 256) \ instantiate_block_sort_tn(itname, itype, 512) @@ -41,6 +43,8 @@ instantiate_block_sort_bn(float32, float) instantiate_block_sort_bn(bfloat16, bfloat16_t) #define instantiate_block_sort_long(itname, itype) \ + instantiate_block_sort_tn(itname, itype, 32) \ + instantiate_block_sort_tn(itname, itype, 64) \ instantiate_block_sort_tn(itname, itype, 128) \ instantiate_block_sort_tn(itname, itype, 256) @@ -57,7 +61,7 @@ instantiate_block_sort_long(int64, int64_t) mb_block_merge, vtype, itype, arg_sort, bn, tn) #define instantiate_multi_block_sort_base(vtname, vtype) \ - instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8) + instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 4) instantiate_multi_block_sort_base(uint8, uint8_t) instantiate_multi_block_sort_base(uint16, uint16_t) @@ -70,7 +74,7 @@ instantiate_multi_block_sort_base(float32, float) instantiate_multi_block_sort_base(bfloat16, bfloat16_t) #define instantiate_multi_block_sort_long(vtname, vtype) \ - instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8) + instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 4) instantiate_multi_block_sort_long(uint64, uint64_t) instantiate_multi_block_sort_long(int64, int64_t) // clang-format on diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 2863ca938..39036828e 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -274,16 +274,20 @@ void gpu_merge_sort( int size_sorted_axis = in.shape(axis); // Get kernel size - int tn = 8; - int bn = 128; + int tn = 4; int potential_bn = (size_sorted_axis + tn - 1) / tn; + int bn; if (potential_bn > 256) { bn = 512; } else if (potential_bn > 128) { bn = 256; - } else { + } else if (potential_bn > 64) { bn = 128; + } else if (potential_bn > 32) { + bn = 64; + } else { + bn = 32; } if (bn == 512 && size_of(in.dtype()) > 4) {