mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
@@ -9,7 +9,7 @@
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
#include <cub/device/device_segmented_sort.cuh>
|
||||
#include <cub/device/device_segmented_radix_sort.cuh>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
@@ -79,7 +79,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
encoder.add_temporary(discard);
|
||||
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||
nullptr,
|
||||
size,
|
||||
in.data<Type>(),
|
||||
@@ -90,6 +90,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
0,
|
||||
sizeof(Type) * 8,
|
||||
stream));
|
||||
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
@@ -104,7 +106,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
thrust::device_pointer_cast(indices.data<uint32_t>()),
|
||||
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
|
||||
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||
temp.data<void>(),
|
||||
size,
|
||||
in.data<Type>(),
|
||||
@@ -115,10 +117,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
0,
|
||||
sizeof(Type) * 8,
|
||||
stream));
|
||||
} else {
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
|
||||
nullptr,
|
||||
size,
|
||||
in.data<Type>(),
|
||||
@@ -127,6 +131,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
0,
|
||||
sizeof(Type) * 8,
|
||||
stream));
|
||||
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
@@ -134,7 +140,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
|
||||
// Start capturing after allocations
|
||||
auto capture = encoder.capture_context();
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
|
||||
temp.data<void>(),
|
||||
size,
|
||||
in.data<Type>(),
|
||||
@@ -143,6 +149,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
0,
|
||||
sizeof(Type) * 8,
|
||||
stream));
|
||||
}
|
||||
} else {
|
||||
|
Reference in New Issue
Block a user