diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 5b3ec027b..6f5807543 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -850,14 +850,14 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { int M = x.shape(-2); int N = out.shape(-1); int B = out.size() / M / N; + int E = w.size() / w.shape(-1) / w.shape(-2); int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; // We are walking x in order and w is also in order so we can batch up the // matmuls and reuse reading x and w. // - // TODO: Tune 16 here a bit better. Maybe also choose it dynamically based - // on B and (w.size() / K / N). - if (M == 1 && B >= 16 && right_sorted_ == true) { + // TODO: Tune 16 and 8 here a bit better. + if (M == 1 && B >= 16 && right_sorted_ == true && B / E >= 8) { gather_qmm_rhs( x, w,