mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Add 3bit packed quants
This commit is contained in:
		@@ -2295,6 +2295,68 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename T, int group_size, int bits, int results_per_simdgroup>
 | 
				
			||||||
 | 
					METAL_FUNC void affine_packed_byte_qmv_fast_impl(
 | 
				
			||||||
 | 
					    const device uint8_t* w,
 | 
				
			||||||
 | 
					    const device vec<T, 2 * results_per_simdgroup>* scales,
 | 
				
			||||||
 | 
					    const device T* x,
 | 
				
			||||||
 | 
					    device T* y,
 | 
				
			||||||
 | 
					    const constant int& in_vec_size,
 | 
				
			||||||
 | 
					    const constant int& out_vec_size,
 | 
				
			||||||
 | 
					    uint3 tid [[threadgroup_position_in_grid]],
 | 
				
			||||||
 | 
					    uint simd_gid [[simdgroup_index_in_threadgroup]],
 | 
				
			||||||
 | 
					    uint simd_lid [[thread_index_in_simdgroup]]) {
 | 
				
			||||||
 | 
					  constexpr int packs_per_thread = 2;
 | 
				
			||||||
 | 
					  constexpr int num_simdgroups = 2;
 | 
				
			||||||
 | 
					  constexpr int pack_factor = (bits == 3) ? 8 : 4;
 | 
				
			||||||
 | 
					  ;
 | 
				
			||||||
 | 
					  constexpr int bytes_per_pack = 3;
 | 
				
			||||||
 | 
					  constexpr int values_per_thread = pack_factor * packs_per_thread;
 | 
				
			||||||
 | 
					  constexpr int block_size = values_per_thread * SIMD_SIZE;
 | 
				
			||||||
 | 
					  constexpr int scale_step_per_thread = group_size / values_per_thread;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  typedef float U;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  thread U x_thread[values_per_thread];
 | 
				
			||||||
 | 
					  vec<U, results_per_simdgroup> result = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Adjust positions
 | 
				
			||||||
 | 
					  const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
 | 
				
			||||||
 | 
					  const int in_vec_size_g = in_vec_size / group_size;
 | 
				
			||||||
 | 
					  const int scales_row = tid.x * num_simdgroups + simd_gid;
 | 
				
			||||||
 | 
					  const int out_row = scales_row * results_per_simdgroup;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  w += out_row * in_vec_size_w + simd_lid * (packs_per_thread * bytes_per_pack);
 | 
				
			||||||
 | 
					  scales += scales_row * in_vec_size_g + simd_lid / scale_step_per_thread;
 | 
				
			||||||
 | 
					  x += tid.y * in_vec_size + simd_lid * values_per_thread;
 | 
				
			||||||
 | 
					  y += tid.y * out_vec_size + out_row;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int k = 0; k < in_vec_size; k += block_size) {
 | 
				
			||||||
 | 
					    // Load the input vector
 | 
				
			||||||
 | 
					    U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Load the scales and biases
 | 
				
			||||||
 | 
					    vec<T, 2 * results_per_simdgroup> sb = scales[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Load the weights and perform the partial dot product
 | 
				
			||||||
 | 
					    for (int row = 0; row < results_per_simdgroup; row++) {
 | 
				
			||||||
 | 
					      result[row] += qdot<U, values_per_thread, bits>(
 | 
				
			||||||
 | 
					          w + row * in_vec_size_w, x_thread, sb[row], sb[2 + row], sum);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    w += block_size * bytes_per_pack / pack_factor;
 | 
				
			||||||
 | 
					    scales += block_size / group_size;
 | 
				
			||||||
 | 
					    x += block_size;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int row = 0; row < results_per_simdgroup; row++) {
 | 
				
			||||||
 | 
					    result[row] = simd_sum(result[row]);
 | 
				
			||||||
 | 
					    if (simd_lid == 0) {
 | 
				
			||||||
 | 
					      y[row] = static_cast<T>(result[row]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename T, int group_size, int bits>
 | 
					template <typename T, int group_size, int bits>
 | 
				
			||||||
[[kernel]] void affine_packed_qmv_fast(
 | 
					[[kernel]] void affine_packed_qmv_fast(
 | 
				
			||||||
    const device vec<uint32_t, 4>* w [[buffer(0)]],
 | 
					    const device vec<uint32_t, 4>* w [[buffer(0)]],
 | 
				
			||||||
@@ -2306,8 +2368,21 @@ template <typename T, int group_size, int bits>
 | 
				
			|||||||
    uint3 tid [[threadgroup_position_in_grid]],
 | 
					    uint3 tid [[threadgroup_position_in_grid]],
 | 
				
			||||||
    uint simd_gid [[simdgroup_index_in_threadgroup]],
 | 
					    uint simd_gid [[simdgroup_index_in_threadgroup]],
 | 
				
			||||||
    uint simd_lid [[thread_index_in_simdgroup]]) {
 | 
					    uint simd_lid [[thread_index_in_simdgroup]]) {
 | 
				
			||||||
  affine_packed_qmv_fast_impl<T, group_size, bits>(
 | 
					  if (bits & (bits - 1)) {
 | 
				
			||||||
      w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
 | 
					    affine_packed_byte_qmv_fast_impl<T, group_size, bits, 2>(
 | 
				
			||||||
 | 
					        (const device uint8_t*)w,
 | 
				
			||||||
 | 
					        scales,
 | 
				
			||||||
 | 
					        x,
 | 
				
			||||||
 | 
					        y,
 | 
				
			||||||
 | 
					        in_vec_size,
 | 
				
			||||||
 | 
					        out_vec_size,
 | 
				
			||||||
 | 
					        tid,
 | 
				
			||||||
 | 
					        simd_gid,
 | 
				
			||||||
 | 
					        simd_lid);
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    affine_packed_qmv_fast_impl<T, group_size, bits>(
 | 
				
			||||||
 | 
					        w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <
 | 
					template <
 | 
				
			||||||
@@ -2617,6 +2692,9 @@ template <
 | 
				
			|||||||
        s_strides,
 | 
					        s_strides,
 | 
				
			||||||
        tid);
 | 
					        tid);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
 | 
					  if (bits & (bits - 1)) {
 | 
				
			||||||
      w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
 | 
					  } else {
 | 
				
			||||||
 | 
					    affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
 | 
				
			||||||
 | 
					        w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -405,10 +405,11 @@ void affine_packed_qmv(
 | 
				
			|||||||
  auto w = ensure_row_contiguous_last_dims(inputs[1]);
 | 
					  auto w = ensure_row_contiguous_last_dims(inputs[1]);
 | 
				
			||||||
  auto scales = ensure_row_contiguous_last_dims(inputs[2]);
 | 
					  auto scales = ensure_row_contiguous_last_dims(inputs[2]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const bool pow2_bits = (bits & (bits - 1)) == 0;
 | 
				
			||||||
  const int n_simdgroups = 2;
 | 
					  const int n_simdgroups = 2;
 | 
				
			||||||
  const int n_outs_per_simdgroup = 4;
 | 
					  const int results_per_simdgroup = (pow2_bits) ? 4 : 2;
 | 
				
			||||||
  MTL::Size group_dims(32, n_simdgroups, 1);
 | 
					  MTL::Size group_dims(32, n_simdgroups, 1);
 | 
				
			||||||
  MTL::Size grid_dims(O / n_simdgroups / n_outs_per_simdgroup, B, 1);
 | 
					  MTL::Size grid_dims(O / n_simdgroups / results_per_simdgroup, B, 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::string name;
 | 
					  std::string name;
 | 
				
			||||||
  name.reserve(64);
 | 
					  name.reserve(64);
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										39
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							
							
						
						
									
										39
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							@@ -99,15 +99,11 @@ std::pair<int, int> extract_quantized_matmul_dims(
 | 
				
			|||||||
            << "biases but biases were provided";
 | 
					            << "biases but biases were provided";
 | 
				
			||||||
        throw std::invalid_argument(msg.str());
 | 
					        throw std::invalid_argument(msg.str());
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      if (bits & (bits - 1)) {
 | 
					 | 
				
			||||||
        std::ostringstream msg;
 | 
					 | 
				
			||||||
        msg << "[" << tag << "] Quantization type '" << quantization_type
 | 
					 | 
				
			||||||
            << "' does not support " << bits << " bits.";
 | 
					 | 
				
			||||||
        throw std::invalid_argument(msg.str());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      break;
 | 
					      break;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  bool pow2_bits = (bits & (bits - 1)) == 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (w.dtype() != uint32) {
 | 
					  if (w.dtype() != uint32) {
 | 
				
			||||||
    std::ostringstream msg;
 | 
					    std::ostringstream msg;
 | 
				
			||||||
    msg << "[" << tag << "] The weight matrix should be uint32 "
 | 
					    msg << "[" << tag << "] The weight matrix should be uint32 "
 | 
				
			||||||
@@ -136,8 +132,12 @@ std::pair<int, int> extract_quantized_matmul_dims(
 | 
				
			|||||||
  int weight_dims = w.shape(-1) * 32 / bits;
 | 
					  int weight_dims = w.shape(-1) * 32 / bits;
 | 
				
			||||||
  int scales_dims = scales.shape(-1) * group_size;
 | 
					  int scales_dims = scales.shape(-1) * group_size;
 | 
				
			||||||
  if (quantization_type == QuantizationType::AffinePacked) {
 | 
					  if (quantization_type == QuantizationType::AffinePacked) {
 | 
				
			||||||
    scales_dims /= 8;
 | 
					    if (pow2_bits) {
 | 
				
			||||||
    weight_dims /= 4;
 | 
					      scales_dims /= 8;
 | 
				
			||||||
 | 
					      weight_dims /= 4;
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      scales_dims /= 4;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (weight_dims != scales_dims) {
 | 
					  if (weight_dims != scales_dims) {
 | 
				
			||||||
@@ -155,7 +155,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  // Calculate the expanded w's dims
 | 
					  // Calculate the expanded w's dims
 | 
				
			||||||
  int weight_dims_other = w.shape(-2);
 | 
					  int weight_dims_other = w.shape(-2);
 | 
				
			||||||
  if (quantization_type == QuantizationType::AffinePacked) {
 | 
					  if (quantization_type == QuantizationType::AffinePacked && pow2_bits) {
 | 
				
			||||||
    weight_dims_other *= 4;
 | 
					    weight_dims_other *= 4;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  int w_inner_dims = (transpose) ? weight_dims : weight_dims_other;
 | 
					  int w_inner_dims = (transpose) ? weight_dims : weight_dims_other;
 | 
				
			||||||
@@ -3793,23 +3793,22 @@ std::tuple<array, array, std::optional<array>> quantize(
 | 
				
			|||||||
    case QuantizationType::Affine:
 | 
					    case QuantizationType::Affine:
 | 
				
			||||||
      return fast::affine_quantize(w, group_size, bits, s);
 | 
					      return fast::affine_quantize(w, group_size, bits, s);
 | 
				
			||||||
    case QuantizationType::AffinePacked: {
 | 
					    case QuantizationType::AffinePacked: {
 | 
				
			||||||
      if (bits & (bits - 1)) {
 | 
					 | 
				
			||||||
        std::ostringstream msg;
 | 
					 | 
				
			||||||
        msg << "[quantize] Quantization type '" << quantization_type
 | 
					 | 
				
			||||||
            << "' does not support " << bits << " bits.";
 | 
					 | 
				
			||||||
        throw std::invalid_argument(msg.str());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
 | 
					      auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      scales = unflatten(scales, -2, {-1, 4}, s);
 | 
					      int pow2_bits = (bits & (bits - 1)) == 0;
 | 
				
			||||||
      biases = unflatten(biases, -2, {-1, 4}, s);
 | 
					      int row_packing = (pow2_bits) ? 4 : 2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      scales = unflatten(scales, -2, {-1, row_packing}, s);
 | 
				
			||||||
 | 
					      biases = unflatten(biases, -2, {-1, row_packing}, s);
 | 
				
			||||||
      scales = concatenate({scales, biases}, -2, s);
 | 
					      scales = concatenate({scales, biases}, -2, s);
 | 
				
			||||||
      scales = moveaxis(scales, -2, -1, s);
 | 
					      scales = moveaxis(scales, -2, -1, s);
 | 
				
			||||||
      scales = flatten(scales, -2, -1, s);
 | 
					      scales = flatten(scales, -2, -1, s);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      wq = unflatten(wq, -2, {-1, 4}, s);
 | 
					      if (pow2_bits) {
 | 
				
			||||||
      wq = moveaxis(wq, -2, -1, s);
 | 
					        wq = unflatten(wq, -2, {-1, row_packing}, s);
 | 
				
			||||||
      wq = flatten(wq, -2, -1, s);
 | 
					        wq = moveaxis(wq, -2, -1, s);
 | 
				
			||||||
 | 
					        wq = flatten(wq, -2, -1, s);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      return std::make_tuple(wq, scales, std::nullopt);
 | 
					      return std::make_tuple(wq, scales, std::nullopt);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -181,12 +181,14 @@ class QuantizedLinear(Module):
 | 
				
			|||||||
        bias: bool = True,
 | 
					        bias: bool = True,
 | 
				
			||||||
        group_size: int = 64,
 | 
					        group_size: int = 64,
 | 
				
			||||||
        bits: int = 4,
 | 
					        bits: int = 4,
 | 
				
			||||||
 | 
					        quantization_type: str = "affine",
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Quantization config
 | 
					        # Quantization config
 | 
				
			||||||
        self.group_size = group_size
 | 
					        self.group_size = group_size
 | 
				
			||||||
        self.bits = bits
 | 
					        self.bits = bits
 | 
				
			||||||
 | 
					        self.quantization_type = quantization_type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Initialize the quantized weight
 | 
					        # Initialize the quantized weight
 | 
				
			||||||
        scale = math.sqrt(1 / input_dims)
 | 
					        scale = math.sqrt(1 / input_dims)
 | 
				
			||||||
@@ -195,7 +197,9 @@ class QuantizedLinear(Module):
 | 
				
			|||||||
            high=scale,
 | 
					            high=scale,
 | 
				
			||||||
            shape=(output_dims, input_dims),
 | 
					            shape=(output_dims, input_dims),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
 | 
					        self.weight, self.scales, self.biases = mx.quantize(
 | 
				
			||||||
 | 
					            weight, group_size, bits, quantization_type=quantization_type
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # And bias if needed
 | 
					        # And bias if needed
 | 
				
			||||||
        if bias:
 | 
					        if bias:
 | 
				
			||||||
@@ -223,10 +227,11 @@ class QuantizedLinear(Module):
 | 
				
			|||||||
            x,
 | 
					            x,
 | 
				
			||||||
            self["weight"],
 | 
					            self["weight"],
 | 
				
			||||||
            scales=self["scales"],
 | 
					            scales=self["scales"],
 | 
				
			||||||
            biases=self["biases"],
 | 
					            biases=self.get("biases", None),
 | 
				
			||||||
            transpose=True,
 | 
					            transpose=True,
 | 
				
			||||||
            group_size=self.group_size,
 | 
					            group_size=self.group_size,
 | 
				
			||||||
            bits=self.bits,
 | 
					            bits=self.bits,
 | 
				
			||||||
 | 
					            quantization_type=self.quantization_type,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        if "bias" in self:
 | 
					        if "bias" in self:
 | 
				
			||||||
            x = x + self["bias"]
 | 
					            x = x + self["bias"]
 | 
				
			||||||
@@ -242,7 +247,7 @@ class QuantizedLinear(Module):
 | 
				
			|||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
 | 
					        """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
 | 
				
			||||||
        output_dims, input_dims = linear_layer.weight.shape
 | 
					        output_dims, input_dims = linear_layer.weight.shape
 | 
				
			||||||
        ql = cls(input_dims, output_dims, False, group_size, bits)
 | 
					        ql = cls(input_dims, output_dims, False, group_size, bits, quantization_type)
 | 
				
			||||||
        ql.weight, ql.scales, ql.biases = mx.quantize(
 | 
					        ql.weight, ql.scales, ql.biases = mx.quantize(
 | 
				
			||||||
            linear_layer.weight, group_size, bits, quantization_type=quantization_type
 | 
					            linear_layer.weight, group_size, bits, quantization_type=quantization_type
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user