From d7ed62450200dc0378cc6dc8ad9580453867c6f6 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 14 Dec 2024 15:08:24 -0800 Subject: [PATCH] Vectorized reads --- mlx/backend/metal/kernels/quantized.h | 89 ++++++++++++++++++--------- mlx/ops.cpp | 5 +- 2 files changed, 63 insertions(+), 31 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 149f49e35..88700f645 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2150,10 +2150,41 @@ template } } +template +inline vec partial_qdot_vec(const thread U* x, vec w) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); + + vec accum = 0; + + if (bits == 4) { + for (int i = 0; i < 4; i++) { + auto ws = as_type>(w[i]); + for (int j = 0; j < 2; j++) { + accum[i] += + (x[4 * j + 0] * (ws[j] & 0x000f) + x[4 * j + 1] * (ws[j] & 0x00f0) + + x[4 * j + 2] * (ws[j] & 0x0f00) + x[4 * j + 3] * (ws[j] & 0xf000)); + } + } + } + + else if (bits == 8) { + for (int i = 0; i < 4; i++) { + auto ws = as_type>(w[i]); + for (int j = 0; j < 4; j++) { + accum[i] += x[j] * ws[j]; + } + } + } + + return accum; +} + template METAL_FUNC void affine_packed_qmv_fast_impl( - const device uint32_t* w, - const device T* scales, + const device vec* w, + const device vec* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -2162,7 +2193,7 @@ METAL_FUNC void affine_packed_qmv_fast_impl( uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int packs_per_thread = 1; + constexpr int packs_per_thread = 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; @@ -2171,48 +2202,50 @@ METAL_FUNC void affine_packed_qmv_fast_impl( constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; - const device uint8_t* ws = (const device uint8_t*)w; - typedef float U; thread U x_thread[values_per_thread]; - thread U result[results_per_simdgroup] = {0}; + vec result = 0; // Adjust positions - const int in_vec_size_w = - in_vec_size * results_per_simdgroup * bytes_per_pack / pack_factor; - const int in_vec_size_g = - in_vec_size * results_per_simdgroup * 2 / group_size; + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size * 2 / group_size; const int w_row = tid.x * num_simdgroups + simd_gid; const int out_row = w_row * results_per_simdgroup; - ws += w_row * in_vec_size_w + - simd_lid * results_per_simdgroup * packs_per_thread * bytes_per_pack; - scales += w_row * in_vec_size_g + - results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread); + w += w_row * in_vec_size_w + simd_lid * packs_per_thread; + scales += w_row * in_vec_size_g + 2 * (simd_lid / scale_step_per_thread); x += tid.y * in_vec_size + simd_lid * values_per_thread; y += tid.y * out_vec_size + out_row; for (int k = 0; k < in_vec_size; k += block_size) { + // Load the input vector U sum = load_vector(x, x_thread); - for (int row = 0; row < results_per_simdgroup; row++) { - result[row] += qdot( - ws + row * bytes_per_pack, - x_thread, - scales[2 * row + 0], - scales[2 * row + 1], - sum); + // Load the scales and biases + vec s = scales[0]; + vec b = scales[1]; + + // Load the weights and perform the partial dot product + vec accum = 0; + for (int pack = 0; pack < packs_per_thread; pack++) { + accum += + partial_qdot_vec(x_thread + pack * pack_factor, w[pack]); } - ws += results_per_simdgroup * block_size * bytes_per_pack / pack_factor; - scales += block_size * 2 * results_per_simdgroup / group_size; + // Finalize the dot product and accumulate it + for (int i = 0; i < 4; i++) { + result[i] += static_cast(s[i]) * accum[i] + static_cast(b[i]) * sum; + } + + w += block_size / pack_factor; + scales += 2 * block_size / group_size; x += block_size; } - for (int row = 0; row < results_per_simdgroup; row++) { - result[row] = simd_sum(result[row]); - if (simd_lid == 0) { + result = simd_sum(result); + if (simd_lid == 0) { + for (int row = 0; row < results_per_simdgroup; row++) { y[row] = static_cast(result[row]); } } @@ -2220,8 +2253,8 @@ METAL_FUNC void affine_packed_qmv_fast_impl( template [[kernel]] void affine_packed_qmv_fast( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], + const device vec* w [[buffer(0)]], + const device vec* scales [[buffer(1)]], const device T* x [[buffer(2)]], device T* y [[buffer(3)]], const constant int& in_vec_size [[buffer(5)]], diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 3577c33de..7f6d296d1 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3789,10 +3789,9 @@ std::tuple> quantize( case QuantizationType::AffinePacked: { auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s); - scales = unflatten(scales, -2, {-1, 4, 1}, s); - biases = unflatten(biases, -2, {-1, 4, 1}, s); + scales = unflatten(scales, -2, {-1, 4}, s); + biases = unflatten(biases, -2, {-1, 4}, s); scales = concatenate({scales, biases}, -2, s); - scales = flatten(scales, -3, -2, s); scales = moveaxis(scales, -2, -1, s); scales = flatten(scales, -2, -1, s);