From 7f914365fd8e45ede878cc3d0ef5b3d3599e597e Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 24 Jul 2024 14:37:10 -0700 Subject: [PATCH] Fix GPU sort for large arrays (#1285) * Fix GPU sort for large arrays --- mlx/backend/metal/kernels/sort.h | 36 +++++++++++++++++++------------- mlx/backend/metal/sort.cpp | 5 ++++- mlx/ops.cpp | 9 -------- python/tests/test_ops.py | 9 ++++++++ 4 files changed, 34 insertions(+), 25 deletions(-) diff --git a/mlx/backend/metal/kernels/sort.h b/mlx/backend/metal/kernels/sort.h index dca5106de..eb1f70d19 100644 --- a/mlx/backend/metal/kernels/sort.h +++ b/mlx/backend/metal/kernels/sort.h @@ -522,13 +522,13 @@ template < bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void -mb_block_partition( +[[kernel]] void mb_block_partition( device idx_t* block_partitions [[buffer(0)]], const device val_t* dev_vals [[buffer(1)]], const device idx_t* dev_idxs [[buffer(2)]], const constant int& size_sorted_axis [[buffer(3)]], const constant int& merge_tiles [[buffer(4)]], + const constant int& n_blocks [[buffer(5)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 tgp_dims [[threads_per_threadgroup]]) { @@ -543,23 +543,29 @@ mb_block_partition( dev_vals += tid.y * size_sorted_axis; dev_idxs += tid.y * size_sorted_axis; - // Find location in merge step - int merge_group = lid.x / merge_tiles; - int merge_lane = lid.x % merge_tiles; + for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) { + // Find location in merge step + int merge_group = i / merge_tiles; + int merge_lane = i % merge_tiles; - int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; - int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; - int A_st = min(size_sorted_axis, sort_st); - int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); - int B_st = A_ed; - int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); + int A_st = min(size_sorted_axis, sort_st); + int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); + int B_st = A_ed; + int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); - int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); - int partition = sort_kernel::merge_partition( - dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at); + int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); + int partition = sort_kernel::merge_partition( + dev_vals + A_st, + dev_vals + B_st, + A_ed - A_st, + B_ed - B_st, + partition_at); - block_partitions[lid.x] = A_st + partition; + block_partitions[i] = A_st + partition; + } } template < diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 457232654..824492789 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -177,6 +177,8 @@ void multi_block_sort( array dev_vals_out = dev_vals_1; array dev_idxs_out = dev_idxs_1; + int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024; + for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) { dev_vals_in = ping ? dev_vals_1 : dev_vals_0; dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0; @@ -199,8 +201,9 @@ void multi_block_sort( compute_encoder.set_input_array(dev_idxs_in, 2); compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3); compute_encoder->setBytes(&merge_tiles, sizeof(int), 4); + compute_encoder->setBytes(&n_blocks, sizeof(int), 5); - MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1); + MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1); MTL::Size grid_dims = MTL::Size(1, n_rows, 1); compute_encoder.dispatchThreadgroups(grid_dims, group_dims); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 5f17241b4..430fdc6f0 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1785,15 +1785,6 @@ array sort(const array& a, int axis, StreamOrDevice s /* = {} */) { throw std::invalid_argument(msg.str()); } - // TODO: Fix GPU kernel - if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) { - std::ostringstream msg; - msg << "[sort] GPU sort cannot handle sort axis of >= 2M elements," - << " got array with sort axis size " << a.shape(axis) << "." - << " Please place this operation on the CPU instead."; - throw std::runtime_error(msg.str()); - } - return array( a.shape(), a.dtype(), std::make_shared(to_stream(s), axis), {a}); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 5b83613f7..16799965f 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1840,6 +1840,15 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.array_equal(c_np, c_mx)) self.assertEqual(b_mx.dtype, c_mx.dtype) + # Test very large array + if mx.default_device() == mx.gpu: + a_np = np.random.normal(20, 20, size=(2**22)).astype(np.float32) + a_mx = mx.array(a_np) + + b_np = np.sort(a_np) + b_mx = mx.sort(a_mx) + self.assertTrue(np.array_equal(b_np, b_mx)) + def test_partition(self): shape = (3, 4, 5) for dtype in ("int32", "float32"):