mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
Use load_vector in arg_reduce (#2439)
This commit is contained in:
parent
a0ae49d397
commit
3628e5d497
@ -44,8 +44,11 @@ struct ArgMin {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int N>
|
template <int N>
|
||||||
__device__ IndexValPair<T>
|
__device__ IndexValPair<T> reduce_many(
|
||||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
IndexValPair<T> best,
|
||||||
|
const AlignedVector<T, N>& vals,
|
||||||
|
uint32_t offset) {
|
||||||
|
#pragma unroll
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
if (vals[i] < best.val) {
|
if (vals[i] < best.val) {
|
||||||
best.val = vals[i];
|
best.val = vals[i];
|
||||||
@ -74,8 +77,11 @@ struct ArgMax {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int N>
|
template <int N>
|
||||||
__device__ IndexValPair<T>
|
__device__ IndexValPair<T> reduce_many(
|
||||||
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
|
IndexValPair<T> best,
|
||||||
|
const AlignedVector<T, N>& vals,
|
||||||
|
uint32_t offset) {
|
||||||
|
#pragma unroll
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
if (vals[i] > best.val) {
|
if (vals[i] > best.val) {
|
||||||
best.val = vals[i];
|
best.val = vals[i];
|
||||||
@ -106,16 +112,15 @@ __global__ void arg_reduce_general(
|
|||||||
|
|
||||||
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
||||||
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
||||||
|
in += in_idx;
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
T init = op.init();
|
T init = op.init();
|
||||||
IndexValPair<T> best{0, init};
|
IndexValPair<T> best{0, init};
|
||||||
|
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
T vals[N_READS];
|
|
||||||
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||||
cub::LoadDirectBlocked(
|
auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
|
||||||
tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
|
|
||||||
best = op.reduce_many(best, vals, tid * N_READS);
|
best = op.reduce_many(best, vals, tid * N_READS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,20 +131,6 @@ inline __device__ void store_vector(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper for accessing strided data.
|
|
||||||
template <typename T>
|
|
||||||
struct StridedIterator {
|
|
||||||
T it;
|
|
||||||
int64_t stride;
|
|
||||||
|
|
||||||
__host__ __device__ StridedIterator(T it, int64_t stride)
|
|
||||||
: it(it), stride(stride) {}
|
|
||||||
|
|
||||||
__host__ __device__ auto operator[](int i) const {
|
|
||||||
return it[i * stride];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Type limits utils
|
// Type limits utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
Loading…
Reference in New Issue
Block a user