diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 2e841cfcd..1ae46d0a3 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -257,9 +257,7 @@ void row_reduce_simple( dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; warps /= 4; - if (warps > 32) { - warps = 32; - } + warps = std::max(std::min(warps, 32), 1); int threads = warps * WARP_SIZE; dim3 block(threads, 1, 1); @@ -306,9 +304,7 @@ void row_reduce_looped( size_t reductions = (args.row_size + N_READS - 1) / N_READS; int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; warps /= 4; - if (warps > 32) { - warps = 32; - } + warps = std::max(std::min(warps, 32), 1); int threads = warps * WARP_SIZE; dim3 block(threads, 1, 1);