mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Add the 3bit packed qmm_t
This commit is contained in:
		@@ -2504,12 +2504,149 @@ struct AffinePackedQuantizedBlockLoader {
 | 
			
		||||
        group_step_cnt++;
 | 
			
		||||
        if (group_step_cnt == group_steps) {
 | 
			
		||||
          group_step_cnt = 0;
 | 
			
		||||
          scales += 8;
 | 
			
		||||
          biases += 8;
 | 
			
		||||
          scales += (2 * row_pack_factor);
 | 
			
		||||
          biases += (2 * row_pack_factor);
 | 
			
		||||
        }
 | 
			
		||||
      } else {
 | 
			
		||||
        scales += 8;
 | 
			
		||||
        biases += 8;
 | 
			
		||||
        scales += (2 * row_pack_factor);
 | 
			
		||||
        biases += (2 * row_pack_factor);
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      scales += group_stride;
 | 
			
		||||
      biases += group_stride;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <
 | 
			
		||||
    typename T,
 | 
			
		||||
    short BROWS,
 | 
			
		||||
    short BCOLS,
 | 
			
		||||
    short dst_ld,
 | 
			
		||||
    short reduction_dim,
 | 
			
		||||
    short tgp_size,
 | 
			
		||||
    short group_size,
 | 
			
		||||
    short bits>
 | 
			
		||||
struct AffineScalesPackedQuantizedBlockLoader {
 | 
			
		||||
  static_assert(
 | 
			
		||||
      BCOLS <= group_size,
 | 
			
		||||
      "The group size should be larger than the columns");
 | 
			
		||||
  static_assert(
 | 
			
		||||
      group_size % BCOLS == 0,
 | 
			
		||||
      "The group size should be divisible by the columns");
 | 
			
		||||
  static_assert(
 | 
			
		||||
      bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
 | 
			
		||||
      "Template undefined for bits not in {2, 3, 4, 6, 8}");
 | 
			
		||||
 | 
			
		||||
  MLX_MTL_CONST short bytes_per_pack = (bits & (bits - 1)) ? 3 : 4;
 | 
			
		||||
  MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
 | 
			
		||||
  MLX_MTL_CONST short row_pack_factor = 2;
 | 
			
		||||
  MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
 | 
			
		||||
  MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
 | 
			
		||||
  MLX_MTL_CONST short TOTAL_READS = BCOLS * BROWS / pack_factor;
 | 
			
		||||
  MLX_MTL_CONST short n_reads =
 | 
			
		||||
      (TOTAL_READS < tgp_size) ? 1 : TOTAL_READS / tgp_size;
 | 
			
		||||
  MLX_MTL_CONST short group_steps = group_size / BCOLS;
 | 
			
		||||
 | 
			
		||||
  const int src_ld;
 | 
			
		||||
  const int tile_stride;
 | 
			
		||||
  short group_step_cnt;
 | 
			
		||||
  const int group_stride;
 | 
			
		||||
 | 
			
		||||
  const short thread_idx;
 | 
			
		||||
  const short bi;
 | 
			
		||||
  const short bj;
 | 
			
		||||
  const short bii;
 | 
			
		||||
 | 
			
		||||
  const device uint8_t* src;
 | 
			
		||||
  const device T* scales;
 | 
			
		||||
  const device T* biases;
 | 
			
		||||
  threadgroup T* dst;
 | 
			
		||||
 | 
			
		||||
  AffineScalesPackedQuantizedBlockLoader(
 | 
			
		||||
      const device uint32_t* src_,
 | 
			
		||||
      const device T* scales_,
 | 
			
		||||
      const int src_ld_,
 | 
			
		||||
      threadgroup T* dst_,
 | 
			
		||||
      ushort simd_group_id [[simdgroup_index_in_threadgroup]],
 | 
			
		||||
      ushort simd_lane_id [[thread_index_in_simdgroup]])
 | 
			
		||||
      : src_ld(src_ld_),
 | 
			
		||||
        tile_stride(
 | 
			
		||||
            reduction_dim ? BCOLS_PACKED * bytes_per_pack
 | 
			
		||||
                          : BROWS * src_ld * bytes_per_pack / pack_factor),
 | 
			
		||||
        group_step_cnt(0),
 | 
			
		||||
        group_stride(BROWS_PACKED * 2 * src_ld / group_size),
 | 
			
		||||
        thread_idx(simd_group_id * 32 + simd_lane_id),
 | 
			
		||||
        bi(n_reads * thread_idx / BCOLS_PACKED),
 | 
			
		||||
        bj((n_reads * thread_idx) % BCOLS_PACKED),
 | 
			
		||||
        bii(bi / row_pack_factor),
 | 
			
		||||
        src(((const device uint8_t*)src_) +
 | 
			
		||||
            bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack),
 | 
			
		||||
        scales(
 | 
			
		||||
            scales_ + bii * 2 * src_ld * row_pack_factor / group_size +
 | 
			
		||||
            bi % row_pack_factor),
 | 
			
		||||
        biases(scales + row_pack_factor),
 | 
			
		||||
        dst(dst_ + bi * dst_ld + bj * pack_factor) {}
 | 
			
		||||
 | 
			
		||||
  void load_unsafe() const {
 | 
			
		||||
    if (bits == 2 && TOTAL_READS < tgp_size && bi >= BROWS) {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    T scale = *scales;
 | 
			
		||||
    T bias = *biases;
 | 
			
		||||
    for (int i = 0; i < n_reads; i++) {
 | 
			
		||||
      dequantize<T, pack_factor, bits>(
 | 
			
		||||
          (const device uint8_t*)(src + bytes_per_pack * i),
 | 
			
		||||
          scale,
 | 
			
		||||
          bias,
 | 
			
		||||
          dst + i * pack_factor);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void load_safe(short2 src_tile_dim) const {
 | 
			
		||||
    if (TOTAL_READS < tgp_size && bi >= BROWS) {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (reduction_dim == 1 && bii >= src_tile_dim.y) {
 | 
			
		||||
      for (int i = 0; i < n_reads * pack_factor; i++) {
 | 
			
		||||
        dst[i] = T(0);
 | 
			
		||||
      }
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (reduction_dim == 0 && bii >= src_tile_dim.x) {
 | 
			
		||||
      for (int i = 0; i < n_reads * pack_factor; i++) {
 | 
			
		||||
        dst[i] = T(0);
 | 
			
		||||
      }
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int i = 0; i < n_reads; i++) {
 | 
			
		||||
      T scale = scales[i];
 | 
			
		||||
      T bias = biases[i];
 | 
			
		||||
      dequantize<T, pack_factor, bits>(
 | 
			
		||||
          (const device uint8_t*)(src + bytes_per_pack * i * src_ld),
 | 
			
		||||
          scale,
 | 
			
		||||
          bias,
 | 
			
		||||
          dst + i * dst_ld);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void next() {
 | 
			
		||||
    src += tile_stride;
 | 
			
		||||
    if (reduction_dim == 1) {
 | 
			
		||||
      if (group_steps > 1) {
 | 
			
		||||
        group_step_cnt++;
 | 
			
		||||
        if (group_step_cnt == group_steps) {
 | 
			
		||||
          group_step_cnt = 0;
 | 
			
		||||
          scales += (2 * row_pack_factor);
 | 
			
		||||
          biases += (2 * row_pack_factor);
 | 
			
		||||
        }
 | 
			
		||||
      } else {
 | 
			
		||||
        scales += (2 * row_pack_factor);
 | 
			
		||||
        biases += (2 * row_pack_factor);
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      scales += group_stride;
 | 
			
		||||
@@ -2545,10 +2682,11 @@ METAL_FUNC void affine_packed_qmm_t_impl(
 | 
			
		||||
 | 
			
		||||
  (void)lid;
 | 
			
		||||
 | 
			
		||||
  constexpr bool power_of_2_bits = (bits & (bits - 1)) == 0;
 | 
			
		||||
  constexpr int WM = 2;
 | 
			
		||||
  constexpr int WN = 2;
 | 
			
		||||
  constexpr int pack_factor = 32 / bits;
 | 
			
		||||
  constexpr int row_pack_factor = 4;
 | 
			
		||||
  constexpr int row_pack_factor = (power_of_2_bits) ? 4 : 2;
 | 
			
		||||
  constexpr int BK_padded = (BK + 16 / sizeof(T));
 | 
			
		||||
 | 
			
		||||
  // Instantiate the appropriate BlockMMA and Loader
 | 
			
		||||
@@ -2556,7 +2694,7 @@ METAL_FUNC void affine_packed_qmm_t_impl(
 | 
			
		||||
      BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
 | 
			
		||||
  using loader_x_t =
 | 
			
		||||
      mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
 | 
			
		||||
  using loader_w_t = AffinePackedQuantizedBlockLoader<
 | 
			
		||||
  using loader_fully_packed_t = AffinePackedQuantizedBlockLoader<
 | 
			
		||||
      T,
 | 
			
		||||
      BN,
 | 
			
		||||
      BK,
 | 
			
		||||
@@ -2565,16 +2703,30 @@ METAL_FUNC void affine_packed_qmm_t_impl(
 | 
			
		||||
      WM * WN * SIMD_SIZE,
 | 
			
		||||
      group_size,
 | 
			
		||||
      bits>;
 | 
			
		||||
  using loader_scales_packed_t = AffineScalesPackedQuantizedBlockLoader<
 | 
			
		||||
      T,
 | 
			
		||||
      BN,
 | 
			
		||||
      BK,
 | 
			
		||||
      BK_padded,
 | 
			
		||||
      1,
 | 
			
		||||
      WM * WN * SIMD_SIZE,
 | 
			
		||||
      group_size,
 | 
			
		||||
      bits>;
 | 
			
		||||
  using loader_w_t = typename ConditionalType<
 | 
			
		||||
      power_of_2_bits,
 | 
			
		||||
      loader_fully_packed_t,
 | 
			
		||||
      loader_scales_packed_t>::type;
 | 
			
		||||
 | 
			
		||||
  // Set the block
 | 
			
		||||
  const int K_w = K * row_pack_factor / pack_factor;
 | 
			
		||||
  const int K_w =
 | 
			
		||||
      (power_of_2_bits) ? K * row_pack_factor / pack_factor : K * bits / 32;
 | 
			
		||||
  const int K_g = K * 2 * row_pack_factor / group_size;
 | 
			
		||||
  const int y_row = tid.y * BM;
 | 
			
		||||
  const int y_col = tid.x * BN;
 | 
			
		||||
  const int packed_y_col = tid.x * (BN / row_pack_factor);
 | 
			
		||||
 | 
			
		||||
  x += y_row * K;
 | 
			
		||||
  w += packed_y_col * K_w;
 | 
			
		||||
  w += (power_of_2_bits) ? packed_y_col * K_w : y_col * K_w;
 | 
			
		||||
  scales += packed_y_col * K_g;
 | 
			
		||||
  y += y_row * N + y_col;
 | 
			
		||||
 | 
			
		||||
@@ -2692,9 +2844,6 @@ template <
 | 
			
		||||
        s_strides,
 | 
			
		||||
        tid);
 | 
			
		||||
  }
 | 
			
		||||
  if (bits & (bits - 1)) {
 | 
			
		||||
  } else {
 | 
			
		||||
  affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
 | 
			
		||||
      w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
 | 
			
		||||
}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user