Fix warp underflow

This commit is contained in:
Angelos Katharopoulos
2025-10-07 00:52:46 -07:00
parent cad47a32e2
commit 3f86389409

View File

@@ -257,9 +257,7 @@ void row_reduce_simple(
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
warps /= 4;
if (warps > 32) {
warps = 32;
}
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1);
@@ -306,9 +304,7 @@ void row_reduce_looped(
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
warps /= 4;
if (warps > 32) {
warps = 32;
}
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1);