This commit is contained in:
Angelos Katharopoulos
2025-10-06 23:28:29 -07:00
parent ed61cb2802
commit cad47a32e2

View File

@@ -163,14 +163,23 @@ __global__ void row_reduce_looped(
U init = ReduceInit<Op, T>::value(); U init = ReduceInit<Op, T>::value();
total[0] = init; total[0] = init;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim); LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
size_t full_blocks = args.row_size / (block.size() * N_READS); const size_t full_blocks = args.row_size / (block.size() * N_READS);
size_t final_offset = full_blocks * (block.size() * N_READS); const size_t final_offset = full_blocks * (block.size() * N_READS);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
in += block.thread_rank() * N_READS; in += block.thread_rank() * N_READS;
// Unaligned reduce
if (final_offset < args.row_size) {
bool mask[N_READS];
for (int i = 0; i < N_READS; i++) {
mask[i] =
(final_offset + block.thread_rank() * N_READS + i) < args.row_size;
}
for (size_t n = 0; n < args.non_row_reductions; n++) { for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location(); const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) { for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0); auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -178,21 +187,37 @@ __global__ void row_reduce_looped(
} }
inlocal += block.size() * N_READS; inlocal += block.size() * N_READS;
} }
if (final_offset < args.row_size) {
{
T vals[N_READS]; T vals[N_READS];
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
vals[i] = vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
((final_offset + block.thread_rank() * N_READS + i) < args.row_size)
? inlocal[i]
: cast_to<T>(init);
} }
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i])); total[0] = op(total[0], cast_to<U>(vals[i]));
} }
} }
// TODO: Maybe block.sync() here?
loop.next(args.reduce_shape.data(), args.reduce_strides.data()); loop.next(args.reduce_shape.data(), args.reduce_strides.data());
} }
}
// Aligned case
else {
for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
}
__shared__ U shared_accumulators[32]; __shared__ U shared_accumulators[32];
block_reduce(block, warp, total, shared_accumulators, op, init); block_reduce(block, warp, total, shared_accumulators, op, init);
@@ -231,18 +256,9 @@ void row_reduce_simple(
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
if (warps > 128) { warps /= 4;
if (warps > 32) {
warps = 32; warps = 32;
} else {
warps = 16;
}
int best = reductions;
for (int j = warps; j >= 1; j /= 2) {
int t = reductions % (j * WARP_SIZE);
if (t < best) {
warps = j;
best = t;
}
} }
int threads = warps * WARP_SIZE; int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1); dim3 block(threads, 1, 1);
@@ -289,18 +305,9 @@ void row_reduce_looped(
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS; size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
if (warps > 128) { warps /= 4;
if (warps > 32) {
warps = 32; warps = 32;
} else {
warps = 16;
}
int best = reductions;
for (int j = warps; j >= 1; j /= 2) {
int t = reductions % (j * WARP_SIZE);
if (t < best) {
warps = j;
best = t;
}
} }
int threads = warps * WARP_SIZE; int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1); dim3 block(threads, 1, 1);