mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Add the 3bit packed qmm_t
This commit is contained in:
		| @@ -2504,12 +2504,149 @@ struct AffinePackedQuantizedBlockLoader { | ||||
|         group_step_cnt++; | ||||
|         if (group_step_cnt == group_steps) { | ||||
|           group_step_cnt = 0; | ||||
|           scales += 8; | ||||
|           biases += 8; | ||||
|           scales += (2 * row_pack_factor); | ||||
|           biases += (2 * row_pack_factor); | ||||
|         } | ||||
|       } else { | ||||
|         scales += 8; | ||||
|         biases += 8; | ||||
|         scales += (2 * row_pack_factor); | ||||
|         biases += (2 * row_pack_factor); | ||||
|       } | ||||
|     } else { | ||||
|       scales += group_stride; | ||||
|       biases += group_stride; | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template < | ||||
|     typename T, | ||||
|     short BROWS, | ||||
|     short BCOLS, | ||||
|     short dst_ld, | ||||
|     short reduction_dim, | ||||
|     short tgp_size, | ||||
|     short group_size, | ||||
|     short bits> | ||||
| struct AffineScalesPackedQuantizedBlockLoader { | ||||
|   static_assert( | ||||
|       BCOLS <= group_size, | ||||
|       "The group size should be larger than the columns"); | ||||
|   static_assert( | ||||
|       group_size % BCOLS == 0, | ||||
|       "The group size should be divisible by the columns"); | ||||
|   static_assert( | ||||
|       bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, | ||||
|       "Template undefined for bits not in {2, 3, 4, 6, 8}"); | ||||
|  | ||||
|   MLX_MTL_CONST short bytes_per_pack = (bits & (bits - 1)) ? 3 : 4; | ||||
|   MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; | ||||
|   MLX_MTL_CONST short row_pack_factor = 2; | ||||
|   MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; | ||||
|   MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor; | ||||
|   MLX_MTL_CONST short TOTAL_READS = BCOLS * BROWS / pack_factor; | ||||
|   MLX_MTL_CONST short n_reads = | ||||
|       (TOTAL_READS < tgp_size) ? 1 : TOTAL_READS / tgp_size; | ||||
|   MLX_MTL_CONST short group_steps = group_size / BCOLS; | ||||
|  | ||||
|   const int src_ld; | ||||
|   const int tile_stride; | ||||
|   short group_step_cnt; | ||||
|   const int group_stride; | ||||
|  | ||||
|   const short thread_idx; | ||||
|   const short bi; | ||||
|   const short bj; | ||||
|   const short bii; | ||||
|  | ||||
|   const device uint8_t* src; | ||||
|   const device T* scales; | ||||
|   const device T* biases; | ||||
|   threadgroup T* dst; | ||||
|  | ||||
|   AffineScalesPackedQuantizedBlockLoader( | ||||
|       const device uint32_t* src_, | ||||
|       const device T* scales_, | ||||
|       const int src_ld_, | ||||
|       threadgroup T* dst_, | ||||
|       ushort simd_group_id [[simdgroup_index_in_threadgroup]], | ||||
|       ushort simd_lane_id [[thread_index_in_simdgroup]]) | ||||
|       : src_ld(src_ld_), | ||||
|         tile_stride( | ||||
|             reduction_dim ? BCOLS_PACKED * bytes_per_pack | ||||
|                           : BROWS * src_ld * bytes_per_pack / pack_factor), | ||||
|         group_step_cnt(0), | ||||
|         group_stride(BROWS_PACKED * 2 * src_ld / group_size), | ||||
|         thread_idx(simd_group_id * 32 + simd_lane_id), | ||||
|         bi(n_reads * thread_idx / BCOLS_PACKED), | ||||
|         bj((n_reads * thread_idx) % BCOLS_PACKED), | ||||
|         bii(bi / row_pack_factor), | ||||
|         src(((const device uint8_t*)src_) + | ||||
|             bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack), | ||||
|         scales( | ||||
|             scales_ + bii * 2 * src_ld * row_pack_factor / group_size + | ||||
|             bi % row_pack_factor), | ||||
|         biases(scales + row_pack_factor), | ||||
|         dst(dst_ + bi * dst_ld + bj * pack_factor) {} | ||||
|  | ||||
|   void load_unsafe() const { | ||||
|     if (bits == 2 && TOTAL_READS < tgp_size && bi >= BROWS) { | ||||
|       return; | ||||
|     } | ||||
|  | ||||
|     T scale = *scales; | ||||
|     T bias = *biases; | ||||
|     for (int i = 0; i < n_reads; i++) { | ||||
|       dequantize<T, pack_factor, bits>( | ||||
|           (const device uint8_t*)(src + bytes_per_pack * i), | ||||
|           scale, | ||||
|           bias, | ||||
|           dst + i * pack_factor); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void load_safe(short2 src_tile_dim) const { | ||||
|     if (TOTAL_READS < tgp_size && bi >= BROWS) { | ||||
|       return; | ||||
|     } | ||||
|  | ||||
|     if (reduction_dim == 1 && bii >= src_tile_dim.y) { | ||||
|       for (int i = 0; i < n_reads * pack_factor; i++) { | ||||
|         dst[i] = T(0); | ||||
|       } | ||||
|       return; | ||||
|     } | ||||
|  | ||||
|     if (reduction_dim == 0 && bii >= src_tile_dim.x) { | ||||
|       for (int i = 0; i < n_reads * pack_factor; i++) { | ||||
|         dst[i] = T(0); | ||||
|       } | ||||
|       return; | ||||
|     } | ||||
|  | ||||
|     for (int i = 0; i < n_reads; i++) { | ||||
|       T scale = scales[i]; | ||||
|       T bias = biases[i]; | ||||
|       dequantize<T, pack_factor, bits>( | ||||
|           (const device uint8_t*)(src + bytes_per_pack * i * src_ld), | ||||
|           scale, | ||||
|           bias, | ||||
|           dst + i * dst_ld); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void next() { | ||||
|     src += tile_stride; | ||||
|     if (reduction_dim == 1) { | ||||
|       if (group_steps > 1) { | ||||
|         group_step_cnt++; | ||||
|         if (group_step_cnt == group_steps) { | ||||
|           group_step_cnt = 0; | ||||
|           scales += (2 * row_pack_factor); | ||||
|           biases += (2 * row_pack_factor); | ||||
|         } | ||||
|       } else { | ||||
|         scales += (2 * row_pack_factor); | ||||
|         biases += (2 * row_pack_factor); | ||||
|       } | ||||
|     } else { | ||||
|       scales += group_stride; | ||||
| @@ -2545,10 +2682,11 @@ METAL_FUNC void affine_packed_qmm_t_impl( | ||||
|  | ||||
|   (void)lid; | ||||
|  | ||||
|   constexpr bool power_of_2_bits = (bits & (bits - 1)) == 0; | ||||
|   constexpr int WM = 2; | ||||
|   constexpr int WN = 2; | ||||
|   constexpr int pack_factor = 32 / bits; | ||||
|   constexpr int row_pack_factor = 4; | ||||
|   constexpr int row_pack_factor = (power_of_2_bits) ? 4 : 2; | ||||
|   constexpr int BK_padded = (BK + 16 / sizeof(T)); | ||||
|  | ||||
|   // Instantiate the appropriate BlockMMA and Loader | ||||
| @@ -2556,7 +2694,7 @@ METAL_FUNC void affine_packed_qmm_t_impl( | ||||
|       BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>; | ||||
|   using loader_x_t = | ||||
|       mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>; | ||||
|   using loader_w_t = AffinePackedQuantizedBlockLoader< | ||||
|   using loader_fully_packed_t = AffinePackedQuantizedBlockLoader< | ||||
|       T, | ||||
|       BN, | ||||
|       BK, | ||||
| @@ -2565,16 +2703,30 @@ METAL_FUNC void affine_packed_qmm_t_impl( | ||||
|       WM * WN * SIMD_SIZE, | ||||
|       group_size, | ||||
|       bits>; | ||||
|   using loader_scales_packed_t = AffineScalesPackedQuantizedBlockLoader< | ||||
|       T, | ||||
|       BN, | ||||
|       BK, | ||||
|       BK_padded, | ||||
|       1, | ||||
|       WM * WN * SIMD_SIZE, | ||||
|       group_size, | ||||
|       bits>; | ||||
|   using loader_w_t = typename ConditionalType< | ||||
|       power_of_2_bits, | ||||
|       loader_fully_packed_t, | ||||
|       loader_scales_packed_t>::type; | ||||
|  | ||||
|   // Set the block | ||||
|   const int K_w = K * row_pack_factor / pack_factor; | ||||
|   const int K_w = | ||||
|       (power_of_2_bits) ? K * row_pack_factor / pack_factor : K * bits / 32; | ||||
|   const int K_g = K * 2 * row_pack_factor / group_size; | ||||
|   const int y_row = tid.y * BM; | ||||
|   const int y_col = tid.x * BN; | ||||
|   const int packed_y_col = tid.x * (BN / row_pack_factor); | ||||
|  | ||||
|   x += y_row * K; | ||||
|   w += packed_y_col * K_w; | ||||
|   w += (power_of_2_bits) ? packed_y_col * K_w : y_col * K_w; | ||||
|   scales += packed_y_col * K_g; | ||||
|   y += y_row * N + y_col; | ||||
|  | ||||
| @@ -2692,9 +2844,6 @@ template < | ||||
|         s_strides, | ||||
|         tid); | ||||
|   } | ||||
|   if (bits & (bits - 1)) { | ||||
|   } else { | ||||
|     affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>( | ||||
|         w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); | ||||
|   } | ||||
|   affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>( | ||||
|       w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos