From bdd68bd893e35ba95775d6742f90cc8f25654c58 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 12 Dec 2024 01:30:38 -0800 Subject: [PATCH] Add a quantization type in the ops --- mlx/ops.cpp | 136 +++++++++++++++++++++++++++++++++------------ mlx/ops.h | 18 +++--- mlx/primitives.cpp | 4 +- mlx/primitives.h | 13 ++++- mlx/utils.cpp | 13 +++++ mlx/utils.h | 12 ++++ 6 files changed, 152 insertions(+), 44 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a0a259580..5d01981a9 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -75,10 +75,33 @@ std::pair extract_quantized_matmul_dims( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, bool transpose, int group_size, - int bits) { + int bits, + QuantizationType type) { + // Check if we have biases as expected + switch (type) { + case QuantizationType::Affine: + if (!biases.has_value()) { + std::ostringstream msg; + msg << "[" << tag + << "] The biases argument is required for quantization " + << "type '" << type << "'"; + throw std::invalid_argument(msg.str()); + } + break; + case QuantizationType::AffinePacked: + if (biases.has_value()) { + std::ostringstream msg; + msg << "[" << tag << "] Quantization type '" << type + << "' does not use " + << "biases but biases were provided"; + throw std::invalid_argument(msg.str()); + } + break; + } + if (w.dtype() != uint32) { std::ostringstream msg; msg << "[" << tag << "] The weight matrix should be uint32 " @@ -86,11 +109,11 @@ std::pair extract_quantized_matmul_dims( throw std::invalid_argument(msg.str()); } - if (scales.shape() != biases.shape()) { + if (biases.has_value() && scales.shape() != biases.value().shape()) { std::ostringstream msg; msg << "[" << tag << "] Scales and biases should have the same shape. " << "Received scales with shape " << scales.shape() - << " and biases with " << biases.shape(); + << " and biases with " << biases.value().shape(); throw std::invalid_argument(msg.str()); } @@ -99,25 +122,33 @@ std::pair extract_quantized_matmul_dims( std::ostringstream msg; msg << "[" << tag << "] Weight, scales and biases should have the same batch shape. " - << "Received weight with shape " << w.shape() << ", scales with " - << scales.shape() << " and biases with " << biases.shape(); + << "Received weight with shape " << w.shape() + << " and scales/biases with " << scales.shape(); throw std::invalid_argument(msg.str()); } - if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) { + int weight_dims = w.shape(-1) * 32 / bits; + int scales_dims = scales.shape(-1) * group_size; + if (type == QuantizationType::AffinePacked) { + scales_dims /= 8; + } + + if (weight_dims != scales_dims) { std::ostringstream msg; msg << "[" << tag << "] The shapes of the weight and scales are " - << "incompatible based on bits and group_size. w.shape() == " - << w.shape() << " and scales.shape() == " << scales.shape() - << " with group_size=" << group_size << " and bits=" << bits; + << "incompatible based on bits, group_size and quantization type. " + << "w.shape() == " << w.shape() + << " and scales.shape() == " << scales.shape() + << " with group_size=" << group_size << ", bits=" << bits + << " and type='" << type << "'"; throw std::invalid_argument(msg.str()); } int x_inner_dims = x.shape(-1); // Calculate the expanded w's dims - int w_inner_dims = (transpose) ? w.shape(-1) * 32 / bits : w.shape(-2); - int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1) * 32 / bits; + int w_inner_dims = (transpose) ? weight_dims : w.shape(-2); + int w_outer_dims = (transpose) ? w.shape(-2) : weight_dims; if (w_inner_dims != x_inner_dims) { std::ostringstream msg; @@ -3662,14 +3693,23 @@ array quantized_matmul( array x, array w, array scales, - array biases, + std::optional biases, bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, + QuantizationType type /* = QuantizationType::Affine */, StreamOrDevice s /* = {} */) { // Check and extract the quantized matrix shape against x auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( - "quantized_matmul", x, w, scales, biases, transpose, group_size, bits); + "quantized_matmul", + x, + w, + scales, + biases, + transpose, + group_size, + bits, + type); // QuantizedMatmul handles w.ndim == 2 case. if (x.ndim() > 2 && w.ndim() > 2) { @@ -3690,37 +3730,53 @@ array quantized_matmul( *(inner_shape.end() - 1) = scales.shape(-1); scales = broadcast_to(scales, inner_shape, s); - *(inner_shape.end() - 1) = biases.shape(-1); - biases = broadcast_to(biases, inner_shape, s); + if (biases.has_value()) { + *(inner_shape.end() - 1) = biases.value().shape(-1); + biases = broadcast_to(biases.value(), inner_shape, s); + } } - auto dtype = result_type(x, scales, biases); + auto dtype = result_type(x, scales); + if (biases.has_value()) { + dtype = promote_types(dtype, biases.value().dtype()); + } if (!issubdtype(dtype, floating)) { std::ostringstream msg; msg << "[quantized_matmul] Only real floating types are supported but " << "the passed types where x.dtype() == " << x.dtype() - << ", scales.dtype() == " << scales.dtype() - << " and biases.dtype() == " << biases.dtype(); + << ", scales.dtype() == " << scales.dtype(); + if (biases.has_value()) { + msg << " and biases.dtype() == " << biases.value().dtype(); + } throw std::invalid_argument(msg.str()); } + // Prepare the inputs vector + std::vector inputs; + inputs.reserve(4); + inputs.push_back(astype(x, dtype, s)); + inputs.push_back(w); + inputs.push_back(astype(scales, dtype, s)); + if (biases.has_value()) { + inputs.push_back(astype(biases.value(), dtype, s)); + } + auto out_shape = x.shape(); out_shape.back() = w_outer_dims; + return array( std::move(out_shape), dtype, std::make_shared( - to_stream(s), group_size, bits, transpose), - {astype(x, dtype, s), - w, - astype(scales, dtype, s), - astype(biases, dtype, s)}); + to_stream(s), type, group_size, bits, transpose), + std::move(inputs)); } -std::tuple quantize( +std::tuple> quantize( const array& w, int group_size /* = 64 */, int bits /* = 4 */, + QuantizationType type /* = QuantizationType::Affine */, StreamOrDevice s /* = {} */) { return fast::affine_quantize(w, group_size, bits, s); } @@ -3728,31 +3784,40 @@ std::tuple quantize( array dequantize( const array& w, const array& scales, - const array& biases, + const std::optional& biases, int group_size /* = 64 */, int bits /* = 4 */, + QuantizationType type /* = QuantizationType::Affine */, StreamOrDevice s /* = {} */) { - return fast::affine_dequantize(w, scales, biases, group_size, bits, s); + return fast::affine_dequantize( + w, scales, biases.value(), group_size, bits, s); } array gather_qmm( const array& x, const array& w, const array& scales, - const array& biases, - std::optional lhs_indices_ /* = std::nullopt */, - std::optional rhs_indices_ /* = std::nullopt */, + const std::optional& biases, + const std::optional& lhs_indices_ /* = std::nullopt */, + const std::optional& rhs_indices_ /* = std::nullopt */, bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, + QuantizationType type /* = QuantizationType::Affine */, StreamOrDevice s /* = {} */) { if (!lhs_indices_ && !rhs_indices_) { return quantized_matmul( - x, w, scales, biases, transpose, group_size, bits, s); + x, w, scales, biases, transpose, group_size, bits, type, s); + } + + if (type != QuantizationType::Affine) { + std::ostringstream msg; + msg << "[gather_qmm] Only quantization type '" << type << "' supported"; + throw std::invalid_argument(msg.str()); } auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( - "gather_qmm", x, w, scales, biases, transpose, group_size, bits); + "gather_qmm", x, w, scales, biases, transpose, group_size, bits, type); // Extract indices and broadcast them array lhs_indices = indices_or_default(lhs_indices_, x, s); @@ -3768,16 +3833,17 @@ array gather_qmm( out_shape.push_back(w_outer_dims); // and output type - auto out_type = result_type(x, scales, biases); + auto out_type = result_type(x, scales, biases.value()); return array( std::move(out_shape), out_type, - std::make_shared(to_stream(s), group_size, bits, transpose), + std::make_shared( + to_stream(s), type, group_size, bits, transpose), {astype(x, out_type, s), w, astype(scales, out_type, s), - astype(biases, out_type, s), + astype(biases.value(), out_type, s), lhs_indices, rhs_indices}); } diff --git a/mlx/ops.h b/mlx/ops.h index 7576774b5..b20482ca7 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1277,31 +1277,34 @@ array conv_transpose3d( int groups = 1, StreamOrDevice s = {}); -/** Quantized matmul multiplies x with a quantized matrix w*/ +/** Quantized matmul multiplies x with a quantized matrix w */ array quantized_matmul( array x, array w, array scales, - array biases, + std::optional biases, bool transpose = true, int group_size = 64, int bits = 4, + QuantizationType type = QuantizationType::Affine, StreamOrDevice s = {}); /** Quantize a matrix along its last axis */ -std::tuple quantize( +std::tuple> quantize( const array& w, int group_size = 64, int bits = 4, + QuantizationType type = QuantizationType::Affine, StreamOrDevice s = {}); /** Dequantize a matrix produced by quantize() */ array dequantize( const array& w, const array& scales, - const array& biases, + const std::optional& biases, int group_size = 64, int bits = 4, + QuantizationType type = QuantizationType::Affine, StreamOrDevice s = {}); /** Compute matrix products with matrix-level gather. */ @@ -1309,12 +1312,13 @@ array gather_qmm( const array& x, const array& w, const array& scales, - const array& biases, - std::optional lhs_indices = std::nullopt, - std::optional rhs_indices = std::nullopt, + const std::optional& biases, + const std::optional& lhs_indices = std::nullopt, + const std::optional& rhs_indices = std::nullopt, bool transpose = true, int group_size = 64, int bits = 4, + QuantizationType type = QuantizationType::Affine, StreamOrDevice s = {}); /** Returns a contraction of a and b over multiple dimensions. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index fa7b384f5..29fd49827 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2777,10 +2777,11 @@ std::vector QuantizedMatmul::vjp( cotangents[0], primals[1], primals[2], - primals[3], + (primals.size() > 3) ? std::optional(primals[3]) : std::nullopt, !transpose_, group_size_, bits_, + type_, stream())); } @@ -2855,6 +2856,7 @@ std::vector GatherQMM::vjp( !transpose_, group_size_, bits_, + type_, stream()), -3, stream()), diff --git a/mlx/primitives.h b/mlx/primitives.h index 55a87cf18..13a2dc80d 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -8,6 +8,7 @@ #include "mlx/device.h" #include "mlx/io/load.h" #include "mlx/stream.h" +#include "mlx/utils.h" #define DEFINE_VMAP() \ virtual std::pair, std::vector> vmap( \ @@ -1568,10 +1569,12 @@ class QuantizedMatmul : public UnaryPrimitive { public: explicit QuantizedMatmul( Stream stream, + QuantizationType type, int group_size, int bits, bool transpose) : UnaryPrimitive(stream), + type_(type), group_size_(group_size), bits_(bits), transpose_(transpose) {} @@ -1586,6 +1589,7 @@ class QuantizedMatmul : public UnaryPrimitive { std::vector output_shapes(const std::vector& inputs) override; private: + QuantizationType type_; int group_size_; int bits_; bool transpose_; @@ -1595,8 +1599,14 @@ class QuantizedMatmul : public UnaryPrimitive { class GatherQMM : public UnaryPrimitive { public: - explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose) + explicit GatherQMM( + Stream stream, + QuantizationType type, + int group_size, + int bits, + bool transpose) : UnaryPrimitive(stream), + type_(type), group_size_(group_size), bits_(bits), transpose_(transpose) {} @@ -1610,6 +1620,7 @@ class GatherQMM : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; private: + QuantizationType type_; int group_size_; int bits_; bool transpose_; diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 6d05ad5f8..4718149bd 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -145,6 +145,19 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) { return os; } +std::ostream& operator<<(std::ostream& os, QuantizationType type) { + std::string_view quantization_type; + switch (type) { + case QuantizationType::Affine: + quantization_type = "affine"; + break; + case QuantizationType::AffinePacked: + quantization_type = "affine-packed"; + break; + } + return os << quantization_type; +} + namespace { inline size_t diff --git a/mlx/utils.h b/mlx/utils.h index 04f59feaa..05d0540a4 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -100,6 +100,18 @@ inline int next_power_of_2(int n) { return pow(2, std::ceil(std::log2(n))); } +/** Enumerate the different quantization types */ +enum class QuantizationType { + // Traditional affine quantization with separate scales and biases + Affine, + + // The same quantization as affine but with the scales and biases packed in a + // single array and contiguously every 4 rows + AffinePacked, +}; + +std::ostream& operator<<(std::ostream& os, QuantizationType type); + namespace env { int get_var(const char* name, int default_value);