From 3f8638940974ac588db77b1a2faed7866dba3ecc Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 7 Oct 2025 00:52:46 -0700 Subject: [PATCH] Fix warp underflow --- mlx/backend/cuda/reduce/row_reduce.cu | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 2e841cfcd..1ae46d0a3 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -257,9 +257,7 @@ void row_reduce_simple( dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; warps /= 4; - if (warps > 32) { - warps = 32; - } + warps = std::max(std::min(warps, 32), 1); int threads = warps * WARP_SIZE; dim3 block(threads, 1, 1); @@ -306,9 +304,7 @@ void row_reduce_looped( size_t reductions = (args.row_size + N_READS - 1) / N_READS; int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; warps /= 4; - if (warps > 32) { - warps = 32; - } + warps = std::max(std::min(warps, 32), 1); int threads = warps * WARP_SIZE; dim3 block(threads, 1, 1);