This commit is contained in:
Angelos Katharopoulos
2025-10-06 23:28:29 -07:00
parent ed61cb2802
commit cad47a32e2

View File

@@ -163,35 +163,60 @@ __global__ void row_reduce_looped(
U init = ReduceInit<Op, T>::value();
total[0] = init;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
size_t full_blocks = args.row_size / (block.size() * N_READS);
size_t final_offset = full_blocks * (block.size() * N_READS);
const size_t full_blocks = args.row_size / (block.size() * N_READS);
const size_t final_offset = full_blocks * (block.size() * N_READS);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
in += block.thread_rank() * N_READS;
for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
// Unaligned reduce
if (final_offset < args.row_size) {
bool mask[N_READS];
for (int i = 0; i < N_READS; i++) {
mask[i] =
(final_offset + block.thread_rank() * N_READS + i) < args.row_size;
}
if (final_offset < args.row_size) {
T vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] =
((final_offset + block.thread_rank() * N_READS + i) < args.row_size)
? inlocal[i]
: cast_to<T>(init);
for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
}
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
{
T vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
}
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
}
// Aligned case
else {
for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
// TODO: Maybe block.sync() here?
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
__shared__ U shared_accumulators[32];
@@ -231,18 +256,9 @@ void row_reduce_simple(
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
if (warps > 128) {
warps /= 4;
if (warps > 32) {
warps = 32;
} else {
warps = 16;
}
int best = reductions;
for (int j = warps; j >= 1; j /= 2) {
int t = reductions % (j * WARP_SIZE);
if (t < best) {
warps = j;
best = t;
}
}
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1);
@@ -289,18 +305,9 @@ void row_reduce_looped(
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
if (warps > 128) {
warps /= 4;
if (warps > 32) {
warps = 32;
} else {
warps = 16;
}
int best = reductions;
for (int j = warps; j >= 1; j /= 2) {
int t = reductions % (j * WARP_SIZE);
if (t < best) {
warps = j;
best = t;
}
}
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1);