diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 744105812..c62b84206 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -449,7 +449,8 @@ void row_reduce_general_dispatch( } // Case 2: Contiguous reduce without non-row reductions - if (plan.type == ContiguousReduce && args.reduce_ndim == 0) { + if (plan.type == ContiguousReduce && args.reduce_ndim == 0 && + in.size() / args.row_size >= 32) { return row_reduce_simple(in, out, op_name, args, compute_encoder, d, s); }