mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-26 16:25:19 +08:00
Remove thrust iterators (#2396)
This commit is contained in:
parent
93d70419e7
commit
f55c4ed1d6
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@ -115,7 +115,7 @@ __global__ void arg_reduce_general(
|
|||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||||
best = op.reduce_many(best, vals, tid * N_READS);
|
best = op.reduce_many(best, vals, tid * N_READS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,6 +49,20 @@ store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
|||||||
to[offset] = vec;
|
to[offset] = vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper for accessing strided data.
|
||||||
|
template <typename T>
|
||||||
|
struct StridedIterator {
|
||||||
|
T it;
|
||||||
|
int64_t stride;
|
||||||
|
|
||||||
|
__host__ __device__ StridedIterator(T it, int64_t stride)
|
||||||
|
: it(it), stride(stride) {}
|
||||||
|
|
||||||
|
__host__ __device__ auto operator[](int i) const {
|
||||||
|
return it[i * stride];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Type limits utils
|
// Type limits utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1,121 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <thrust/iterator/iterator_adaptor.h>
|
|
||||||
#include <cuda/std/utility>
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
// Iterating non-contiguous array.
|
|
||||||
template <typename Iterator, typename IdxT = int64_t>
|
|
||||||
class general_iterator
|
|
||||||
: public thrust::
|
|
||||||
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
|
|
||||||
public:
|
|
||||||
using super_t =
|
|
||||||
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
|
|
||||||
|
|
||||||
using reference = typename super_t::reference;
|
|
||||||
using difference_type = typename super_t::difference_type;
|
|
||||||
|
|
||||||
__host__ __device__ general_iterator(
|
|
||||||
Iterator it,
|
|
||||||
IdxT index,
|
|
||||||
int ndim,
|
|
||||||
Shape shape,
|
|
||||||
Strides strides)
|
|
||||||
: super_t(it),
|
|
||||||
index_(index),
|
|
||||||
ndim_(ndim),
|
|
||||||
shape_(cuda::std::move(shape)),
|
|
||||||
strides_(cuda::std::move(strides)) {}
|
|
||||||
|
|
||||||
__host__ __device__ IdxT index() const {
|
|
||||||
return index_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ const Shape& shape() const {
|
|
||||||
return shape_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ const Strides& strides() const {
|
|
||||||
return strides_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class thrust::iterator_core_access;
|
|
||||||
|
|
||||||
__host__ __device__ bool equal(const general_iterator& other) const {
|
|
||||||
return this->base() == other.base() && this->index() == other.index();
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void advance(difference_type n) {
|
|
||||||
this->index_ += n;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void increment() {
|
|
||||||
this->index_ += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void decrement() {
|
|
||||||
this->index_ -= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ difference_type
|
|
||||||
distance_to(const general_iterator& other) const {
|
|
||||||
_CCCL_ASSERT(
|
|
||||||
this->base() == other.base(),
|
|
||||||
"Underlying iterator must point to same base iterator");
|
|
||||||
return other.index() - this->index();
|
|
||||||
}
|
|
||||||
|
|
||||||
// The dereference is device-only to avoid accidental running in host.
|
|
||||||
__device__ typename super_t::reference dereference() const {
|
|
||||||
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
|
|
||||||
return *(this->base() + offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
IdxT index_;
|
|
||||||
int ndim_;
|
|
||||||
Shape shape_;
|
|
||||||
Strides strides_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename IdxT, typename Iterator>
|
|
||||||
__host__ __device__ auto make_general_iterator(
|
|
||||||
Iterator it,
|
|
||||||
IdxT index,
|
|
||||||
int ndim,
|
|
||||||
Shape shape,
|
|
||||||
Strides strides) {
|
|
||||||
return general_iterator<Iterator, IdxT>(
|
|
||||||
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename IdxT, typename Iterator>
|
|
||||||
auto make_general_iterator(
|
|
||||||
Iterator it,
|
|
||||||
const std::vector<int32_t>& shape,
|
|
||||||
const std::vector<int64_t>& strides) {
|
|
||||||
return make_general_iterator<IdxT>(
|
|
||||||
it, 0, shape.size(), const_param(shape), const_param(strides));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename IdxT, typename Iterator>
|
|
||||||
auto make_general_iterators(
|
|
||||||
Iterator it,
|
|
||||||
IdxT size,
|
|
||||||
const std::vector<int32_t>& shape,
|
|
||||||
const std::vector<int64_t>& strides) {
|
|
||||||
auto ndim = shape.size();
|
|
||||||
auto shape_arg = const_param(shape);
|
|
||||||
auto strides_arg = const_param(strides);
|
|
||||||
return std::make_pair(
|
|
||||||
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
|
|
||||||
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
@ -1,60 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <thrust/iterator/iterator_adaptor.h>
|
|
||||||
#include <thrust/iterator/iterator_facade.h>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
// RandomAccessIterator for strided access to array entries.
|
|
||||||
template <typename Iterator, typename Stride = int64_t>
|
|
||||||
class strided_iterator
|
|
||||||
: public thrust::
|
|
||||||
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
|
|
||||||
public:
|
|
||||||
using super_t =
|
|
||||||
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
|
|
||||||
|
|
||||||
using reference = typename super_t::reference;
|
|
||||||
using difference_type = typename super_t::difference_type;
|
|
||||||
|
|
||||||
__host__ __device__ strided_iterator(Iterator it, Stride stride)
|
|
||||||
: super_t(it), stride_(stride) {}
|
|
||||||
|
|
||||||
__host__ __device__ Stride stride() const {
|
|
||||||
return stride_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class thrust::iterator_core_access;
|
|
||||||
|
|
||||||
__host__ __device__ bool equal(const strided_iterator& other) const {
|
|
||||||
return this->base() == other.base();
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void advance(difference_type n) {
|
|
||||||
this->base_reference() += n * stride_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void increment() {
|
|
||||||
this->base_reference() += stride_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void decrement() {
|
|
||||||
this->base_reference() -= stride_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ difference_type
|
|
||||||
distance_to(const strided_iterator& other) const {
|
|
||||||
const difference_type dist = other.base() - this->base();
|
|
||||||
_CCCL_ASSERT(
|
|
||||||
dist % stride() == 0,
|
|
||||||
"Underlying iterator difference must be divisible by the stride");
|
|
||||||
return dist / stride();
|
|
||||||
}
|
|
||||||
|
|
||||||
Stride stride_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
@ -105,8 +104,8 @@ __global__ void layer_norm(
|
|||||||
T wn[N_READS];
|
T wn[N_READS];
|
||||||
T bn[N_READS];
|
T bn[N_READS];
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(b, b_stride), bn, axis_size);
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||||
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
||||||
@ -162,7 +161,7 @@ __global__ void layer_norm_vjp(
|
|||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float t = static_cast<float>(xn[i]) - mean;
|
float t = static_cast<float>(xn[i]) - mean;
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
@ -185,7 +184,7 @@ __global__ void layer_norm_vjp(
|
|||||||
T gn[N_READS];
|
T gn[N_READS];
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
@ -89,7 +88,7 @@ __global__ void rms_norm(
|
|||||||
T xn[N_READS];
|
T xn[N_READS];
|
||||||
T wn[N_READS];
|
T wn[N_READS];
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
float norm = static_cast<float>(xn[i]) * normalizer;
|
float norm = static_cast<float>(xn[i]) * normalizer;
|
||||||
xn[i] = wn[i] * static_cast<T>(norm);
|
xn[i] = wn[i] * static_cast<T>(norm);
|
||||||
@ -132,7 +131,7 @@ __global__ void rms_norm_vjp(
|
|||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float t = static_cast<float>(xn[i]);
|
float t = static_cast<float>(xn[i]);
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
@ -154,7 +153,7 @@ __global__ void rms_norm_vjp(
|
|||||||
T gn[N_READS];
|
T gn[N_READS];
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float xi = xn[i];
|
float xi = xn[i];
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include "mlx/backend/common/unary.h"
|
#include "mlx/backend/common/unary.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
Loading…
Reference in New Issue
Block a user