From 410ccdbed5410915b590302f7ea7c2be1c1e330c Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 16 Dec 2024 13:31:34 -0800 Subject: [PATCH] Change the argument name to quantization_type --- benchmarks/python/packed_qmv_bench.py | 18 +++++--- mlx/ops.cpp | 61 +++++++++++++++++---------- mlx/ops.h | 8 ++-- python/src/ops.cpp | 41 ++++++++++-------- 4 files changed, 78 insertions(+), 50 deletions(-) diff --git a/benchmarks/python/packed_qmv_bench.py b/benchmarks/python/packed_qmv_bench.py index 3fcc82fc5..f6c6a4724 100644 --- a/benchmarks/python/packed_qmv_bench.py +++ b/benchmarks/python/packed_qmv_bench.py @@ -20,14 +20,14 @@ def qmv_(x, wq1, wq2, q_type): *wq1, group_size=group_size, bits=bits, - type=q_type, + quantization_type=q_type, ) x = mx.quantized_matmul( x, *wq2, group_size=group_size, bits=bits, - type=q_type, + quantization_type=q_type, ) return x @@ -44,9 +44,9 @@ def time_qmv(): mx.random.seed(3) x = mx.random.normal(shape=(1, D)).astype(dtype) w1 = mx.random.normal(shape=(M, D)).astype(dtype) - wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine") + wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine") w2 = mx.random.normal(shape=(D, M)).astype(dtype) - wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine") + wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine") mx.eval(x, wq1, wq2) time_fn(affine_qmv, x, wq1, wq2) @@ -55,15 +55,19 @@ def time_packed_qmv(): mx.random.seed(3) x = mx.random.normal(shape=(1, D)).astype(dtype) w1 = mx.random.normal(shape=(M, D)).astype(dtype) - wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine-packed") + wq1 = mx.quantize( + w1, group_size=group_size, bits=bits, quantization_type="affine-packed" + ) w2 = mx.random.normal(shape=(D, M)).astype(dtype) - wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine-packed") + wq2 = mx.quantize( + w2, group_size=group_size, bits=bits, quantization_type="affine-packed" + ) mx.eval(x, wq1, wq2) time_fn(affine_packed_qmv, x, wq1, wq2) if __name__ == "__main__": - for b in [2, 3, 4, 6, 8]: + for b in [2, 4, 8]: bits = b print(f"Bits {bits}:") time_qmv() diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 7b41ee90a..14a7cdb86 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -79,29 +79,29 @@ std::pair extract_quantized_matmul_dims( bool transpose, int group_size, int bits, - QuantizationType type) { + QuantizationType quantization_type) { // Check if we have biases as expected - switch (type) { + switch (quantization_type) { case QuantizationType::Affine: if (!biases.has_value()) { std::ostringstream msg; msg << "[" << tag << "] The biases argument is required for quantization " - << "type '" << type << "'"; + << "type '" << quantization_type << "'"; throw std::invalid_argument(msg.str()); } break; case QuantizationType::AffinePacked: if (biases.has_value()) { std::ostringstream msg; - msg << "[" << tag << "] Quantization type '" << type + msg << "[" << tag << "] Quantization type '" << quantization_type << "' does not use " << "biases but biases were provided"; throw std::invalid_argument(msg.str()); } if (bits & (bits - 1)) { std::ostringstream msg; - msg << "[" << tag << "] Quantization type '" << type + msg << "[" << tag << "] Quantization type '" << quantization_type << "' does not support " << bits << " bits."; throw std::invalid_argument(msg.str()); } @@ -135,7 +135,7 @@ std::pair extract_quantized_matmul_dims( int weight_dims = w.shape(-1) * 32 / bits; int scales_dims = scales.shape(-1) * group_size; - if (type == QuantizationType::AffinePacked) { + if (quantization_type == QuantizationType::AffinePacked) { scales_dims /= 8; weight_dims /= 4; } @@ -147,7 +147,7 @@ std::pair extract_quantized_matmul_dims( << "w.shape() == " << w.shape() << " and scales.shape() == " << scales.shape() << " with group_size=" << group_size << ", bits=" << bits - << " and type='" << type << "'"; + << " and type='" << quantization_type << "'"; throw std::invalid_argument(msg.str()); } @@ -155,7 +155,7 @@ std::pair extract_quantized_matmul_dims( // Calculate the expanded w's dims int weight_dims_other = w.shape(-2); - if (type == QuantizationType::AffinePacked) { + if (quantization_type == QuantizationType::AffinePacked) { weight_dims_other *= 4; } int w_inner_dims = (transpose) ? weight_dims : weight_dims_other; @@ -3708,7 +3708,7 @@ array quantized_matmul( bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, - QuantizationType type /* = QuantizationType::Affine */, + QuantizationType quantization_type /* = QuantizationType::Affine */, StreamOrDevice s /* = {} */) { // Check and extract the quantized matrix shape against x auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( @@ -3720,7 +3720,7 @@ array quantized_matmul( transpose, group_size, bits, - type); + quantization_type); // QuantizedMatmul handles w.ndim == 2 case. if (x.ndim() > 2 && w.ndim() > 2) { @@ -3779,7 +3779,7 @@ array quantized_matmul( std::move(out_shape), dtype, std::make_shared( - to_stream(s), type, group_size, bits, transpose), + to_stream(s), quantization_type, group_size, bits, transpose), std::move(inputs)); } @@ -3787,16 +3787,16 @@ std::tuple> quantize( const array& w, int group_size /* = 64 */, int bits /* = 4 */, - QuantizationType type /* = QuantizationType::Affine */, + QuantizationType quantization_type /* = QuantizationType::Affine */, StreamOrDevice s /* = {} */) { - switch (type) { + switch (quantization_type) { case QuantizationType::Affine: return fast::affine_quantize(w, group_size, bits, s); case QuantizationType::AffinePacked: { if (bits & (bits - 1)) { std::ostringstream msg; - msg << "[quantize] Quantization type '" << type << "' does not support " - << bits << " bits."; + msg << "[quantize] Quantization type '" << quantization_type + << "' does not support " << bits << " bits."; throw std::invalid_argument(msg.str()); } auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s); @@ -3822,7 +3822,7 @@ array dequantize( const std::optional& biases, int group_size /* = 64 */, int bits /* = 4 */, - QuantizationType type /* = QuantizationType::Affine */, + QuantizationType quantization_type /* = QuantizationType::Affine */, StreamOrDevice s /* = {} */) { return fast::affine_dequantize( w, scales, biases.value(), group_size, bits, s); @@ -3838,21 +3838,38 @@ array gather_qmm( bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, - QuantizationType type /* = QuantizationType::Affine */, + QuantizationType quantization_type /* = QuantizationType::Affine */, StreamOrDevice s /* = {} */) { if (!lhs_indices_ && !rhs_indices_) { return quantized_matmul( - x, w, scales, biases, transpose, group_size, bits, type, s); + x, + w, + scales, + biases, + transpose, + group_size, + bits, + quantization_type, + s); } - if (type != QuantizationType::Affine) { + if (quantization_type != QuantizationType::Affine) { std::ostringstream msg; - msg << "[gather_qmm] Only quantization type '" << type << "' supported"; + msg << "[gather_qmm] Only quantization type '" << quantization_type + << "' supported"; throw std::invalid_argument(msg.str()); } auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( - "gather_qmm", x, w, scales, biases, transpose, group_size, bits, type); + "gather_qmm", + x, + w, + scales, + biases, + transpose, + group_size, + bits, + quantization_type); // Extract indices and broadcast them array lhs_indices = indices_or_default(lhs_indices_, x, s); @@ -3874,7 +3891,7 @@ array gather_qmm( std::move(out_shape), out_type, std::make_shared( - to_stream(s), type, group_size, bits, transpose), + to_stream(s), quantization_type, group_size, bits, transpose), {astype(x, out_type, s), w, astype(scales, out_type, s), diff --git a/mlx/ops.h b/mlx/ops.h index b20482ca7..4a011c05a 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1286,7 +1286,7 @@ array quantized_matmul( bool transpose = true, int group_size = 64, int bits = 4, - QuantizationType type = QuantizationType::Affine, + QuantizationType quantization_type = QuantizationType::Affine, StreamOrDevice s = {}); /** Quantize a matrix along its last axis */ @@ -1294,7 +1294,7 @@ std::tuple> quantize( const array& w, int group_size = 64, int bits = 4, - QuantizationType type = QuantizationType::Affine, + QuantizationType quantization_type = QuantizationType::Affine, StreamOrDevice s = {}); /** Dequantize a matrix produced by quantize() */ @@ -1304,7 +1304,7 @@ array dequantize( const std::optional& biases, int group_size = 64, int bits = 4, - QuantizationType type = QuantizationType::Affine, + QuantizationType quantization_type = QuantizationType::Affine, StreamOrDevice s = {}); /** Compute matrix products with matrix-level gather. */ @@ -1318,7 +1318,7 @@ array gather_qmm( bool transpose = true, int group_size = 64, int bits = 4, - QuantizationType type = QuantizationType::Affine, + QuantizationType quantization_type = QuantizationType::Affine, StreamOrDevice s = {}); /** Returns a contraction of a and b over multiple dimensions. */ diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 4332b207b..a2f2831f2 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4025,7 +4025,7 @@ void init_ops(nb::module_& m) { bool transpose, int group_size, int bits, - const std::string& type, + const std::string& quantization_type, mx::StreamOrDevice s) { return mx::quantized_matmul( std::move(x), @@ -4035,7 +4035,7 @@ void init_ops(nb::module_& m) { transpose, group_size, bits, - mx::from_string(type), + mx::from_string(quantization_type), s); }, nb::arg(), @@ -4045,11 +4045,11 @@ void init_ops(nb::module_& m) { "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, - "type"_a = "affine", + "quantization_type"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), + "def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the matrix multiplication with the quantized matrix ``w``. The quantization uses one floating point scale and bias per ``group_size`` of @@ -4069,7 +4069,7 @@ void init_ops(nb::module_& m) { shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. - type (str, optional): The type of quantization used for the matrix. + quantization_type (str, optional): The type of quantization used for the matrix. It can be 'affine' or 'affine-packed'. Returns: @@ -4080,18 +4080,19 @@ void init_ops(nb::module_& m) { [](const mx::array& w, int group_size, int bits, - const std::string& type, + const std::string& quantization_type, mx::StreamOrDevice s) { - return mx::quantize(w, group_size, bits, mx::from_string(type), s); + return mx::quantize( + w, group_size, bits, mx::from_string(quantization_type), s); }, nb::arg(), "group_size"_a = 64, "bits"_a = 4, - "type"_a = "affine", + "quantization_type"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantize(w: array, /, group_size: int = 64, bits : int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"), + "def quantize(w: array, /, group_size: int = 64, bits : int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"), R"pbdoc( Quantize the matrix ``w`` using ``bits`` bits per element. @@ -4133,7 +4134,7 @@ void init_ops(nb::module_& m) { scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element of ``w`` in the returned quantized matrix. Default: ``4``. - type (str, optional): The type of quantization used for the matrix. + quantization_type (str, optional): The type of quantization used for the matrix. It can be 'affine' or 'affine-packed'. Returns: @@ -4152,21 +4153,27 @@ void init_ops(nb::module_& m) { const std::optional& biases, int group_size, int bits, - const std::string& type, + const std::string& quantization_type, mx::StreamOrDevice s) { return mx::dequantize( - wq, scales, biases, group_size, bits, mx::from_string(type), s); + wq, + scales, + biases, + group_size, + bits, + mx::from_string(quantization_type), + s); }, nb::arg(), "scales"_a, "biases"_a, "group_size"_a = 64, "bits"_a = 4, - "type"_a = "affine", + "quantization_type"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), + "def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Dequantize the matrix ``w`` using the provided ``scales`` and ``biases`` and the ``group_size`` and ``bits`` configuration. @@ -4187,7 +4194,7 @@ void init_ops(nb::module_& m) { scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. - type (str, optional): The type of quantization used for the matrix. + quantization_type (str, optional): The type of quantization used for the matrix. It can be 'affine' or 'affine-packed'. Returns: @@ -4205,7 +4212,7 @@ void init_ops(nb::module_& m) { "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, - "type"_a = "affine", + "quantization_type"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( @@ -4235,7 +4242,7 @@ void init_ops(nb::module_& m) { shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. - type (str, optional): The type of quantization used for the matrix. + quantization_type (str, optional): The type of quantization used for the matrix. It can be 'affine' or 'affine-packed'. Returns: