From b2f0ebe9ee41a97dde192037b086943077e5e9f8 Mon Sep 17 00:00:00 2001 From: tianyi Date: Tue, 22 Jul 2025 12:54:04 +0800 Subject: [PATCH] [Feature]Add no parallel-m qmm kernel to improve decoding performance --- mlx/backend/metal/kernels/quantized.h | 129 ++++++++++++++++++++++ mlx/backend/metal/kernels/quantized.metal | 1 + mlx/backend/metal/quantized.cpp | 55 ++++++++- 3 files changed, 184 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index b2b0d8d8f..2f7ceb90e 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -688,6 +688,82 @@ METAL_FUNC void qmv_fast_impl( } } +template +METAL_FUNC void qmv_no_parallel_m_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const constant int& m_size, + const constant int& k_size, + const constant int& n_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int packs_per_thread = bits == 2 ? 1 : 2; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + constexpr int max_batch = 10; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[max_batch * results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = k_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = k_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + // x += tid.x * in_vec_size + simd_lid * values_per_thread; + // y += tid.x * out_vec_size + out_row; + + for (int k = 0; k < k_size; k += block_size) { + // U sum = load_vector(x, x_thread); + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + U s = sl[0]; + U b = bl[0]; + + for (int col = 0; col < m_size; col++) { + auto x_temp = x + col * k_size + simd_lid * values_per_thread + k; + U sum = load_vector(x_temp, x_thread); + result[col * results_per_simdgroup + row] += qdot(wl, x_thread, s, b, sum); + } + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + // x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + for (int col = 0; col < m_size; col++) { + result[col * results_per_simdgroup + row] = simd_sum(result[col * results_per_simdgroup + row]); + auto y_temp = y + col * n_size + out_row; + if (simd_lid == 0) { + y_temp[row] = static_cast(result[col * results_per_simdgroup + row]); + } + } + } +} + template METAL_FUNC void qmv_impl( const device uint32_t* w, @@ -1410,6 +1486,59 @@ template simd_lid); } +template +[[kernel]] void qmv_no_parallel_m( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& m_size [[buffer(5)]], + const constant int& k_size [[buffer(6)]], + const constant int& n_size [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant int64_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + n_size * m_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmv_no_parallel_m_impl( + w, + scales, + biases, + x, + y, + m_size, + k_size, + n_size, + tid, + simd_gid, + simd_lid); +} + template [[kernel]] void qmv( const device uint32_t* w [[buffer(0)]], diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 11cd8421b..0faf7fbc3 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -80,6 +80,7 @@ #define instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \ + instantiate_quantized_batched_wrap(qmv_no_parallel_m, type, group_size, bits) \ instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \ instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \ instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits) diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 6f5807543..e631d2d7b 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -244,6 +244,59 @@ void qmv( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +void qmv_no_parallel_m( + const array& x, + const array& w, + const array& scales, + const array& biases, + array& out, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int B = out.size() / M / N; + + int bn = 128; + // int bk = 32; + MTL::Size group_dims(2, 1, 1); + MTL::Size grid_dims((N + bn - 1) / bn, 1, B); + + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + // bool fast = N % bn == 0 && K % 512 == 0; + concatenate( + kname, + "qmv_no_parallel_m_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + B > 1 ? "_batch_1" : "_batch_0"); + auto template_def = get_template_definition( + kname, "qmv_no_parallel_m", type_string, group_size, bits, B > 1); + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(M, 5); + compute_encoder.set_bytes(K, 6); + compute_encoder.set_bytes(N, 7); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 8); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + void qvm_split_k( const array& x, const array& w, @@ -818,7 +871,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { // Run of the mill qmv if (transpose_) { - qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + qmv_no_parallel_m(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); return; }