mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Add the 3bit packed qmm_t
This commit is contained in:
		@@ -2504,12 +2504,149 @@ struct AffinePackedQuantizedBlockLoader {
 | 
				
			|||||||
        group_step_cnt++;
 | 
					        group_step_cnt++;
 | 
				
			||||||
        if (group_step_cnt == group_steps) {
 | 
					        if (group_step_cnt == group_steps) {
 | 
				
			||||||
          group_step_cnt = 0;
 | 
					          group_step_cnt = 0;
 | 
				
			||||||
          scales += 8;
 | 
					          scales += (2 * row_pack_factor);
 | 
				
			||||||
          biases += 8;
 | 
					          biases += (2 * row_pack_factor);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        scales += 8;
 | 
					        scales += (2 * row_pack_factor);
 | 
				
			||||||
        biases += 8;
 | 
					        biases += (2 * row_pack_factor);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      scales += group_stride;
 | 
				
			||||||
 | 
					      biases += group_stride;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <
 | 
				
			||||||
 | 
					    typename T,
 | 
				
			||||||
 | 
					    short BROWS,
 | 
				
			||||||
 | 
					    short BCOLS,
 | 
				
			||||||
 | 
					    short dst_ld,
 | 
				
			||||||
 | 
					    short reduction_dim,
 | 
				
			||||||
 | 
					    short tgp_size,
 | 
				
			||||||
 | 
					    short group_size,
 | 
				
			||||||
 | 
					    short bits>
 | 
				
			||||||
 | 
					struct AffineScalesPackedQuantizedBlockLoader {
 | 
				
			||||||
 | 
					  static_assert(
 | 
				
			||||||
 | 
					      BCOLS <= group_size,
 | 
				
			||||||
 | 
					      "The group size should be larger than the columns");
 | 
				
			||||||
 | 
					  static_assert(
 | 
				
			||||||
 | 
					      group_size % BCOLS == 0,
 | 
				
			||||||
 | 
					      "The group size should be divisible by the columns");
 | 
				
			||||||
 | 
					  static_assert(
 | 
				
			||||||
 | 
					      bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
 | 
				
			||||||
 | 
					      "Template undefined for bits not in {2, 3, 4, 6, 8}");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MLX_MTL_CONST short bytes_per_pack = (bits & (bits - 1)) ? 3 : 4;
 | 
				
			||||||
 | 
					  MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
 | 
				
			||||||
 | 
					  MLX_MTL_CONST short row_pack_factor = 2;
 | 
				
			||||||
 | 
					  MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
 | 
				
			||||||
 | 
					  MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
 | 
				
			||||||
 | 
					  MLX_MTL_CONST short TOTAL_READS = BCOLS * BROWS / pack_factor;
 | 
				
			||||||
 | 
					  MLX_MTL_CONST short n_reads =
 | 
				
			||||||
 | 
					      (TOTAL_READS < tgp_size) ? 1 : TOTAL_READS / tgp_size;
 | 
				
			||||||
 | 
					  MLX_MTL_CONST short group_steps = group_size / BCOLS;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const int src_ld;
 | 
				
			||||||
 | 
					  const int tile_stride;
 | 
				
			||||||
 | 
					  short group_step_cnt;
 | 
				
			||||||
 | 
					  const int group_stride;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const short thread_idx;
 | 
				
			||||||
 | 
					  const short bi;
 | 
				
			||||||
 | 
					  const short bj;
 | 
				
			||||||
 | 
					  const short bii;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const device uint8_t* src;
 | 
				
			||||||
 | 
					  const device T* scales;
 | 
				
			||||||
 | 
					  const device T* biases;
 | 
				
			||||||
 | 
					  threadgroup T* dst;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  AffineScalesPackedQuantizedBlockLoader(
 | 
				
			||||||
 | 
					      const device uint32_t* src_,
 | 
				
			||||||
 | 
					      const device T* scales_,
 | 
				
			||||||
 | 
					      const int src_ld_,
 | 
				
			||||||
 | 
					      threadgroup T* dst_,
 | 
				
			||||||
 | 
					      ushort simd_group_id [[simdgroup_index_in_threadgroup]],
 | 
				
			||||||
 | 
					      ushort simd_lane_id [[thread_index_in_simdgroup]])
 | 
				
			||||||
 | 
					      : src_ld(src_ld_),
 | 
				
			||||||
 | 
					        tile_stride(
 | 
				
			||||||
 | 
					            reduction_dim ? BCOLS_PACKED * bytes_per_pack
 | 
				
			||||||
 | 
					                          : BROWS * src_ld * bytes_per_pack / pack_factor),
 | 
				
			||||||
 | 
					        group_step_cnt(0),
 | 
				
			||||||
 | 
					        group_stride(BROWS_PACKED * 2 * src_ld / group_size),
 | 
				
			||||||
 | 
					        thread_idx(simd_group_id * 32 + simd_lane_id),
 | 
				
			||||||
 | 
					        bi(n_reads * thread_idx / BCOLS_PACKED),
 | 
				
			||||||
 | 
					        bj((n_reads * thread_idx) % BCOLS_PACKED),
 | 
				
			||||||
 | 
					        bii(bi / row_pack_factor),
 | 
				
			||||||
 | 
					        src(((const device uint8_t*)src_) +
 | 
				
			||||||
 | 
					            bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack),
 | 
				
			||||||
 | 
					        scales(
 | 
				
			||||||
 | 
					            scales_ + bii * 2 * src_ld * row_pack_factor / group_size +
 | 
				
			||||||
 | 
					            bi % row_pack_factor),
 | 
				
			||||||
 | 
					        biases(scales + row_pack_factor),
 | 
				
			||||||
 | 
					        dst(dst_ + bi * dst_ld + bj * pack_factor) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void load_unsafe() const {
 | 
				
			||||||
 | 
					    if (bits == 2 && TOTAL_READS < tgp_size && bi >= BROWS) {
 | 
				
			||||||
 | 
					      return;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    T scale = *scales;
 | 
				
			||||||
 | 
					    T bias = *biases;
 | 
				
			||||||
 | 
					    for (int i = 0; i < n_reads; i++) {
 | 
				
			||||||
 | 
					      dequantize<T, pack_factor, bits>(
 | 
				
			||||||
 | 
					          (const device uint8_t*)(src + bytes_per_pack * i),
 | 
				
			||||||
 | 
					          scale,
 | 
				
			||||||
 | 
					          bias,
 | 
				
			||||||
 | 
					          dst + i * pack_factor);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void load_safe(short2 src_tile_dim) const {
 | 
				
			||||||
 | 
					    if (TOTAL_READS < tgp_size && bi >= BROWS) {
 | 
				
			||||||
 | 
					      return;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (reduction_dim == 1 && bii >= src_tile_dim.y) {
 | 
				
			||||||
 | 
					      for (int i = 0; i < n_reads * pack_factor; i++) {
 | 
				
			||||||
 | 
					        dst[i] = T(0);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      return;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (reduction_dim == 0 && bii >= src_tile_dim.x) {
 | 
				
			||||||
 | 
					      for (int i = 0; i < n_reads * pack_factor; i++) {
 | 
				
			||||||
 | 
					        dst[i] = T(0);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      return;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int i = 0; i < n_reads; i++) {
 | 
				
			||||||
 | 
					      T scale = scales[i];
 | 
				
			||||||
 | 
					      T bias = biases[i];
 | 
				
			||||||
 | 
					      dequantize<T, pack_factor, bits>(
 | 
				
			||||||
 | 
					          (const device uint8_t*)(src + bytes_per_pack * i * src_ld),
 | 
				
			||||||
 | 
					          scale,
 | 
				
			||||||
 | 
					          bias,
 | 
				
			||||||
 | 
					          dst + i * dst_ld);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void next() {
 | 
				
			||||||
 | 
					    src += tile_stride;
 | 
				
			||||||
 | 
					    if (reduction_dim == 1) {
 | 
				
			||||||
 | 
					      if (group_steps > 1) {
 | 
				
			||||||
 | 
					        group_step_cnt++;
 | 
				
			||||||
 | 
					        if (group_step_cnt == group_steps) {
 | 
				
			||||||
 | 
					          group_step_cnt = 0;
 | 
				
			||||||
 | 
					          scales += (2 * row_pack_factor);
 | 
				
			||||||
 | 
					          biases += (2 * row_pack_factor);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        scales += (2 * row_pack_factor);
 | 
				
			||||||
 | 
					        biases += (2 * row_pack_factor);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      scales += group_stride;
 | 
					      scales += group_stride;
 | 
				
			||||||
@@ -2545,10 +2682,11 @@ METAL_FUNC void affine_packed_qmm_t_impl(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  (void)lid;
 | 
					  (void)lid;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  constexpr bool power_of_2_bits = (bits & (bits - 1)) == 0;
 | 
				
			||||||
  constexpr int WM = 2;
 | 
					  constexpr int WM = 2;
 | 
				
			||||||
  constexpr int WN = 2;
 | 
					  constexpr int WN = 2;
 | 
				
			||||||
  constexpr int pack_factor = 32 / bits;
 | 
					  constexpr int pack_factor = 32 / bits;
 | 
				
			||||||
  constexpr int row_pack_factor = 4;
 | 
					  constexpr int row_pack_factor = (power_of_2_bits) ? 4 : 2;
 | 
				
			||||||
  constexpr int BK_padded = (BK + 16 / sizeof(T));
 | 
					  constexpr int BK_padded = (BK + 16 / sizeof(T));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Instantiate the appropriate BlockMMA and Loader
 | 
					  // Instantiate the appropriate BlockMMA and Loader
 | 
				
			||||||
@@ -2556,7 +2694,7 @@ METAL_FUNC void affine_packed_qmm_t_impl(
 | 
				
			|||||||
      BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
 | 
					      BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
 | 
				
			||||||
  using loader_x_t =
 | 
					  using loader_x_t =
 | 
				
			||||||
      mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
 | 
					      mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
 | 
				
			||||||
  using loader_w_t = AffinePackedQuantizedBlockLoader<
 | 
					  using loader_fully_packed_t = AffinePackedQuantizedBlockLoader<
 | 
				
			||||||
      T,
 | 
					      T,
 | 
				
			||||||
      BN,
 | 
					      BN,
 | 
				
			||||||
      BK,
 | 
					      BK,
 | 
				
			||||||
@@ -2565,16 +2703,30 @@ METAL_FUNC void affine_packed_qmm_t_impl(
 | 
				
			|||||||
      WM * WN * SIMD_SIZE,
 | 
					      WM * WN * SIMD_SIZE,
 | 
				
			||||||
      group_size,
 | 
					      group_size,
 | 
				
			||||||
      bits>;
 | 
					      bits>;
 | 
				
			||||||
 | 
					  using loader_scales_packed_t = AffineScalesPackedQuantizedBlockLoader<
 | 
				
			||||||
 | 
					      T,
 | 
				
			||||||
 | 
					      BN,
 | 
				
			||||||
 | 
					      BK,
 | 
				
			||||||
 | 
					      BK_padded,
 | 
				
			||||||
 | 
					      1,
 | 
				
			||||||
 | 
					      WM * WN * SIMD_SIZE,
 | 
				
			||||||
 | 
					      group_size,
 | 
				
			||||||
 | 
					      bits>;
 | 
				
			||||||
 | 
					  using loader_w_t = typename ConditionalType<
 | 
				
			||||||
 | 
					      power_of_2_bits,
 | 
				
			||||||
 | 
					      loader_fully_packed_t,
 | 
				
			||||||
 | 
					      loader_scales_packed_t>::type;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Set the block
 | 
					  // Set the block
 | 
				
			||||||
  const int K_w = K * row_pack_factor / pack_factor;
 | 
					  const int K_w =
 | 
				
			||||||
 | 
					      (power_of_2_bits) ? K * row_pack_factor / pack_factor : K * bits / 32;
 | 
				
			||||||
  const int K_g = K * 2 * row_pack_factor / group_size;
 | 
					  const int K_g = K * 2 * row_pack_factor / group_size;
 | 
				
			||||||
  const int y_row = tid.y * BM;
 | 
					  const int y_row = tid.y * BM;
 | 
				
			||||||
  const int y_col = tid.x * BN;
 | 
					  const int y_col = tid.x * BN;
 | 
				
			||||||
  const int packed_y_col = tid.x * (BN / row_pack_factor);
 | 
					  const int packed_y_col = tid.x * (BN / row_pack_factor);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  x += y_row * K;
 | 
					  x += y_row * K;
 | 
				
			||||||
  w += packed_y_col * K_w;
 | 
					  w += (power_of_2_bits) ? packed_y_col * K_w : y_col * K_w;
 | 
				
			||||||
  scales += packed_y_col * K_g;
 | 
					  scales += packed_y_col * K_g;
 | 
				
			||||||
  y += y_row * N + y_col;
 | 
					  y += y_row * N + y_col;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2692,9 +2844,6 @@ template <
 | 
				
			|||||||
        s_strides,
 | 
					        s_strides,
 | 
				
			||||||
        tid);
 | 
					        tid);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  if (bits & (bits - 1)) {
 | 
					  affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
 | 
				
			||||||
  } else {
 | 
					      w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
 | 
				
			||||||
    affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
 | 
					 | 
				
			||||||
        w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user