From 9ba81e3da4c350c9308f2a3992464e15c15ff608 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 2 Apr 2025 20:05:54 -0700 Subject: [PATCH] tune quant dispatch (#2031) --- mlx/backend/metal/kernels/quantized.h | 6 +-- mlx/backend/metal/quantized.cpp | 54 +++++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 3af3c971f..af9d7860e 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -586,13 +586,13 @@ METAL_FUNC void qmv_quad_impl( // Adjust positions const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid; + const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; w += out_row * in_vec_size_w + quad_lid * packs_per_thread; scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; - x += tid.y * in_vec_size + quad_lid * values_per_thread; - y += tid.y * out_vec_size + out_row; + x += tid.x * in_vec_size + quad_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; U sum = load_vector(x, x_thread); diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index cc32797eb..8d1d176c4 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -298,23 +298,69 @@ void qmm_op( bool aligned = false; bool quad = false; + auto get_qmv_batch_limit = [s](int D, int O) { + auto arch = metal::device(s.device).get_architecture(); + auto arch_size = arch.back(); + auto arch_gen = arch.substr(arch.size() - 3, 2); + if (arch_gen == "13" || arch_gen == "14") { + switch (arch_size) { + case 'd': + if (D <= 2048 && O <= 2048) { + return 32; + } else if (D <= 4096 && O <= 4096) { + return 18; + } else { + return 12; + } + default: + if (D <= 2048 && O <= 2048) { + return 14; + } else if (D <= 4096 && O <= 4096) { + return 10; + } else { + return 6; + } + } + } else { + switch (arch_size) { + case 'd': + if (D <= 2048 && O <= 2048) { + return 32; + } else if (D <= 4096 && O <= 4096) { + return 18; + } else { + return 12; + } + default: + if (D <= 2048 && O <= 2048) { + return 18; + } else if (D <= 4096 && O <= 4096) { + return 12; + } else { + return 10; + } + } + } + }; + if (transpose) { - if (B < 6 && (D == 128 || D == 64) && is_power_of_2(bits)) { + auto qmv_batch_limit = get_qmv_batch_limit(D, O); + if (B < qmv_batch_limit && (D == 128 || D == 64) && is_power_of_2(bits)) { name += "qmv_quad"; constexpr int quads_per_simd = 8; constexpr int results_per_quadgroup = 8; int bo = quads_per_simd * results_per_quadgroup; int simdgroup_size = 32; group_dims = MTL::Size(simdgroup_size, 1, 1); - grid_dims = MTL::Size((O + bo - 1) / bo, B, N); + grid_dims = MTL::Size(B, (O + bo - 1) / bo, N); quad = true; - } else if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { + } else if (B < qmv_batch_limit && O % 8 == 0 && D % 512 == 0 && D >= 512) { name += "qmv_fast"; int bo = 8; int bd = 32; group_dims = MTL::Size(bd, 2, 1); grid_dims = MTL::Size(B, O / bo, N); - } else if (B < 6) { + } else if (B < qmv_batch_limit) { name += "qmv"; int bo = 8; int bd = 32;