[CUDA] fix sort (#2550)

* [CUDA] fix sort

* fix test
This commit is contained in:
Awni Hannun
2025-08-27 19:48:43 -07:00
committed by GitHub
parent 31c6f6e33f
commit 7ef8a6f2d5
2 changed files with 19 additions and 5 deletions

View File

@@ -9,7 +9,7 @@
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
#include <thrust/transform.h> #include <thrust/transform.h>
#include <cub/device/device_segmented_sort.cuh> #include <cub/device/device_segmented_radix_sort.cuh>
#include <cassert> #include <cassert>
@@ -79,7 +79,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
encoder.add_temporary(discard); encoder.add_temporary(discard);
size_t size; size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
nullptr, nullptr,
size, size,
in.data<Type>(), in.data<Type>(),
@@ -90,6 +90,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort, in.data_size() / nsort,
offsets, offsets,
offsets + 1, offsets + 1,
0,
sizeof(Type) * 8,
stream)); stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8); array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
@@ -104,7 +106,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
thrust::device_pointer_cast(indices.data<uint32_t>()), thrust::device_pointer_cast(indices.data<uint32_t>()),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)}); ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
temp.data<void>(), temp.data<void>(),
size, size,
in.data<Type>(), in.data<Type>(),
@@ -115,10 +117,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort, in.data_size() / nsort,
offsets, offsets,
offsets + 1, offsets + 1,
0,
sizeof(Type) * 8,
stream)); stream));
} else { } else {
size_t size; size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
nullptr, nullptr,
size, size,
in.data<Type>(), in.data<Type>(),
@@ -127,6 +131,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort, in.data_size() / nsort,
offsets, offsets,
offsets + 1, offsets + 1,
0,
sizeof(Type) * 8,
stream)); stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8); array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
@@ -134,7 +140,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
// Start capturing after allocations // Start capturing after allocations
auto capture = encoder.capture_context(); auto capture = encoder.capture_context();
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
temp.data<void>(), temp.data<void>(),
size, size,
in.data<Type>(), in.data<Type>(),
@@ -143,6 +149,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort, in.data_size() / nsort,
offsets, offsets,
offsets + 1, offsets + 1,
0,
sizeof(Type) * 8,
stream)); stream));
} }
} else { } else {

View File

@@ -2191,6 +2191,12 @@ class TestOps(mlx_tests.MLXTestCase):
y_mx = mx.sort(mx.array(x), axis=-2) y_mx = mx.sort(mx.array(x), axis=-2)
self.assertTrue(np.array_equal(y_np, y_mx)) self.assertTrue(np.array_equal(y_np, y_mx))
# Test many segments
a = mx.random.uniform(shape=(512, 128))
y_mx = mx.sort(a, axis=-1)
y_np = np.sort(np.array(a), axis=-1)
self.assertTrue(np.array_equal(y_np, y_mx))
def test_partition(self): def test_partition(self):
shape = (3, 4, 5) shape = (3, 4, 5)
for dtype in ("int32", "float32"): for dtype in ("int32", "float32"):