mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
@@ -9,7 +9,7 @@
|
|||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
#include <thrust/device_ptr.h>
|
#include <thrust/device_ptr.h>
|
||||||
#include <thrust/transform.h>
|
#include <thrust/transform.h>
|
||||||
#include <cub/device/device_segmented_sort.cuh>
|
#include <cub/device/device_segmented_radix_sort.cuh>
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
@@ -79,7 +79,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
encoder.add_temporary(discard);
|
encoder.add_temporary(discard);
|
||||||
|
|
||||||
size_t size;
|
size_t size;
|
||||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||||
nullptr,
|
nullptr,
|
||||||
size,
|
size,
|
||||||
in.data<Type>(),
|
in.data<Type>(),
|
||||||
@@ -90,6 +90,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
in.data_size() / nsort,
|
in.data_size() / nsort,
|
||||||
offsets,
|
offsets,
|
||||||
offsets + 1,
|
offsets + 1,
|
||||||
|
0,
|
||||||
|
sizeof(Type) * 8,
|
||||||
stream));
|
stream));
|
||||||
|
|
||||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
@@ -104,7 +106,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
thrust::device_pointer_cast(indices.data<uint32_t>()),
|
thrust::device_pointer_cast(indices.data<uint32_t>()),
|
||||||
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
|
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
|
||||||
|
|
||||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||||
temp.data<void>(),
|
temp.data<void>(),
|
||||||
size,
|
size,
|
||||||
in.data<Type>(),
|
in.data<Type>(),
|
||||||
@@ -115,10 +117,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
in.data_size() / nsort,
|
in.data_size() / nsort,
|
||||||
offsets,
|
offsets,
|
||||||
offsets + 1,
|
offsets + 1,
|
||||||
|
0,
|
||||||
|
sizeof(Type) * 8,
|
||||||
stream));
|
stream));
|
||||||
} else {
|
} else {
|
||||||
size_t size;
|
size_t size;
|
||||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
|
||||||
nullptr,
|
nullptr,
|
||||||
size,
|
size,
|
||||||
in.data<Type>(),
|
in.data<Type>(),
|
||||||
@@ -127,6 +131,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
in.data_size() / nsort,
|
in.data_size() / nsort,
|
||||||
offsets,
|
offsets,
|
||||||
offsets + 1,
|
offsets + 1,
|
||||||
|
0,
|
||||||
|
sizeof(Type) * 8,
|
||||||
stream));
|
stream));
|
||||||
|
|
||||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||||
@@ -134,7 +140,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
|
|
||||||
// Start capturing after allocations
|
// Start capturing after allocations
|
||||||
auto capture = encoder.capture_context();
|
auto capture = encoder.capture_context();
|
||||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
|
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
|
||||||
temp.data<void>(),
|
temp.data<void>(),
|
||||||
size,
|
size,
|
||||||
in.data<Type>(),
|
in.data<Type>(),
|
||||||
@@ -143,6 +149,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
in.data_size() / nsort,
|
in.data_size() / nsort,
|
||||||
offsets,
|
offsets,
|
||||||
offsets + 1,
|
offsets + 1,
|
||||||
|
0,
|
||||||
|
sizeof(Type) * 8,
|
||||||
stream));
|
stream));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@@ -2191,6 +2191,12 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
y_mx = mx.sort(mx.array(x), axis=-2)
|
y_mx = mx.sort(mx.array(x), axis=-2)
|
||||||
self.assertTrue(np.array_equal(y_np, y_mx))
|
self.assertTrue(np.array_equal(y_np, y_mx))
|
||||||
|
|
||||||
|
# Test many segments
|
||||||
|
a = mx.random.uniform(shape=(512, 128))
|
||||||
|
y_mx = mx.sort(a, axis=-1)
|
||||||
|
y_np = np.sort(np.array(a), axis=-1)
|
||||||
|
self.assertTrue(np.array_equal(y_np, y_mx))
|
||||||
|
|
||||||
def test_partition(self):
|
def test_partition(self):
|
||||||
shape = (3, 4, 5)
|
shape = (3, 4, 5)
|
||||||
for dtype in ("int32", "float32"):
|
for dtype in ("int32", "float32"):
|
||||||
|
Reference in New Issue
Block a user