From 3628e5d4979b634c009f4fbbebfc7ffa17bf5052 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 30 Jul 2025 17:40:26 +0900 Subject: [PATCH] Use load_vector in arg_reduce (#2439) --- mlx/backend/cuda/arg_reduce.cu | 19 ++++++++++++------- mlx/backend/cuda/device/utils.cuh | 14 -------------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 74108e00b..321bd66b4 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -44,8 +44,11 @@ struct ArgMin { } template - __device__ IndexValPair - reduce_many(IndexValPair best, T (&vals)[N], uint32_t offset) { + __device__ IndexValPair reduce_many( + IndexValPair best, + const AlignedVector& vals, + uint32_t offset) { +#pragma unroll for (int i = 0; i < N; i++) { if (vals[i] < best.val) { best.val = vals[i]; @@ -74,8 +77,11 @@ struct ArgMax { } template - __device__ IndexValPair - reduce_many(IndexValPair best, T (&vals)[N], uint32_t offset) { + __device__ IndexValPair reduce_many( + IndexValPair best, + const AlignedVector& vals, + uint32_t offset) { +#pragma unroll for (int i = 0; i < N; i++) { if (vals[i] > best.val) { best.val = vals[i]; @@ -106,16 +112,15 @@ __global__ void arg_reduce_general( int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim); int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim); + in += in_idx; Op op; T init = op.init(); IndexValPair best{0, init}; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { - T vals[N_READS]; auto tid = r * BLOCK_DIM + block.thread_index().x; - cub::LoadDirectBlocked( - tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init); + auto vals = load_vector(in, tid, axis_size, axis_stride, init); best = op.reduce_many(best, vals, tid * N_READS); } diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 8dd8f0276..f9a5c4e06 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -131,20 +131,6 @@ inline __device__ void store_vector( } } -// Helper for accessing strided data. -template -struct StridedIterator { - T it; - int64_t stride; - - __host__ __device__ StridedIterator(T it, int64_t stride) - : it(it), stride(stride) {} - - __host__ __device__ auto operator[](int i) const { - return it[i * stride]; - } -}; - /////////////////////////////////////////////////////////////////////////////// // Type limits utils ///////////////////////////////////////////////////////////////////////////////