From d75a5092341ccedb25432eda3b1d509e4aab7cf3 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 17 Dec 2024 10:08:47 -0800 Subject: [PATCH] Add 3bit packed quants --- mlx/backend/metal/kernels/quantized.h | 86 +++++++++++++++++++++++++-- mlx/backend/metal/quantized.cpp | 5 +- mlx/ops.cpp | 39 ++++++------ python/mlx/nn/layers/quantized.py | 11 +++- 4 files changed, 112 insertions(+), 29 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index c88f20923..8328ab16a 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2295,6 +2295,68 @@ METAL_FUNC void affine_packed_qmv_fast_impl( } } +template +METAL_FUNC void affine_packed_byte_qmv_fast_impl( + const device uint8_t* w, + const device vec* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int packs_per_thread = 2; + constexpr int num_simdgroups = 2; + constexpr int pack_factor = (bits == 3) ? 8 : 4; + ; + constexpr int bytes_per_pack = 3; + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + typedef float U; + + thread U x_thread[values_per_thread]; + vec result = 0; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int scales_row = tid.x * num_simdgroups + simd_gid; + const int out_row = scales_row * results_per_simdgroup; + + w += out_row * in_vec_size_w + simd_lid * (packs_per_thread * bytes_per_pack); + scales += scales_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.y * in_vec_size + simd_lid * values_per_thread; + y += tid.y * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + // Load the input vector + U sum = load_vector(x, x_thread); + + // Load the scales and biases + vec sb = scales[0]; + + // Load the weights and perform the partial dot product + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] += qdot( + w + row * in_vec_size_w, x_thread, sb[row], sb[2 + row], sum); + } + + w += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} + template [[kernel]] void affine_packed_qmv_fast( const device vec* w [[buffer(0)]], @@ -2306,8 +2368,21 @@ template uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - affine_packed_qmv_fast_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); + if (bits & (bits - 1)) { + affine_packed_byte_qmv_fast_impl( + (const device uint8_t*)w, + scales, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); + } else { + affine_packed_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); + } } template < @@ -2617,6 +2692,9 @@ template < s_strides, tid); } - affine_packed_qmm_t_impl( - w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); + if (bits & (bits - 1)) { + } else { + affine_packed_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); + } } diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 925bf8560..29bad1baf 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -405,10 +405,11 @@ void affine_packed_qmv( auto w = ensure_row_contiguous_last_dims(inputs[1]); auto scales = ensure_row_contiguous_last_dims(inputs[2]); + const bool pow2_bits = (bits & (bits - 1)) == 0; const int n_simdgroups = 2; - const int n_outs_per_simdgroup = 4; + const int results_per_simdgroup = (pow2_bits) ? 4 : 2; MTL::Size group_dims(32, n_simdgroups, 1); - MTL::Size grid_dims(O / n_simdgroups / n_outs_per_simdgroup, B, 1); + MTL::Size grid_dims(O / n_simdgroups / results_per_simdgroup, B, 1); std::string name; name.reserve(64); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 14a7cdb86..db409536b 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -99,15 +99,11 @@ std::pair extract_quantized_matmul_dims( << "biases but biases were provided"; throw std::invalid_argument(msg.str()); } - if (bits & (bits - 1)) { - std::ostringstream msg; - msg << "[" << tag << "] Quantization type '" << quantization_type - << "' does not support " << bits << " bits."; - throw std::invalid_argument(msg.str()); - } break; } + bool pow2_bits = (bits & (bits - 1)) == 0; + if (w.dtype() != uint32) { std::ostringstream msg; msg << "[" << tag << "] The weight matrix should be uint32 " @@ -136,8 +132,12 @@ std::pair extract_quantized_matmul_dims( int weight_dims = w.shape(-1) * 32 / bits; int scales_dims = scales.shape(-1) * group_size; if (quantization_type == QuantizationType::AffinePacked) { - scales_dims /= 8; - weight_dims /= 4; + if (pow2_bits) { + scales_dims /= 8; + weight_dims /= 4; + } else { + scales_dims /= 4; + } } if (weight_dims != scales_dims) { @@ -155,7 +155,7 @@ std::pair extract_quantized_matmul_dims( // Calculate the expanded w's dims int weight_dims_other = w.shape(-2); - if (quantization_type == QuantizationType::AffinePacked) { + if (quantization_type == QuantizationType::AffinePacked && pow2_bits) { weight_dims_other *= 4; } int w_inner_dims = (transpose) ? weight_dims : weight_dims_other; @@ -3793,23 +3793,22 @@ std::tuple> quantize( case QuantizationType::Affine: return fast::affine_quantize(w, group_size, bits, s); case QuantizationType::AffinePacked: { - if (bits & (bits - 1)) { - std::ostringstream msg; - msg << "[quantize] Quantization type '" << quantization_type - << "' does not support " << bits << " bits."; - throw std::invalid_argument(msg.str()); - } auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s); - scales = unflatten(scales, -2, {-1, 4}, s); - biases = unflatten(biases, -2, {-1, 4}, s); + int pow2_bits = (bits & (bits - 1)) == 0; + int row_packing = (pow2_bits) ? 4 : 2; + + scales = unflatten(scales, -2, {-1, row_packing}, s); + biases = unflatten(biases, -2, {-1, row_packing}, s); scales = concatenate({scales, biases}, -2, s); scales = moveaxis(scales, -2, -1, s); scales = flatten(scales, -2, -1, s); - wq = unflatten(wq, -2, {-1, 4}, s); - wq = moveaxis(wq, -2, -1, s); - wq = flatten(wq, -2, -1, s); + if (pow2_bits) { + wq = unflatten(wq, -2, {-1, row_packing}, s); + wq = moveaxis(wq, -2, -1, s); + wq = flatten(wq, -2, -1, s); + } return std::make_tuple(wq, scales, std::nullopt); } diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 76d30b1de..ee5e28a4e 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -181,12 +181,14 @@ class QuantizedLinear(Module): bias: bool = True, group_size: int = 64, bits: int = 4, + quantization_type: str = "affine", ): super().__init__() # Quantization config self.group_size = group_size self.bits = bits + self.quantization_type = quantization_type # Initialize the quantized weight scale = math.sqrt(1 / input_dims) @@ -195,7 +197,9 @@ class QuantizedLinear(Module): high=scale, shape=(output_dims, input_dims), ) - self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + self.weight, self.scales, self.biases = mx.quantize( + weight, group_size, bits, quantization_type=quantization_type + ) # And bias if needed if bias: @@ -223,10 +227,11 @@ class QuantizedLinear(Module): x, self["weight"], scales=self["scales"], - biases=self["biases"], + biases=self.get("biases", None), transpose=True, group_size=self.group_size, bits=self.bits, + quantization_type=self.quantization_type, ) if "bias" in self: x = x + self["bias"] @@ -242,7 +247,7 @@ class QuantizedLinear(Module): ): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape - ql = cls(input_dims, output_dims, False, group_size, bits) + ql = cls(input_dims, output_dims, False, group_size, bits, quantization_type) ql.weight, ql.scales, ql.biases = mx.quantize( linear_layer.weight, group_size, bits, quantization_type=quantization_type )