// Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include #include #include #include namespace mlx::core { namespace { template struct ModOp { T divisor; __device__ T operator()(T x) { return x % divisor; } }; struct OffsetTransform { int nsort; int __device__ operator()(int i) { return i * nsort; } }; void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { array out = out_; auto& encoder = cu::get_command_encoder(s); if (axis < 0) { axis += in.ndim(); } int nsort = in.shape(axis); int last_dim = in.ndim() - 1; // If we are not sorting the innermost dimension of a contiguous array, // transpose and make a copy. bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1; if (!is_segmented_sort) { array trans = swapaxes_in_eval(in, axis, last_dim); in = contiguous_copy_gpu(trans, s); encoder.add_temporary(in); out = array(cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype()); encoder.add_temporary(out); } else { out.set_data( cu::malloc_async(in.data_size() * out.itemsize(), encoder), in.data_size(), in.strides(), in.flags()); } encoder.set_input_array(in); encoder.set_output_array(out); dispatch_all_types(in.dtype(), [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag); auto& stream = encoder.stream(); if constexpr (!std::is_same_v) { using Type = cuda_type_t; auto offsets = thrust::make_transform_iterator( thrust::make_counting_iterator(0), OffsetTransform{nsort}); if (argsort) { // Indices in the sorted dimension. array indices( cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype()); encoder.add_temporary(indices); // In argsort though we don't need the result of sorted values, the // API requires us to provide an array to store it. array discard( cu::malloc_async(in.nbytes(), encoder), in.shape(), in.dtype()); encoder.add_temporary(discard); size_t size; CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs( nullptr, size, gpu_ptr(in), gpu_ptr(discard), gpu_ptr(indices), gpu_ptr(out), in.data_size(), in.data_size() / nsort, offsets, offsets + 1, 0, sizeof(Type) * 8, stream)); array temp( cu::malloc_async(size, encoder), {static_cast(size)}, uint8); encoder.add_temporary(temp); // Start capturing after allocations auto capture = encoder.capture_context(); thrust::transform( cu::thrust_policy(stream), thrust::counting_iterator(0), thrust::counting_iterator(indices.data_size()), thrust::device_pointer_cast(gpu_ptr(indices)), ModOp{static_cast(nsort)}); CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs( gpu_ptr(temp), size, gpu_ptr(in), gpu_ptr(discard), gpu_ptr(indices), gpu_ptr(out), in.data_size(), in.data_size() / nsort, offsets, offsets + 1, 0, sizeof(Type) * 8, stream)); } else { size_t size; CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys( nullptr, size, gpu_ptr(in), gpu_ptr(out), in.data_size(), in.data_size() / nsort, offsets, offsets + 1, 0, sizeof(Type) * 8, stream)); array temp( cu::malloc_async(size, encoder), {static_cast(size)}, uint8); encoder.add_temporary(temp); // Start capturing after allocations auto capture = encoder.capture_context(); CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys( gpu_ptr(temp), size, gpu_ptr(in), gpu_ptr(out), in.data_size(), in.data_size() / nsort, offsets, offsets + 1, 0, sizeof(Type) * 8, stream)); } } else { throw std::runtime_error( "CUDA backend does not support sorting complex numbers"); } }); if (!is_segmented_sort) { // Swap the sorted axis back. // TODO: Do in-place transpose instead of using a temporary out array. copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s); } } } // namespace void ArgSort::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ArgSort::eval_gpu"); assert(inputs.size() == 1); gpu_sort(stream(), inputs[0], out, axis_, true); } void Sort::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Sort::eval_gpu"); assert(inputs.size() == 1); gpu_sort(stream(), inputs[0], out, axis_, false); } void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ArgPartition::eval_gpu"); gpu_sort(stream(), inputs[0], out, axis_, true); } void Partition::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Partition::eval_gpu"); gpu_sort(stream(), inputs[0], out, axis_, false); } } // namespace mlx::core