From da83f899bbf3279db3efe2f410342f3a467a1aab Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 20 May 2024 09:20:44 -0700 Subject: [PATCH] Improve qvm speed (#1140) --- mlx/backend/metal/kernels/quantized.metal | 86 +++++++++++++++-------- mlx/backend/metal/quantized.cpp | 12 ++-- 2 files changed, 62 insertions(+), 36 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 61fe3c303..5bc612aae 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -601,14 +601,18 @@ METAL_FUNC void qvm_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int num_simdgroups = 8; + constexpr int num_simdgroups = 2; constexpr int pack_factor = 32 / bits; + constexpr int tn = 32 / pack_factor; constexpr int blocksize = SIMD_SIZE; typedef float U; + typedef struct { + uint32_t wi[tn]; + } vec_w; - thread uint32_t w_local; - thread U result[pack_factor] = {0}; + thread vec_w w_local; + thread U result[tn * pack_factor] = {0}; thread U scale = 1; thread U bias = 0; thread U x_local = 0; @@ -616,11 +620,12 @@ METAL_FUNC void qvm_impl( // Adjust positions const int out_vec_size_w = out_vec_size / pack_factor; const int out_vec_size_g = out_vec_size / group_size; - int out_col = tid.x * (num_simdgroups * pack_factor) + simd_gid * pack_factor; - w += out_col / pack_factor; - scales += out_col / group_size; - biases += out_col / group_size; - x += tid.y * in_vec_size; + int out_col = + tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn; + w += out_col / pack_factor + simd_lid * out_vec_size_w; + scales += out_col / group_size + simd_lid * out_vec_size_g; + biases += out_col / group_size + simd_lid * out_vec_size_g; + x += tid.y * in_vec_size + simd_lid; y += tid.y * out_vec_size + out_col; if (out_col >= out_vec_size) { @@ -628,40 +633,61 @@ METAL_FUNC void qvm_impl( } // Loop over in_vec in blocks of blocksize - int i = 0; - for (; i + blocksize <= in_vec_size; i += blocksize) { - x_local = x[i + simd_lid]; - scale = scales[(i + simd_lid) * out_vec_size_g]; - bias = biases[(i + simd_lid) * out_vec_size_g]; - w_local = w[(i + simd_lid) * out_vec_size_w]; + int remaining = in_vec_size % blocksize; + if (remaining == 0) { + for (int i = 0; i < in_vec_size; i += blocksize) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)w); - qouter( + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + + x += blocksize; + scales += blocksize * out_vec_size_g; + biases += blocksize * out_vec_size_g; + w += blocksize * out_vec_size_w; + } + } else { + for (int i = blocksize; i < in_vec_size; i += blocksize) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)w); + + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + + x += blocksize; + scales += blocksize * out_vec_size_g; + biases += blocksize * out_vec_size_g; + w += blocksize * out_vec_size_w; + } + if (static_cast(simd_lid) < remaining) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)w); + } else { + x_local = 0; + scale = 0; + bias = 0; + } + qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); } - if (static_cast(i + simd_lid) < in_vec_size) { - x_local = x[i + simd_lid]; - scale = scales[(i + simd_lid) * out_vec_size_g]; - bias = biases[(i + simd_lid) * out_vec_size_g]; - w_local = w[(i + simd_lid) * out_vec_size_w]; - } else { - x_local = 0; - scale = 0; - bias = 0; - w_local = 0; - } - qouter( - (thread uint8_t*)&w_local, x_local, scale, bias, result); // Accumulate in the simdgroup #pragma clang loop unroll(full) - for (int k = 0; k < pack_factor; k++) { + for (int k = 0; k < tn * pack_factor; k++) { result[k] = simd_sum(result[k]); } // Store the result if (simd_lid == 0) { #pragma clang loop unroll(full) - for (int k = 0; k < pack_factor; k++) { + for (int k = 0; k < tn * pack_factor; k++) { y[k] = static_cast(result[k]); } } diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 980509f47..609d7bfac 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -137,10 +137,10 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - int bo = 8; + int bo = 64; int bd = 32; - MTL::Size group_dims = MTL::Size(bd, bo, 1); - MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1); + MTL::Size group_dims = MTL::Size(bd, 2, 1); + MTL::Size grid_dims = MTL::Size(O / bo, B, 1); compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(w, 1); @@ -393,10 +393,10 @@ void BlockSparseQMM::eval_gpu(const std::vector& inputs, array& out) { auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - int bo = 8; + int bo = 64; int bd = 32; - MTL::Size group_dims = MTL::Size(bd, bo, 1); - MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N); + MTL::Size group_dims = MTL::Size(bd, 2, 1); + MTL::Size grid_dims = MTL::Size(O / bo, B, N); compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(w, 1);