mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Support for quantized matmul with w and w^T (#349)
* Add the metal qvm implementation * Add qmm_n * Add gradient wrt to input for quantized_matmul
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							d7ac050f4b
						
					
				
				
					commit
					e7f5059fe4
				
			@@ -3072,12 +3072,13 @@ void init_ops(py::module_& m) {
 | 
			
		||||
      py::pos_only(),
 | 
			
		||||
      "scales"_a,
 | 
			
		||||
      "biases"_a,
 | 
			
		||||
      "transpose"_a = true,
 | 
			
		||||
      "group_size"_a = 64,
 | 
			
		||||
      "bits"_a = 4,
 | 
			
		||||
      py::kw_only(),
 | 
			
		||||
      "stream"_a = none,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        quantized_matmul(x: array, w: array, scales: array, biases: array, /, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
 | 
			
		||||
        quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
 | 
			
		||||
 | 
			
		||||
        Perform the matrix multiplication with the quantized matrix ``w``. The
 | 
			
		||||
        quantization uses one floating point scale and bias per ``group_size`` of
 | 
			
		||||
@@ -3089,10 +3090,13 @@ void init_ops(py::module_& m) {
 | 
			
		||||
          w (array): Quantized matrix packed in unsigned integers
 | 
			
		||||
          scales (array): The scales to use per ``group_size`` elements of ``w``
 | 
			
		||||
          biases (array): The biases to use per ``group_size`` elements of ``w``
 | 
			
		||||
          transpose (bool, optional): Defines whether to multiply with the
 | 
			
		||||
            transposed ``w`` or not, namely whether we are performing
 | 
			
		||||
            ``x @ w.T`` or ``x @ w``. (default: ``True``)
 | 
			
		||||
          group_size (int, optional): The size of the group in ``w`` that
 | 
			
		||||
            shares a scale and bias. (default: 64)
 | 
			
		||||
            shares a scale and bias. (default: ``64``)
 | 
			
		||||
          bits (int, optional): The number of bits occupied by each element in
 | 
			
		||||
            ``w``. (default: 4)
 | 
			
		||||
            ``w``. (default: ``4``)
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
          result (array): The result of the multiplication of ``x`` with ``w``.
 | 
			
		||||
@@ -3146,9 +3150,9 @@ void init_ops(py::module_& m) {
 | 
			
		||||
        Args:
 | 
			
		||||
          w (array): Matrix to be quantized
 | 
			
		||||
          group_size (int, optional): The size of the group in ``w`` that shares a
 | 
			
		||||
            scale and bias. (default: 64)
 | 
			
		||||
            scale and bias. (default: ``64``)
 | 
			
		||||
          bits (int, optional): The number of bits occupied by each element of
 | 
			
		||||
            ``w`` in the returned quantized matrix. (default: 4)
 | 
			
		||||
            ``w`` in the returned quantized matrix. (default: ``4``)
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
          (tuple): A tuple containing
 | 
			
		||||
@@ -3187,9 +3191,9 @@ void init_ops(py::module_& m) {
 | 
			
		||||
          scales (array): The scales to use per ``group_size`` elements of ``w``
 | 
			
		||||
          biases (array): The biases to use per ``group_size`` elements of ``w``
 | 
			
		||||
          group_size (int, optional): The size of the group in ``w`` that shares a
 | 
			
		||||
            scale and bias. (default: 64)
 | 
			
		||||
            scale and bias. (default: ``64``)
 | 
			
		||||
          bits (int, optional): The number of bits occupied by each element in
 | 
			
		||||
            ``w``. (default: 4)
 | 
			
		||||
            ``w``. (default: ``4``)
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
          result (array): The dequantized version of ``w``
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user