From 333ffea2730bb789a6eecdcaa78f1e808607c4be Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 24 Aug 2025 16:22:36 +0900 Subject: [PATCH] [CUDA] Remove thrust in arange (#2535) --- mlx/backend/cuda/arange.cu | 52 ++++++++++++++++++----------- mlx/backend/cuda/device/cast_op.cuh | 12 ------- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/mlx/backend/cuda/arange.cu b/mlx/backend/cuda/arange.cu index 6190ac272..a28a245db 100644 --- a/mlx/backend/cuda/arange.cu +++ b/mlx/backend/cuda/arange.cu @@ -6,23 +6,33 @@ #include "mlx/dtype_utils.h" #include "mlx/primitives.h" +#include #include -#include -#include namespace mlx::core { namespace cu { -template -struct Arange { - const T start; - const T step; +namespace cg = cooperative_groups; - __device__ T operator()(uint32_t i) const { - return start + i * step; +template +__global__ void arange(T* out, IdxT size, T start, T step) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_WRITES > size) { + for (IdxT i = index * N_WRITES; i < size; ++i) { + out[i] = start + i * step; + } + } else { + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_WRITES; ++i) { + out_vec[i] = start + (index * N_WRITES + i) * step; + } + + store_vector(out, index, out_vec); } -}; +} } // namespace cu @@ -36,19 +46,23 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(stream()); encoder.set_output_array(out); - auto capture = encoder.capture_context(); dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag); using OutType = cuda_type_t; - CTYPE step = - static_cast(start_ + step_) - static_cast(start_); - thrust::transform( - cu::thrust_policy(encoder.stream()), - thrust::counting_iterator(0), - thrust::counting_iterator(out.data_size()), - thrust::device_pointer_cast(out.data()), - cu::Arange{ - static_cast(start_), static_cast(step)}); + constexpr int N_WRITES = 16 / sizeof(OutType); + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES); + encoder.add_kernel_node( + cu::arange, + num_blocks, + block_dims, + 0, + out.data(), + out.data_size(), + static_cast(start_), + static_cast(start_ + step_) - static_cast(start_)); + }); }); } diff --git a/mlx/backend/cuda/device/cast_op.cuh b/mlx/backend/cuda/device/cast_op.cuh index e10fde6dc..b85a40c83 100644 --- a/mlx/backend/cuda/device/cast_op.cuh +++ b/mlx/backend/cuda/device/cast_op.cuh @@ -6,7 +6,6 @@ #include #include -#include namespace mlx::core::cu { @@ -116,15 +115,4 @@ inline __host__ __device__ auto cast_to(SrcT x) { return CastOp{}(x); } -// Return an iterator that cast the value to DstT using CastOp. -template -inline __host__ __device__ auto make_cast_iterator(Iterator it) { - using SrcT = typename cuda::std::iterator_traits::value_type; - if constexpr (std::is_same_v) { - return it; - } else { - return thrust::make_transform_iterator(it, CastOp{}); - } -} - } // namespace mlx::core::cu