From c7cdd51f50a1c0c25a7a7a667ed8cdf8f01ff425 Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Wed, 23 Jul 2025 10:12:38 +0800 Subject: [PATCH] Improve perf --- mlx/backend/metal/kernels/quantized.h | 36 ++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 2f7ceb90e..afbdab7e8 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -225,6 +225,25 @@ inline U qdot( return scale * accum + sum * bias; } +inline float qdot_bit4( + const device uint16_t* w, + const thread float* x_thread, + float scale, + float bias, + float sum) { + + float accum = 0; + for (int i = 0; i < 4; i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x000f) + + x_thread[4 * i + 1] * (w[i] & 0x00f0) + + x_thread[4 * i + 2] * (w[i] & 0x0f00) + + x_thread[4 * i + 3] * (w[i] & 0xf000)); + } + + return scale * accum + sum * bias; +} + template inline U qdot_safe( const device uint8_t* w, @@ -728,35 +747,40 @@ METAL_FUNC void qmv_no_parallel_m_impl( ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - // x += tid.x * in_vec_size + simd_lid * values_per_thread; - // y += tid.x * out_vec_size + out_row; + x += simd_lid * values_per_thread; + y += out_row; for (int k = 0; k < k_size; k += block_size) { // U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device uint16_t* wb = (const device uint16_t*)wl; const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; for (int col = 0; col < m_size; col++) { - auto x_temp = x + col * k_size + simd_lid * values_per_thread + k; + auto x_temp = x + col * k_size; U sum = load_vector(x_temp, x_thread); - result[col * results_per_simdgroup + row] += qdot(wl, x_thread, s, b, sum); + if (bits == 4) { + result[col * results_per_simdgroup + row] += qdot_bit4(wb, x_thread, s, b, sum); + } else { + result[col * results_per_simdgroup + row] += qdot(wl, x_thread, s, b, sum); + } } } ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; - // x += block_size; + x += block_size; } for (int row = 0; row < results_per_simdgroup; row++) { for (int col = 0; col < m_size; col++) { result[col * results_per_simdgroup + row] = simd_sum(result[col * results_per_simdgroup + row]); - auto y_temp = y + col * n_size + out_row; + auto y_temp = y + col * n_size; if (simd_lid == 0) { y_temp[row] = static_cast(result[col * results_per_simdgroup + row]); }