From bae9a6b404aa21fa068faab626839051f9d610fe Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 11 Jun 2025 00:59:47 +0900 Subject: [PATCH] CUDA backend: sort (#2262) Co-authored-by: Awni Hannun --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/primitives.cu | 2 - mlx/backend/cuda/sort.cu | 180 ++++++++++++++++++++++++++++++++ 3 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 mlx/backend/cuda/sort.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index c813f8fd4..23ae64cf6 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -16,6 +16,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 2c3a73c42..3f3674c07 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -73,7 +73,6 @@ bool fast::ScaledDotProductAttention::use_fallback( NO_GPU(ArgPartition) NO_GPU(ArgReduce) -NO_GPU(ArgSort) NO_GPU(BlockMaskedMM) NO_GPU_MULTI(Compiled) NO_GPU(Convolution) @@ -100,7 +99,6 @@ NO_GPU(ScatterAxis) NO_GPU(Select) NO_GPU(SliceUpdate) NO_GPU(Softmax) -NO_GPU(Sort) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu new file mode 100644 index 000000000..e1c2e8530 --- /dev/null +++ b/mlx/backend/cuda/sort.cu @@ -0,0 +1,180 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include +#include + +namespace mlx::core { + +namespace { + +template +struct ModOp { + T divisor; + __device__ T operator()(T x) { + return x % divisor; + } +}; + +// We can not use any op in eval, make an utility. +array swapaxes_in_eval(const array& in, int axis1, int axis2) { + std::vector axes(in.ndim()); + std::iota(axes.begin(), axes.end(), 0); + std::swap(axes[axis1], axes[axis2]); + // TODO: Share the code with Transpose::eval. + Shape shape(axes.size()); + Strides strides(in.ndim()); + for (size_t ax = 0; ax < axes.size(); ++ax) { + shape[ax] = in.shape()[axes[ax]]; + strides[ax] = in.strides()[axes[ax]]; + } + auto flags = in.flags(); + if (flags.contiguous) { + auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides); + flags.row_contiguous = row_contiguous; + flags.col_contiguous = col_contiguous; + } + array out(shape, in.dtype(), nullptr, {}); + out.copy_shared_buffer(in, strides, flags, in.data_size()); + return out; +} + +template +void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR( + cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( + temp.data(), size, args...)); +} + +template +void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR( + cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( + temp.data(), size, args...)); +} + +void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { + array out = out_; + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + if (axis < 0) { + axis += in.ndim(); + } + int nsort = in.shape(axis); + int nsegments = in.data_size() / nsort; + int last_dim = in.ndim() - 1; + + // If we are not sorting the innermost dimension of a contiguous array, + // transpose and make a copy. + bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1; + if (!is_segmented_sort) { + array trans = swapaxes_in_eval(in, axis, last_dim); + in = array(trans.shape(), trans.dtype(), nullptr, {}); + copy_gpu(trans, in, CopyType::General, s); + encoder.add_temporary(in); + out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + encoder.add_temporary(out); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + if constexpr (!std::is_same_v) { + using Type = cuda_type_t; + auto offsets = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + [nsort] __device__(int i) { return i * nsort; }); + if (argsort) { + // Indices in the sorted dimension. + array indices( + allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + encoder.add_temporary(indices); + thrust::transform( + cu::thrust_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(indices.data_size()), + thrust::device_pointer_cast(indices.data()), + ModOp{static_cast(nsort)}); + + // In argsort though we don't need the result of sorted values, the + // API requires us to provide an array to store it. + array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); + encoder.add_temporary(discard); + + segmented_sort_pairs( + encoder, + in.data(), + discard.data(), + indices.data(), + out.data(), + in.data_size(), + nsegments, + offsets, + offsets + 1, + stream); + } else { + segmented_sort( + encoder, + in.data(), + out.data(), + in.data_size(), + nsegments, + offsets, + offsets + 1, + stream); + } + } else { + throw std::runtime_error( + "CUDA backend does not support sorting complex numbers"); + } + }); + }); + + if (!is_segmented_sort) { + // Swap the sorted axis back. + // TODO: Do in-place transpose instead of using a temporary out array. + copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s); + } +} + +} // namespace + +void ArgSort::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ArgSort::eval_gpu"); + assert(inputs.size() == 1); + gpu_sort(stream(), inputs[0], out, axis_, true); +} + +void Sort::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Sort::eval_gpu"); + assert(inputs.size() == 1); + gpu_sort(stream(), inputs[0], out, axis_, false); +} + +} // namespace mlx::core