mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Adds C++ and nn quantization utilities (#230)
* Add C++ de-/quantize ops * Add quantize functions to the docs and tests * Add a QuantizedLinear module
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							4912ff3ec2
						
					
				
				
					commit
					57fe918cf8
				
			@@ -3035,4 +3035,101 @@ void init_ops(py::module_& m) {
 | 
			
		||||
        Returns:
 | 
			
		||||
          result (array): The result of the multiplication of ``x`` with ``w``.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "quantize",
 | 
			
		||||
      &quantize,
 | 
			
		||||
      "w"_a,
 | 
			
		||||
      py::pos_only(),
 | 
			
		||||
      "groups"_a = 128,
 | 
			
		||||
      "width"_a = 4,
 | 
			
		||||
      py::kw_only(),
 | 
			
		||||
      "stream"_a = none,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        quantize(w: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]
 | 
			
		||||
 | 
			
		||||
        Quantize the matrix ``w`` using ``width`` bits per element.
 | 
			
		||||
 | 
			
		||||
        Note, every ``groups`` elements in a row of ``w`` are quantized
 | 
			
		||||
        together. Hence, number of columns of ``w`` should be divisible by
 | 
			
		||||
        ``groups``. In particular, the rows of ``w`` are divided into groups of
 | 
			
		||||
        size ``groups`` which are quantized together.
 | 
			
		||||
 | 
			
		||||
        .. warning::
 | 
			
		||||
 | 
			
		||||
          ``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32
 | 
			
		||||
 | 
			
		||||
        Formally, for a group of :math:`g` consecutive elements :math:`w_1` to
 | 
			
		||||
        :math:`w_g` in a row of ``w`` we compute the quantized representation
 | 
			
		||||
        of each element :math:`\hat{w_i}` as follows
 | 
			
		||||
 | 
			
		||||
        .. math::
 | 
			
		||||
 | 
			
		||||
          \begin{aligned}
 | 
			
		||||
            \alpha &= \max_i w_i \\
 | 
			
		||||
            \beta &= \min_i w_i \\
 | 
			
		||||
            s &= \frac{\alpha - \beta}{2^b - 1} \\
 | 
			
		||||
            \hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right).
 | 
			
		||||
          \end{aligned}
 | 
			
		||||
 | 
			
		||||
        After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits
 | 
			
		||||
        and is packed in an unsigned 32-bit integer from the lower to upper
 | 
			
		||||
        bits. For instance, for 4-bit quantization we fit 8 elements in an
 | 
			
		||||
        unsigned 32 bit integer where the 1st element occupies the 4 least
 | 
			
		||||
        significant bits, the 2nd bits 4-7 etc.
 | 
			
		||||
 | 
			
		||||
        In order to be able to dequantize the elements of ``w`` we also need to
 | 
			
		||||
        save :math:`s` and :math:`\beta` which are the returned ``scales`` and
 | 
			
		||||
        ``biases`` respectively.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
          w (array): Matrix to be quantized
 | 
			
		||||
          groups (int, optional): The size of the group in ``w`` that shares a
 | 
			
		||||
            scale and bias. (default: 128)
 | 
			
		||||
          width (int, optional): The bitwidth of the elements in ``w``.
 | 
			
		||||
            (default: 4)
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
          (tuple): A tuple containing
 | 
			
		||||
 | 
			
		||||
            - w_q (array): The quantized version of ``w``
 | 
			
		||||
            - scales (array): The scale to multiply each element with, namely :math:`s`
 | 
			
		||||
            - biases (array): The biases to add to each element, namely :math:`\beta`
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "dequantize",
 | 
			
		||||
      &dequantize,
 | 
			
		||||
      "w"_a,
 | 
			
		||||
      py::pos_only(),
 | 
			
		||||
      "scales"_a,
 | 
			
		||||
      "biases"_a,
 | 
			
		||||
      "groups"_a = 128,
 | 
			
		||||
      "width"_a = 4,
 | 
			
		||||
      py::kw_only(),
 | 
			
		||||
      "stream"_a = none,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        dequantize(w: array, /, scales: array, biases: array, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
 | 
			
		||||
 | 
			
		||||
        Dequantize the matrix ``w`` using the provided ``scales`` and
 | 
			
		||||
        ``biases`` and the ``groups`` and ``width`` configuration.
 | 
			
		||||
 | 
			
		||||
        Formally, given the notation in :func:`quantize`, we compute
 | 
			
		||||
        :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
 | 
			
		||||
        :math:`\beta` as follows
 | 
			
		||||
 | 
			
		||||
        .. math::
 | 
			
		||||
 | 
			
		||||
          w_i = s \hat{w_i} - \beta
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
          w (array): Matrix to be quantized
 | 
			
		||||
          scales (array): The scales to use per ``groups`` elements of ``w``
 | 
			
		||||
          biases (array): The biases to use per ``groups`` elements of ``w``
 | 
			
		||||
          groups (int, optional): The size of the group in ``w`` that shares a
 | 
			
		||||
            scale and bias. (default: 128)
 | 
			
		||||
          width (int, optional): The bitwidth of the elements in ``w``.
 | 
			
		||||
            (default: 4)
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
          result (array): The dequantized version of w
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user