From f55c4ed1d67c09df0a7b0702a35c0bfc4081088e Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 21 Jul 2025 23:30:27 +0900 Subject: [PATCH] Remove thrust iterators (#2396) --- mlx/backend/cuda/arg_reduce.cu | 4 +- mlx/backend/cuda/device/utils.cuh | 14 ++ .../cuda/iterators/general_iterator.cuh | 121 ------------------ .../cuda/iterators/strided_iterator.cuh | 60 --------- mlx/backend/cuda/layer_norm.cu | 9 +- mlx/backend/cuda/rms_norm.cu | 7 +- mlx/backend/cuda/unary.cu | 1 - 7 files changed, 23 insertions(+), 193 deletions(-) delete mode 100644 mlx/backend/cuda/iterators/general_iterator.cuh delete mode 100644 mlx/backend/cuda/iterators/strided_iterator.cuh diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 67ef5d968..74108e00b 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -1,8 +1,8 @@ // Copyright © 2025 Apple Inc. + #include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/fp16_math.cuh" -#include "mlx/backend/cuda/iterators/strided_iterator.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -115,7 +115,7 @@ __global__ void arg_reduce_general( T vals[N_READS]; auto tid = r * BLOCK_DIM + block.thread_index().x; cub::LoadDirectBlocked( - tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init); + tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init); best = op.reduce_many(best, vals, tid * N_READS); } diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 73bc7ff63..3745637da 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -49,6 +49,20 @@ store_vector(T* ptr, uint32_t offset, const AlignedVector& vec) { to[offset] = vec; } +// Helper for accessing strided data. +template +struct StridedIterator { + T it; + int64_t stride; + + __host__ __device__ StridedIterator(T it, int64_t stride) + : it(it), stride(stride) {} + + __host__ __device__ auto operator[](int i) const { + return it[i * stride]; + } +}; + /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/iterators/general_iterator.cuh b/mlx/backend/cuda/iterators/general_iterator.cuh deleted file mode 100644 index 3c8c098c3..000000000 --- a/mlx/backend/cuda/iterators/general_iterator.cuh +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -#include -#include - -#include "mlx/backend/cuda/kernel_utils.cuh" - -namespace mlx::core::cu { - -// Iterating non-contiguous array. -template -class general_iterator - : public thrust:: - iterator_adaptor, Iterator> { - public: - using super_t = - thrust::iterator_adaptor, Iterator>; - - using reference = typename super_t::reference; - using difference_type = typename super_t::difference_type; - - __host__ __device__ general_iterator( - Iterator it, - IdxT index, - int ndim, - Shape shape, - Strides strides) - : super_t(it), - index_(index), - ndim_(ndim), - shape_(cuda::std::move(shape)), - strides_(cuda::std::move(strides)) {} - - __host__ __device__ IdxT index() const { - return index_; - } - - __host__ __device__ const Shape& shape() const { - return shape_; - } - - __host__ __device__ const Strides& strides() const { - return strides_; - } - - private: - friend class thrust::iterator_core_access; - - __host__ __device__ bool equal(const general_iterator& other) const { - return this->base() == other.base() && this->index() == other.index(); - } - - __host__ __device__ void advance(difference_type n) { - this->index_ += n; - } - - __host__ __device__ void increment() { - this->index_ += 1; - } - - __host__ __device__ void decrement() { - this->index_ -= 1; - } - - __host__ __device__ difference_type - distance_to(const general_iterator& other) const { - _CCCL_ASSERT( - this->base() == other.base(), - "Underlying iterator must point to same base iterator"); - return other.index() - this->index(); - } - - // The dereference is device-only to avoid accidental running in host. - __device__ typename super_t::reference dereference() const { - IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_); - return *(this->base() + offset); - } - - IdxT index_; - int ndim_; - Shape shape_; - Strides strides_; -}; - -template -__host__ __device__ auto make_general_iterator( - Iterator it, - IdxT index, - int ndim, - Shape shape, - Strides strides) { - return general_iterator( - it, index, ndim, cuda::std::move(shape), cuda::std::move(strides)); -} - -template -auto make_general_iterator( - Iterator it, - const std::vector& shape, - const std::vector& strides) { - return make_general_iterator( - it, 0, shape.size(), const_param(shape), const_param(strides)); -} - -template -auto make_general_iterators( - Iterator it, - IdxT size, - const std::vector& shape, - const std::vector& strides) { - auto ndim = shape.size(); - auto shape_arg = const_param(shape); - auto strides_arg = const_param(strides); - return std::make_pair( - make_general_iterator(it, 0, ndim, shape_arg, strides_arg), - make_general_iterator(it, size, ndim, shape_arg, strides_arg)); -} - -} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/iterators/strided_iterator.cuh b/mlx/backend/cuda/iterators/strided_iterator.cuh deleted file mode 100644 index 3ef8d66bd..000000000 --- a/mlx/backend/cuda/iterators/strided_iterator.cuh +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -#include -#include - -namespace mlx::core::cu { - -// RandomAccessIterator for strided access to array entries. -template -class strided_iterator - : public thrust:: - iterator_adaptor, Iterator> { - public: - using super_t = - thrust::iterator_adaptor, Iterator>; - - using reference = typename super_t::reference; - using difference_type = typename super_t::difference_type; - - __host__ __device__ strided_iterator(Iterator it, Stride stride) - : super_t(it), stride_(stride) {} - - __host__ __device__ Stride stride() const { - return stride_; - } - - private: - friend class thrust::iterator_core_access; - - __host__ __device__ bool equal(const strided_iterator& other) const { - return this->base() == other.base(); - } - - __host__ __device__ void advance(difference_type n) { - this->base_reference() += n * stride_; - } - - __host__ __device__ void increment() { - this->base_reference() += stride_; - } - - __host__ __device__ void decrement() { - this->base_reference() -= stride_; - } - - __host__ __device__ difference_type - distance_to(const strided_iterator& other) const { - const difference_type dist = other.base() - this->base(); - _CCCL_ASSERT( - dist % stride() == 0, - "Underlying iterator difference must be divisible by the stride"); - return dist / stride(); - } - - Stride stride_; -}; - -} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 83a9c2a67..fdb63d64c 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -1,7 +1,6 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/iterators/strided_iterator.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/gpu/copy.h" @@ -105,8 +104,8 @@ __global__ void layer_norm( T wn[N_READS]; T bn[N_READS]; cub::LoadDirectBlocked(index, x, xn, axis_size); - cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); - cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size); + cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size); + cub::LoadDirectBlocked(index, StridedIterator(b, b_stride), bn, axis_size); for (int i = 0; i < N_READS; ++i) { float norm = (static_cast(xn[i]) - mean) * normalizer; xn[i] = wn[i] * static_cast(norm) + bn[i]; @@ -162,7 +161,7 @@ __global__ void layer_norm_vjp( auto index = r * BLOCK_DIM + block.thread_rank(); cub::LoadDirectBlocked(index, x, xn, axis_size, mean); cub::LoadDirectBlocked(index, g, gn, axis_size); - cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size); for (int i = 0; i < N_READS; i++) { float t = static_cast(xn[i]) - mean; float wi = wn[i]; @@ -185,7 +184,7 @@ __global__ void layer_norm_vjp( T gn[N_READS]; cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size); - cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size); for (int i = 0; i < N_READS; i++) { float xi = (static_cast(xn[i]) - mean) * normalizer; float wi = wn[i]; diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 66b759b5e..48d6a8281 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -1,7 +1,6 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/iterators/strided_iterator.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/gpu/copy.h" @@ -89,7 +88,7 @@ __global__ void rms_norm( T xn[N_READS]; T wn[N_READS]; cub::LoadDirectBlocked(index, x, xn, axis_size); - cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size); for (int i = 0; i < N_READS; ++i) { float norm = static_cast(xn[i]) * normalizer; xn[i] = wn[i] * static_cast(norm); @@ -132,7 +131,7 @@ __global__ void rms_norm_vjp( auto index = r * BLOCK_DIM + block.thread_rank(); cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to(0)); cub::LoadDirectBlocked(index, g, gn, axis_size); - cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size); for (int i = 0; i < N_READS; i++) { float t = static_cast(xn[i]); float wi = wn[i]; @@ -154,7 +153,7 @@ __global__ void rms_norm_vjp( T gn[N_READS]; cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size); - cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size); for (int i = 0; i < N_READS; i++) { float xi = xn[i]; float wi = wn[i]; diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index ddb32d05e..6b7c94bb8 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -3,7 +3,6 @@ #include "mlx/backend/common/unary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/unary_ops.cuh" -#include "mlx/backend/cuda/iterators/general_iterator.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h"