Use load_vector in arg_reduce (#2439)

This commit is contained in:
Cheng 2025-07-30 17:40:26 +09:00 committed by GitHub
parent a0ae49d397
commit 3628e5d497
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 21 deletions

View File

@ -44,8 +44,11 @@ struct ArgMin {
} }
template <int N> template <int N>
__device__ IndexValPair<T> __device__ IndexValPair<T> reduce_many(
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) { IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
if (vals[i] < best.val) { if (vals[i] < best.val) {
best.val = vals[i]; best.val = vals[i];
@ -74,8 +77,11 @@ struct ArgMax {
} }
template <int N> template <int N>
__device__ IndexValPair<T> __device__ IndexValPair<T> reduce_many(
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) { IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
if (vals[i] > best.val) { if (vals[i] > best.val) {
best.val = vals[i]; best.val = vals[i];
@ -106,16 +112,15 @@ __global__ void arg_reduce_general(
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim); int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim); int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
in += in_idx;
Op op; Op op;
T init = op.init(); T init = op.init();
IndexValPair<T> best{0, init}; IndexValPair<T> best{0, init};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x; auto tid = r * BLOCK_DIM + block.thread_index().x;
cub::LoadDirectBlocked( auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS); best = op.reduce_many(best, vals, tid * N_READS);
} }

View File

@ -131,20 +131,6 @@ inline __device__ void store_vector(
} }
} }
// Helper for accessing strided data.
template <typename T>
struct StridedIterator {
T it;
int64_t stride;
__host__ __device__ StridedIterator(T it, int64_t stride)
: it(it), stride(stride) {}
__host__ __device__ auto operator[](int i) const {
return it[i * stride];
}
};
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Type limits utils // Type limits utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////