From 8ec8d44ee6b93c185a124f1f1f5307235bc17618 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 15 Aug 2025 17:36:55 -0700 Subject: [PATCH] add mode parameter for quantization --- docs/src/dev/custom_metal_kernels.rst | 2 +- mlx/ops.cpp | 9 ++- mlx/ops.h | 4 ++ mlx/primitives.cpp | 9 ++- mlx/primitives.h | 10 ++- python/mlx/nn/layers/embedding.py | 4 +- python/mlx/nn/layers/linear.py | 4 +- python/mlx/nn/layers/quantized.py | 53 +++++++++++++--- python/src/ops.cpp | 88 ++++++++++++++++----------- 9 files changed, 127 insertions(+), 56 deletions(-) diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index 1febe960a..4c4ce65ae 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -127,7 +127,7 @@ relying on a copy from ``ensure_row_contiguous``: name="myexp_strided", input_names=["inp"], output_names=["out"], - source=source + source=source, ensure_row_contiguous=False, ) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c8583c72f..4b1df9908 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4029,6 +4029,7 @@ array quantized_matmul( bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, + const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { // Check and extract the quantized matrix shape against x auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( @@ -4056,7 +4057,7 @@ array quantized_matmul( std::move(out_shape), dtype, std::make_shared( - to_stream(s), group_size, bits, transpose), + to_stream(s), group_size, bits, mode, transpose), std::move(inputs)); } @@ -4064,6 +4065,7 @@ std::tuple quantize( const array& w, int group_size /* = 64 */, int bits /* = 4 */, + const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { return fast::affine_quantize(w, group_size, bits, s); } @@ -4074,6 +4076,7 @@ array dequantize( const array& biases, int group_size /* = 64 */, int bits /* = 4 */, + const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { return fast::affine_dequantize(w, scales, biases, group_size, bits, s); } @@ -4088,11 +4091,12 @@ array gather_qmm( bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, + const std::string& mode /* = "affine" */, bool sorted_indices /* = false */, StreamOrDevice s /* = {} */) { if (!lhs_indices_ && !rhs_indices_) { return quantized_matmul( - x, w, scales, biases, transpose, group_size, bits, s); + x, w, scales, biases, transpose, group_size, bits, mode, s); } auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( @@ -4132,6 +4136,7 @@ array gather_qmm( to_stream(s), group_size, bits, + mode, transpose, sorted_indices && !rhs_indices_, sorted_indices && !lhs_indices_), diff --git a/mlx/ops.h b/mlx/ops.h index 596d6d287..37ab9800c 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1326,6 +1326,7 @@ array quantized_matmul( bool transpose = true, int group_size = 64, int bits = 4, + const std::string& mode = "affine", StreamOrDevice s = {}); /** Quantize a matrix along its last axis */ @@ -1333,6 +1334,7 @@ std::tuple quantize( const array& w, int group_size = 64, int bits = 4, + const std::string& mode = "affine", StreamOrDevice s = {}); /** Dequantize a matrix produced by quantize() */ @@ -1342,6 +1344,7 @@ array dequantize( const array& biases, int group_size = 64, int bits = 4, + const std::string& mode = "affine", StreamOrDevice s = {}); /** Compute matrix products with matrix-level gather. */ @@ -1355,6 +1358,7 @@ array gather_qmm( bool transpose = true, int group_size = 64, int bits = 4, + const std::string& mode = "affine", bool sorted_indices = false, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 980a1f7c3..4341235b6 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3234,6 +3234,7 @@ std::vector QuantizedMatmul::vjp( !transpose_, group_size_, bits_, + mode_, stream())); } @@ -3262,6 +3263,7 @@ std::vector QuantizedMatmul::vjp( zeros_like(primals[3], stream()), group_size_, bits_, + mode_, stream()); wq = unflatten(wq, -1, {-1, group_size_}, stream()); vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream())); @@ -3287,13 +3289,14 @@ std::vector QuantizedMatmul::jvp( transpose_, group_size_, bits_, + mode_, stream())}; } bool QuantizedMatmul::is_equivalent(const Primitive& other) const { const QuantizedMatmul& qm_other = static_cast(other); return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && - transpose_ == qm_other.transpose_; + mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_; } std::vector QuantizedMatmul::output_shapes( @@ -3348,6 +3351,7 @@ std::vector GatherQMM::vjp( !transpose_, group_size_, bits_, + mode_, sorted, stream()); if (sorted && no_broadcast) { @@ -3406,6 +3410,7 @@ std::vector GatherQMM::vjp( zeros_like(biases, stream()), group_size_, bits_, + mode_, stream()), -1, {-1, group_size_}, @@ -3430,7 +3435,7 @@ std::vector GatherQMM::jvp( bool GatherQMM::is_equivalent(const Primitive& other) const { const GatherQMM& qm_other = static_cast(other); return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && - transpose_ == qm_other.transpose_; + mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_; } std::pair, std::vector> RandomBits::vmap( diff --git a/mlx/primitives.h b/mlx/primitives.h index 277e42a0b..15d867653 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1597,10 +1597,12 @@ class QuantizedMatmul : public UnaryPrimitive { Stream stream, int group_size, int bits, + const std::string& mode, bool transpose) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), + mode_(mode), transpose_(transpose) {} void eval_cpu(const std::vector& inputs, array& out) override; @@ -1612,12 +1614,13 @@ class QuantizedMatmul : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { - return std::make_tuple(group_size_, bits_, transpose_); + return std::make_tuple(group_size_, bits_, mode_, transpose_); } private: int group_size_; int bits_; + std::string mode_; bool transpose_; }; @@ -1627,12 +1630,14 @@ class GatherQMM : public UnaryPrimitive { Stream stream, int group_size, int bits, + const std::string& mode, bool transpose, bool left_sorted = false, bool right_sorted = false) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), + mode_(mode), transpose_(transpose), left_sorted_(left_sorted), right_sorted_(right_sorted) {} @@ -1646,12 +1651,13 @@ class GatherQMM : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( - group_size_, bits_, transpose_, left_sorted_, right_sorted_); + group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_); } private: int group_size_; int bits_; + std::string mode_; bool transpose_; bool left_sorted_; bool right_sorted_; diff --git a/python/mlx/nn/layers/embedding.py b/python/mlx/nn/layers/embedding.py index 1e15a59cc..1edf7e3a5 100644 --- a/python/mlx/nn/layers/embedding.py +++ b/python/mlx/nn/layers/embedding.py @@ -39,6 +39,6 @@ class Embedding(Module): """ return x @ self.weight.T - def to_quantized(self, group_size: int = 64, bits: int = 4): + def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"): """Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer.""" - return QuantizedEmbedding.from_embedding(self, group_size, bits) + return QuantizedEmbedding.from_embedding(self, group_size, bits, mode) diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 63caa911c..84a4d8327 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -70,9 +70,9 @@ class Linear(Module): x = x @ self["weight"].T return x - def to_quantized(self, group_size: int = 64, bits: int = 4): + def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"): """Return a :obj:`QuantizedLinear` layer that approximates this layer.""" - return QuantizedLinear.from_linear(self, group_size, bits) + return QuantizedLinear.from_linear(self, group_size, bits, mode) class Bilinear(Module): diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 2d6dc0882..5894d7c15 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -12,6 +12,8 @@ def quantize( model: Module, group_size: int = 64, bits: int = 4, + *, + mode: str = "affine", class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None, ): """Quantize the sub-modules of a module according to a predicate. @@ -26,6 +28,8 @@ def quantize( :func:`mlx.core.quantize`). Default: ``64``. bits (int): The number of bits per parameter (see :func:`mlx.core.quantize`). Default: ``4``. + mode (str): The quantization method to use (see + :func:`mlx.core.quantize`). Default: ``"affine"``. class_predicate (Optional[Callable]): A callable which receives the :obj:`Module` path and :obj:`Module` itself and returns ``True`` or a dict of params for `to_quantized` if it should be quantized and @@ -39,7 +43,7 @@ def quantize( if bool_or_params := class_predicate(path, m): if hasattr(m, "to_quantized"): if isinstance(bool_or_params, bool): - return m.to_quantized(group_size=group_size, bits=bits) + return m.to_quantized(group_size=group_size, bits=bits, mode=mode) elif isinstance(bool_or_params, dict): return m.to_quantized(**bool_or_params) else: @@ -72,6 +76,8 @@ class QuantizedEmbedding(Module): weight. See :func:`~mlx.core.quantize`. Default: ``64``. bits (int, optional): The bit width to use for the quantized weight. See :func:`~mlx.core.quantize`. Default: ``4``. + mode (str): The quantization method to use (see + :func:`mlx.core.quantize`). Default: ``"affine"``. """ def __init__( @@ -80,17 +86,21 @@ class QuantizedEmbedding(Module): dims: int, group_size: int = 64, bits: int = 4, + mode: str = "affine", ): super().__init__() # Quantization config self.group_size = group_size self.bits = bits + self.mode = mode # Initialize the quantized weight scale = math.sqrt(1 / dims) weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) - self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + self.weight, self.scales, self.biases = mx.quantize( + weight, group_size, bits, mode=mode + ) self.num_embeddings = num_embeddings self.dims = dims @@ -104,6 +114,7 @@ class QuantizedEmbedding(Module): biases=self["biases"][x], group_size=self.group_size, bits=self.bits, + mode=self.mode, ) def as_linear(self, x): @@ -121,23 +132,31 @@ class QuantizedEmbedding(Module): transpose=True, group_size=self.group_size, bits=self.bits, + mode=self.mode, ) def _extra_repr(self): return ( f"{self.num_embeddings}, {self.dims}, " - f"group_size={self.group_size}, bits={self.bits}" + f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}" ) @classmethod def from_embedding( - cls, embedding_layer: Module, group_size: int = 64, bits: int = 4 + cls, + embedding_layer: Module, + group_size: int = 64, + bits: int = 4, + mode: str = "affine", ): """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" embedding_dims, dims = embedding_layer.weight.shape ql = cls(embedding_dims, dims, group_size, bits) ql.weight, ql.scales, ql.biases = mx.quantize( - embedding_layer.weight, group_size, bits + embedding_layer.weight, + group_size, + bits, + mode=mode, ) return ql @@ -161,6 +180,8 @@ class QuantizedLinear(Module): weight. See :func:`~mlx.core.quantize`. Default: ``64``. bits (int, optional): The bit width to use for the quantized weight. See :func:`~mlx.core.quantize`. Default: ``4``. + mode (str): The quantization method to use (see + :func:`mlx.core.quantize`). Default: ``"affine"``. """ def __init__( @@ -170,12 +191,14 @@ class QuantizedLinear(Module): bias: bool = True, group_size: int = 64, bits: int = 4, + mode: str = "affine", ): super().__init__() # Quantization config self.group_size = group_size self.bits = bits + self.mode = mode # Initialize the quantized weight scale = math.sqrt(1 / input_dims) @@ -184,7 +207,9 @@ class QuantizedLinear(Module): high=scale, shape=(output_dims, input_dims), ) - self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + self.weight, self.scales, self.biases = mx.quantize( + weight, group_size, bits, mode=mode + ) # And bias if needed if bias: @@ -198,7 +223,7 @@ class QuantizedLinear(Module): in_dims *= 32 // self.bits return ( f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " - f"group_size={self.group_size}, bits={self.bits}" + f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}" ) def __call__(self, x): @@ -210,18 +235,28 @@ class QuantizedLinear(Module): transpose=True, group_size=self.group_size, bits=self.bits, + mode=self.mode, ) if "bias" in self: x = x + self["bias"] return x @classmethod - def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): + def from_linear( + cls, + linear_layer: Module, + group_size: int = 64, + bits: int = 4, + mode: str = "affine", + ): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape ql = cls(input_dims, output_dims, False, group_size, bits) ql.weight, ql.scales, ql.biases = mx.quantize( - linear_layer.weight, group_size, bits + linear_layer.weight, + group_size, + bits, + mode=mode, ) if "bias" in linear_layer: ql.bias = linear_layer.bias diff --git a/python/src/ops.cpp b/python/src/ops.cpp index af64d9dfc..090ec842f 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4153,10 +4153,11 @@ void init_ops(nb::module_& m) { "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, + "mode"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the matrix multiplication with the quantized matrix ``w``. The quantization uses one floating point scale and bias per ``group_size`` of @@ -4175,6 +4176,7 @@ void init_ops(nb::module_& m) { shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. + mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: array: The result of the multiplication of ``x`` with ``w``. @@ -4185,10 +4187,11 @@ void init_ops(nb::module_& m) { nb::arg(), "group_size"_a = 64, "bits"_a = 4, + "mode"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), + "def quantize(w: array, /, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), R"pbdoc( Quantize the matrix ``w`` using ``bits`` bits per element. @@ -4199,30 +4202,10 @@ void init_ops(nb::module_& m) { .. warning:: - ``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32 + ``quantize`` currently only supports 2D inputs with the second + dimension divisible by ``group_size`` - Formally, for a group of :math:`g` consecutive elements :math:`w_1` to - :math:`w_g` in a row of ``w`` we compute the quantized representation - of each element :math:`\hat{w_i}` as follows - - .. math:: - - \begin{aligned} - \alpha &= \max_i w_i \\ - \beta &= \min_i w_i \\ - s &= \frac{\alpha - \beta}{2^b - 1} \\ - \hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right). - \end{aligned} - - After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits - and is packed in an unsigned 32-bit integer from the lower to upper - bits. For instance, for 4-bit quantization we fit 8 elements in an - unsigned 32 bit integer where the 1st element occupies the 4 least - significant bits, the 2nd bits 4-7 etc. - - In order to be able to dequantize the elements of ``w`` we also need to - save :math:`s` and :math:`\beta` which are the returned ``scales`` and - ``biases`` respectively. + The supported quantization modes are described in more detail below. Args: w (array): Matrix to be quantized @@ -4230,6 +4213,7 @@ void init_ops(nb::module_& m) { scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element of ``w`` in the returned quantized matrix. Default: ``4``. + mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: tuple: A tuple containing @@ -4237,6 +4221,31 @@ void init_ops(nb::module_& m) { * w_q (array): The quantized version of ``w`` * scales (array): The scale to multiply each element with, namely :math:`s` * biases (array): The biases to add to each element, namely :math:`\beta` + + Notes: + The currently supported quantization mode is `"affine"`. + Formally, for a group of :math:`g` consecutive elements :math:`w_1` to + :math:`w_g` in a row of ``w`` we compute the quantized representation + of each element :math:`\hat{w_i}` as follows + + .. math:: + + \begin{aligned} + \alpha &= \max_i w_i \\ + \beta &= \min_i w_i \\ + s &= \frac{\alpha - \beta}{2^b - 1} \\ + \hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right). + \end{aligned} + + After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits + and is packed in an unsigned 32-bit integer from the lower to upper + bits. For instance, for 4-bit quantization we fit 8 elements in an + unsigned 32 bit integer where the 1st element occupies the 4 least + significant bits, the 2nd bits 4-7 etc. + + In order to be able to dequantize the elements of ``w`` we also need to + save :math:`s` and :math:`\beta` which are the returned ``scales`` and + ``biases`` respectively. )pbdoc"); m.def( "dequantize", @@ -4246,21 +4255,15 @@ void init_ops(nb::module_& m) { "biases"_a, "group_size"_a = 64, "bits"_a = 4, + "mode"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - Dequantize the matrix ``w`` using the provided ``scales`` and - ``biases`` and the ``group_size`` and ``bits`` configuration. + Dequantize the matrix ``w`` using quantization parameters. - Formally, given the notation in :func:`quantize`, we compute - :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and - :math:`\beta` as follows - - .. math:: - - w_i = s \hat{w_i} + \beta + The supported quantization modes are described in more detail below. Args: w (array): Matrix to be quantized @@ -4270,9 +4273,20 @@ void init_ops(nb::module_& m) { scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. + mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: array: The dequantized version of ``w`` + + Notes: + The currently supported quantization mode is `"affine"`. + Formally, given the notation in :func:`quantize`, we compute + :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and + :math:`\beta` as follows + + .. math:: + + w_i = s \hat{w_i} + \beta )pbdoc"); m.def( "gather_qmm", @@ -4286,11 +4300,12 @@ void init_ops(nb::module_& m) { "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, + "mode"_a = "affine", nb::kw_only(), "sorted_indices"_a = false, "stream"_a = nb::none(), nb::sig( - "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform quantized matrix multiplication with matrix-level gather. @@ -4316,6 +4331,7 @@ void init_ops(nb::module_& m) { shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. + mode (str, optional): The quantization mode. Default: ``"affine"``. sorted_indices (bool, optional): May allow a faster implementation if the passed indices are sorted. Default: ``False``.