diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index eb325a987..63940b2d4 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -338,7 +338,7 @@ void col_reduce( // Small col reduce with a single or contiguous reduction axis if (args.non_col_reductions == 1 && args.reduction_size <= 32 && - args.reduction_stride % 4 == 0) { + args.reduction_stride % (16 / in.itemsize()) == 0) { col_reduce_small( encoder, in, out, reduce_type, axes, plan, std::move(args)); return;