From e4b587819c3c75b0fb453274f193db0717e10946 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 13 Dec 2024 18:36:36 -0800 Subject: [PATCH] Attempt different packing --- benchmarks/python/packed_qmv_bench.py | 4 ++-- mlx/backend/metal/kernels/quantized.h | 20 +++++++------------- mlx/backend/metal/quantized.cpp | 11 +++++++++++ mlx/ops.cpp | 23 +++++++++++++---------- 4 files changed, 33 insertions(+), 25 deletions(-) diff --git a/benchmarks/python/packed_qmv_bench.py b/benchmarks/python/packed_qmv_bench.py index 3949bd5e5..aa55fdb77 100644 --- a/benchmarks/python/packed_qmv_bench.py +++ b/benchmarks/python/packed_qmv_bench.py @@ -5,11 +5,11 @@ from functools import partial import mlx.core as mx from time_utils import time_fn -D = 16384 +D = 8192 group_size = 64 bits = 3 dtype = mx.float16 -loops = 10 +loops = 100 def qmv_(x, wq, q_type): diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index e4f821b69..f2b3840bb 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2180,34 +2180,28 @@ METAL_FUNC void affine_packed_qmv_fast_impl( // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; - const int in_vec_size_g = - in_vec_size * results_per_simdgroup * 2 / group_size; - const int scales_row = tid.x * num_simdgroups + simd_gid; - const int out_row = scales_row * results_per_simdgroup; + const int out_vec_size_g = 2 * out_vec_size; + const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + const int scales_blocksize = (block_size / group_size) * out_vec_size_g; ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; - scales += scales_row * in_vec_size_g + - results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread); + scales += (simd_lid / scale_step_per_thread) * out_vec_size_g + 2 * out_row; x += tid.y * in_vec_size + simd_lid * values_per_thread; y += tid.y * out_vec_size + out_row; for (int k = 0; k < in_vec_size; k += block_size) { U sum = load_vector(x, x_thread); - U sb[2 * results_per_simdgroup]; - for (int i = 0; i < 2 * results_per_simdgroup; i++) { - sb[i] = scales[i]; - } - for (int row = 0; row < results_per_simdgroup; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); result[row] += qdot( - wl, x_thread, sb[2 * row + 0], sb[2 * row + 1], sum); + wl, x_thread, scales[2 * row + 0], scales[2 * row + 1], sum); } ws += block_size * bytes_per_pack / pack_factor; - scales += block_size * 2 * results_per_simdgroup / group_size; x += block_size; + scales += scales_blocksize; } for (int row = 0; row < results_per_simdgroup; row++) { diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index bf60d17d1..c2edbb5e7 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -433,6 +433,16 @@ void affine_packed_qmv( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +void affine_packed_qmm( + const std::vector& inputs, + array& out, + int B, + int D, + int O, + int group_size, + int bits, + const Stream& s) {} + void affine_packed_qmm_op( const std::vector& inputs, array& out, @@ -451,6 +461,7 @@ void affine_packed_qmm_op( if (B < 6) { affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s); } else { + affine_packed_qmm(inputs, out, B, D, O, group_size, bits, s); } } else { } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 3fc8fe70c..2f7e79434 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -128,9 +128,14 @@ std::pair extract_quantized_matmul_dims( } int weight_dims = w.shape(-1) * 32 / bits; - int scales_dims = scales.shape(-1) * group_size; - if (type == QuantizationType::AffinePacked) { - scales_dims /= 8; + int scales_dims; + switch (type) { + case QuantizationType::Affine: + scales_dims = scales.shape(-1) * group_size; + break; + case QuantizationType::AffinePacked: + scales_dims = scales.shape(-2) * group_size; + break; } if (weight_dims != scales_dims) { @@ -3782,14 +3787,12 @@ std::tuple> quantize( // Pack scales and biases if (type == QuantizationType::AffinePacked) { - scales = unflatten(scales, -2, {-1, 4, 1}, s); - biases = unflatten(biases, -2, {-1, 4, 1}, s); - scales = concatenate({scales, biases}, -2, s); - scales = flatten(scales, -3, -2, s); - scales = moveaxis(scales, -2, -1, s); - scales = flatten(scales, -2, -1, s); + array packed_scales_biases = + flatten(stack({scales, biases}, -2, s), -3, -2, s); + packed_scales_biases = + contiguous(moveaxis(packed_scales_biases, -2, -1, s), false, s); - return std::make_tuple(wq, scales, std::nullopt); + return std::make_tuple(wq, packed_scales_biases, std::nullopt); } else { return std::make_tuple(wq, scales, biases); }