diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 4718149bd..7a567e8b9 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -158,6 +158,17 @@ std::ostream& operator<<(std::ostream& os, QuantizationType type) { return os << quantization_type; } +QuantizationType from_string(const std::string& s) { + if (s == "affine") { + return QuantizationType::Affine; + } + if (s == "affine-packed") { + return QuantizationType::AffinePacked; + } + + throw std::invalid_argument("Cannot map '" + s + "' to a quantization type"); +} + namespace { inline size_t diff --git a/mlx/utils.h b/mlx/utils.h index 05d0540a4..84e921596 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -111,6 +111,7 @@ enum class QuantizationType { }; std::ostream& operator<<(std::ostream& os, QuantizationType type); +QuantizationType from_string(const std::string& s); namespace env { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index abfbbbc7c..6572f9201 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4018,7 +4018,26 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "quantized_matmul", - &mx::quantized_matmul, + [](mx::array x, + mx::array w, + mx::array scales, + std::optional biases, + bool transpose, + int group_size, + int bits, + const std::string& type, + mx::StreamOrDevice s) { + return mx::quantized_matmul( + std::move(x), + std::move(w), + std::move(scales), + std::move(biases), + transpose, + group_size, + bits, + mx::from_string(type), + s); + }, nb::arg(), nb::arg(), "scales"_a, @@ -4026,10 +4045,11 @@ void init_ops(nb::module_& m) { "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, + "type"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the matrix multiplication with the quantized matrix ``w``. The quantization uses one floating point scale and bias per ``group_size`` of @@ -4040,7 +4060,8 @@ void init_ops(nb::module_& m) { x (array): Input array w (array): Quantized matrix packed in unsigned integers scales (array): The scales to use per ``group_size`` elements of ``w`` - biases (array): The biases to use per ``group_size`` elements of ``w`` + biases (array, optional): The biases to use per ``group_size`` + elements of ``w`` depending on the quantization type transpose (bool, optional): Defines whether to multiply with the transposed ``w`` or not, namely whether we are performing ``x @ w.T`` or ``x @ w``. Default: ``True``. @@ -4048,20 +4069,29 @@ void init_ops(nb::module_& m) { shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. + type (str, optional): The type of quantization used for the matrix. + It can be 'affine' or 'affine-packed'. Returns: array: The result of the multiplication of ``x`` with ``w``. )pbdoc"); m.def( "quantize", - &mx::quantize, + [](const mx::array& w, + int group_size, + int bits, + const std::string& type, + mx::StreamOrDevice s) { + return mx::quantize(w, group_size, bits, mx::from_string(type), s); + }, nb::arg(), "group_size"_a = 64, "bits"_a = 4, + "type"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), + "def quantize(w: array, /, group_size: int = 64, bits : int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"), R"pbdoc( Quantize the matrix ``w`` using ``bits`` bits per element. @@ -4103,13 +4133,17 @@ void init_ops(nb::module_& m) { scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element of ``w`` in the returned quantized matrix. Default: ``4``. + type (str, optional): The type of quantization used for the matrix. + It can be 'affine' or 'affine-packed'. Returns: tuple: A tuple containing * w_q (array): The quantized version of ``w`` * scales (array): The scale to multiply each element with, namely :math:`s` - * biases (array): The biases to add to each element, namely :math:`\beta` + * biases (array, optional): The biases to add to each element, namely + * :math:`\beta`. Depending on the quantization type this return value + may be None. )pbdoc"); m.def( "dequantize", @@ -4119,10 +4153,11 @@ void init_ops(nb::module_& m) { "biases"_a, "group_size"_a = 64, "bits"_a = 4, + "type"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Dequantize the matrix ``w`` using the provided ``scales`` and ``biases`` and the ``group_size`` and ``bits`` configuration. @@ -4143,6 +4178,8 @@ void init_ops(nb::module_& m) { scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. + type (str, optional): The type of quantization used for the matrix. + It can be 'affine' or 'affine-packed'. Returns: array: The dequantized version of ``w`` @@ -4159,10 +4196,11 @@ void init_ops(nb::module_& m) { "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, + "type"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array], lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform quantized matrix multiplication with matrix-level gather. @@ -4188,6 +4226,8 @@ void init_ops(nb::module_& m) { shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. + type (str, optional): The type of quantization used for the matrix. + It can be 'affine' or 'affine-packed'. Returns: array: The result of the multiplication of ``x`` with ``w``