From cb358dbddab2a457288d8bd5419554e936139c29 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 13 Dec 2024 23:23:21 -0800 Subject: [PATCH] Revert "Attempt different packing" This reverts commit e4b587819c3c75b0fb453274f193db0717e10946. --- benchmarks/python/packed_qmv_bench.py | 2 +- mlx/backend/metal/kernels/quantized.h | 20 +++++++++++++------- mlx/backend/metal/quantized.cpp | 11 ----------- mlx/ops.cpp | 23 ++++++++++------------- 4 files changed, 24 insertions(+), 32 deletions(-) diff --git a/benchmarks/python/packed_qmv_bench.py b/benchmarks/python/packed_qmv_bench.py index aa55fdb77..0848227a3 100644 --- a/benchmarks/python/packed_qmv_bench.py +++ b/benchmarks/python/packed_qmv_bench.py @@ -7,7 +7,7 @@ from time_utils import time_fn D = 8192 group_size = 64 -bits = 3 +bits = 4 dtype = mx.float16 loops = 100 diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index f2b3840bb..e4f821b69 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2180,28 +2180,34 @@ METAL_FUNC void affine_packed_qmv_fast_impl( // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; - const int out_vec_size_g = 2 * out_vec_size; - const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + - simd_gid * results_per_simdgroup; - const int scales_blocksize = (block_size / group_size) * out_vec_size_g; + const int in_vec_size_g = + in_vec_size * results_per_simdgroup * 2 / group_size; + const int scales_row = tid.x * num_simdgroups + simd_gid; + const int out_row = scales_row * results_per_simdgroup; ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; - scales += (simd_lid / scale_step_per_thread) * out_vec_size_g + 2 * out_row; + scales += scales_row * in_vec_size_g + + results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread); x += tid.y * in_vec_size + simd_lid * values_per_thread; y += tid.y * out_vec_size + out_row; for (int k = 0; k < in_vec_size; k += block_size) { U sum = load_vector(x, x_thread); + U sb[2 * results_per_simdgroup]; + for (int i = 0; i < 2 * results_per_simdgroup; i++) { + sb[i] = scales[i]; + } + for (int row = 0; row < results_per_simdgroup; row++) { auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); result[row] += qdot( - wl, x_thread, scales[2 * row + 0], scales[2 * row + 1], sum); + wl, x_thread, sb[2 * row + 0], sb[2 * row + 1], sum); } ws += block_size * bytes_per_pack / pack_factor; + scales += block_size * 2 * results_per_simdgroup / group_size; x += block_size; - scales += scales_blocksize; } for (int row = 0; row < results_per_simdgroup; row++) { diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index c2edbb5e7..bf60d17d1 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -433,16 +433,6 @@ void affine_packed_qmv( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } -void affine_packed_qmm( - const std::vector& inputs, - array& out, - int B, - int D, - int O, - int group_size, - int bits, - const Stream& s) {} - void affine_packed_qmm_op( const std::vector& inputs, array& out, @@ -461,7 +451,6 @@ void affine_packed_qmm_op( if (B < 6) { affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s); } else { - affine_packed_qmm(inputs, out, B, D, O, group_size, bits, s); } } else { } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2f7e79434..3fc8fe70c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -128,14 +128,9 @@ std::pair extract_quantized_matmul_dims( } int weight_dims = w.shape(-1) * 32 / bits; - int scales_dims; - switch (type) { - case QuantizationType::Affine: - scales_dims = scales.shape(-1) * group_size; - break; - case QuantizationType::AffinePacked: - scales_dims = scales.shape(-2) * group_size; - break; + int scales_dims = scales.shape(-1) * group_size; + if (type == QuantizationType::AffinePacked) { + scales_dims /= 8; } if (weight_dims != scales_dims) { @@ -3787,12 +3782,14 @@ std::tuple> quantize( // Pack scales and biases if (type == QuantizationType::AffinePacked) { - array packed_scales_biases = - flatten(stack({scales, biases}, -2, s), -3, -2, s); - packed_scales_biases = - contiguous(moveaxis(packed_scales_biases, -2, -1, s), false, s); + scales = unflatten(scales, -2, {-1, 4, 1}, s); + biases = unflatten(biases, -2, {-1, 4, 1}, s); + scales = concatenate({scales, biases}, -2, s); + scales = flatten(scales, -3, -2, s); + scales = moveaxis(scales, -2, -1, s); + scales = flatten(scales, -2, -1, s); - return std::make_tuple(wq, packed_scales_biases, std::nullopt); + return std::make_tuple(wq, scales, std::nullopt); } else { return std::make_tuple(wq, scales, biases); }