mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-25 20:58:13 +08:00 
			
		
		
		
	Improve names of quantization arguments (#235)
* Change the default quantization group_size to 64 * Rename groups to group_size and width to bits
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							57fe918cf8
						
					
				
				
					commit
					b3916cbf2b
				
			| @@ -19,12 +19,12 @@ void _qmm_t_4_64( | ||||
|     int M, | ||||
|     int N, | ||||
|     int K) { | ||||
|   constexpr int width = 4; | ||||
|   constexpr int groups = 64; | ||||
|   constexpr int bitmask = (1 << width) - 1; | ||||
|   constexpr int pack_factor = 32 / width; | ||||
|   constexpr int packs_in_group = groups / pack_factor; | ||||
|   const int Kg = K / groups; | ||||
|   constexpr int bits = 4; | ||||
|   constexpr int group_size = 64; | ||||
|   constexpr int bitmask = (1 << bits) - 1; | ||||
|   constexpr int pack_factor = 32 / bits; | ||||
|   constexpr int packs_in_group = group_size / pack_factor; | ||||
|   const int Kg = K / group_size; | ||||
|   const int Kw = K / pack_factor; | ||||
|  | ||||
|   for (int m = 0; m < M; m++) { | ||||
| @@ -35,7 +35,7 @@ void _qmm_t_4_64( | ||||
|     for (int n = 0; n < N; n++) { | ||||
|       const simd_float16* x_local = (simd_float16*)x; | ||||
|       simd_float16 sum = 0; | ||||
|       for (int k = 0; k < K; k += groups) { | ||||
|       for (int k = 0; k < K; k += group_size) { | ||||
|         float scale = *scales_local++; | ||||
|         float bias = *biases_local++; | ||||
|  | ||||
| @@ -46,7 +46,7 @@ void _qmm_t_4_64( | ||||
|             uint32_t wii = *w_local++; | ||||
|             for (int p = 0; p < 8; p++) { | ||||
|               wi[e * 8 + p] = wii & bitmask; | ||||
|               wii >>= width; | ||||
|               wii >>= bits; | ||||
|             } | ||||
|           } | ||||
|           simd_float16 wf = simd_float(wi); | ||||
| @@ -85,7 +85,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) { | ||||
|     throw std::runtime_error("x, scales and biases should be row contiguous."); | ||||
|   } | ||||
|  | ||||
|   if (x.dtype() == float32 && width_ == 4 && groups_ == 64) { | ||||
|   if (x.dtype() == float32 && bits_ == 4 && group_size_ == 64) { | ||||
|     out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||||
|     int K = x.shape(-1); | ||||
|     int M = x.size() / K; | ||||
|   | ||||
| @@ -8,7 +8,7 @@ namespace mlx::core { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| template <typename T, int width, int groups> | ||||
| template <typename T, int bits, int group_size> | ||||
| void _qmm_t( | ||||
|     T* result, | ||||
|     const T* x, | ||||
| @@ -18,10 +18,10 @@ void _qmm_t( | ||||
|     int M, | ||||
|     int N, | ||||
|     int K) { | ||||
|   constexpr int bitmask = (1 << width) - 1; | ||||
|   constexpr int pack_factor = 32 / width; | ||||
|   constexpr int packs_in_group = groups / pack_factor; | ||||
|   const int Kg = K / groups; | ||||
|   constexpr int bitmask = (1 << bits) - 1; | ||||
|   constexpr int pack_factor = 32 / bits; | ||||
|   constexpr int packs_in_group = group_size / pack_factor; | ||||
|   const int Kg = K / group_size; | ||||
|   const int Kw = K / pack_factor; | ||||
|  | ||||
|   for (int m = 0; m < M; m++) { | ||||
| @@ -32,7 +32,7 @@ void _qmm_t( | ||||
|     for (int n = 0; n < N; n++) { | ||||
|       const T* x_local = x; | ||||
|       T sum = 0; | ||||
|       for (int k = 0; k < K; k += groups) { | ||||
|       for (int k = 0; k < K; k += group_size) { | ||||
|         T scale = *scales_local++; | ||||
|         T bias = *biases_local++; | ||||
|  | ||||
| @@ -42,7 +42,7 @@ void _qmm_t( | ||||
| #pragma clang loop unroll(full) | ||||
|           for (int p = 0; p < pack_factor; p++) { | ||||
|             sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias); | ||||
|             wi >>= width; | ||||
|             wi >>= bits; | ||||
|           } | ||||
|         } | ||||
|       } | ||||
| @@ -64,11 +64,11 @@ void _qmm_t_dispatch_typed( | ||||
|     int M, | ||||
|     int N, | ||||
|     int K, | ||||
|     int width, | ||||
|     int groups) { | ||||
|   switch (width) { | ||||
|     int group_size, | ||||
|     int bits) { | ||||
|   switch (bits) { | ||||
|     case 2: { | ||||
|       switch (groups) { | ||||
|       switch (group_size) { | ||||
|         case 64: | ||||
|           return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K); | ||||
|         case 128: | ||||
| @@ -76,7 +76,7 @@ void _qmm_t_dispatch_typed( | ||||
|       } | ||||
|     } | ||||
|     case 4: { | ||||
|       switch (groups) { | ||||
|       switch (group_size) { | ||||
|         case 64: | ||||
|           return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K); | ||||
|         case 128: | ||||
| @@ -84,7 +84,7 @@ void _qmm_t_dispatch_typed( | ||||
|       } | ||||
|     } | ||||
|     case 8: { | ||||
|       switch (groups) { | ||||
|       switch (group_size) { | ||||
|         case 64: | ||||
|           return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K); | ||||
|         case 128: | ||||
| @@ -93,9 +93,10 @@ void _qmm_t_dispatch_typed( | ||||
|     } | ||||
|   } | ||||
|   std::ostringstream msg; | ||||
|   msg << "Quantization type not supported. Provided bit width=" << width | ||||
|       << " and groups=" << groups << ". The supported options are width in " | ||||
|       << "{2, 4, 8} and groups in {64, 128}."; | ||||
|   msg << "Quantization type not supported. Provided bits=" << bits | ||||
|       << " and group_size=" << group_size | ||||
|       << ". The supported options are bits in " | ||||
|       << "{2, 4, 8} and group_size in {64, 128}."; | ||||
|   throw std::invalid_argument(msg.str()); | ||||
| } | ||||
|  | ||||
| @@ -105,8 +106,8 @@ void _qmm_t_dispatch( | ||||
|     const array& w, | ||||
|     const array& scales, | ||||
|     const array& biases, | ||||
|     int width, | ||||
|     int groups) { | ||||
|     int bits, | ||||
|     int group_size) { | ||||
|   int K = x.shape(-1); | ||||
|   int M = x.size() / K; | ||||
|   int N = w.shape(1); | ||||
| @@ -122,8 +123,8 @@ void _qmm_t_dispatch( | ||||
|           M, | ||||
|           N, | ||||
|           K, | ||||
|           width, | ||||
|           groups); | ||||
|           bits, | ||||
|           group_size); | ||||
|       break; | ||||
|     case float16: | ||||
|       _qmm_t_dispatch_typed<float16_t>( | ||||
| @@ -135,8 +136,8 @@ void _qmm_t_dispatch( | ||||
|           M, | ||||
|           N, | ||||
|           K, | ||||
|           width, | ||||
|           groups); | ||||
|           bits, | ||||
|           group_size); | ||||
|       break; | ||||
|     case bfloat16: | ||||
|       _qmm_t_dispatch_typed<bfloat16_t>( | ||||
| @@ -148,8 +149,8 @@ void _qmm_t_dispatch( | ||||
|           M, | ||||
|           N, | ||||
|           K, | ||||
|           width, | ||||
|           groups); | ||||
|           bits, | ||||
|           group_size); | ||||
|       break; | ||||
|     default: | ||||
|       throw std::invalid_argument( | ||||
| @@ -177,7 +178,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) { | ||||
|   } | ||||
|  | ||||
|   out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||||
|   _qmm_t_dispatch(out, x, w, scales, biases, width_, groups_); | ||||
|   _qmm_t_dispatch(out, x, w, scales, biases, group_size_, bits_); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -14,7 +14,7 @@ using namespace metal; | ||||
|  | ||||
| MLX_MTL_CONST int SIMD_SIZE = 32; | ||||
|  | ||||
| template <typename T, const int BM, const int BN, const int groups, const int width> | ||||
| template <typename T, const int BM, const int BN, const int group_size, const int bits> | ||||
| [[kernel]] void qmv( | ||||
|     const device uint32_t* w [[buffer(0)]], | ||||
|     const device T* scales [[buffer(1)]], | ||||
| @@ -30,10 +30,10 @@ template <typename T, const int BM, const int BN, const int groups, const int wi | ||||
|  | ||||
|   static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE"); | ||||
|  | ||||
|   constexpr int bitmask = (1 << width) - 1; | ||||
|   constexpr int el_per_thread = 32 / width; | ||||
|   constexpr int bitmask = (1 << bits) - 1; | ||||
|   constexpr int el_per_thread = 32 / bits; | ||||
|   constexpr int colgroup = BN * el_per_thread; | ||||
|   constexpr int groups_per_block = colgroup / groups; | ||||
|   constexpr int groups_per_block = colgroup / group_size; | ||||
|   constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE; | ||||
|  | ||||
|   threadgroup T scales_block[BM * groups_per_block]; | ||||
| @@ -48,7 +48,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi | ||||
|  | ||||
|   // Adjust positions | ||||
|   const int in_vec_size_w = in_vec_size / el_per_thread; | ||||
|   const int in_vec_size_g = in_vec_size / groups; | ||||
|   const int in_vec_size_g = in_vec_size / group_size; | ||||
|   int out_row = tid.y * BM + simd_gid; | ||||
|   w += out_row * in_vec_size_w; | ||||
|   scales += out_row * in_vec_size_g; | ||||
| @@ -66,11 +66,11 @@ template <typename T, const int BM, const int BN, const int groups, const int wi | ||||
|     if (simd_lid == 0) { | ||||
|       #pragma clang loop unroll(full) | ||||
|       for (int j=0; j<groups_per_block; j++) { | ||||
|         scales_block[simd_gid * groups_per_block + j] = scales[i / groups + j]; | ||||
|         scales_block[simd_gid * groups_per_block + j] = scales[i / group_size + j]; | ||||
|       } | ||||
|       #pragma clang loop unroll(full) | ||||
|       for (int j=0; j<groups_per_block; j++) { | ||||
|         biases_block[simd_gid * groups_per_block + j] = biases[i / groups + j]; | ||||
|         biases_block[simd_gid * groups_per_block + j] = biases[i / group_size + j]; | ||||
|       } | ||||
|     } | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
| @@ -80,8 +80,8 @@ template <typename T, const int BM, const int BN, const int groups, const int wi | ||||
|     for (int j=0; j<el_per_thread; j++) { | ||||
|       x_thread[j] = x_block[simd_lid*el_per_thread + j]; | ||||
|     } | ||||
|     scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups]; | ||||
|     bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups]; | ||||
|     scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size]; | ||||
|     bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size]; | ||||
|  | ||||
|     // Load the matrix elements | ||||
|     w_local = w[i / el_per_thread + simd_lid]; | ||||
| @@ -90,7 +90,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi | ||||
|     #pragma clang loop unroll(full) | ||||
|     for (int k=0; k<el_per_thread; k++) { | ||||
|       result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k]; | ||||
|       w_local >>= width; | ||||
|       w_local >>= bits; | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @@ -104,7 +104,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi | ||||
| } | ||||
|  | ||||
|  | ||||
| template <typename T, const int BM, const int BK, const int BN, const int groups, const int width> | ||||
| template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits> | ||||
| [[kernel]] void qmm_t( | ||||
|     const device T* x [[buffer(0)]], | ||||
|     const device uint32_t* w [[buffer(1)]], | ||||
| @@ -126,10 +126,10 @@ template <typename T, const int BM, const int BK, const int BN, const int groups | ||||
|  | ||||
|   constexpr int WM = 2; | ||||
|   constexpr int WN = 2; | ||||
|   constexpr int bitmask = (1 << width) - 1; | ||||
|   constexpr int el_per_int = 32 / width; | ||||
|   constexpr int bitmask = (1 << bits) - 1; | ||||
|   constexpr int el_per_int = 32 / bits; | ||||
|   constexpr int ints_per_block = BK / el_per_int; | ||||
|   constexpr int groups_per_block = (BK / groups > 0) ? (BK / groups) : 1; | ||||
|   constexpr int groups_per_block = (BK / group_size > 0) ? (BK / group_size) : 1; | ||||
|   constexpr int groups_per_simd = BN / (WM * WN); | ||||
|   constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN); | ||||
|  | ||||
| @@ -145,7 +145,7 @@ template <typename T, const int BM, const int BK, const int BN, const int groups | ||||
|  | ||||
|   // Set the block | ||||
|   const int K_w = K / el_per_int; | ||||
|   const int K_g = K / groups; | ||||
|   const int K_g = K / group_size; | ||||
|   const int y_row = tid.y * BM; | ||||
|   const int y_col = tid.x * BN; | ||||
|   x += y_row * K; | ||||
| @@ -172,8 +172,8 @@ template <typename T, const int BM, const int BK, const int BN, const int groups | ||||
|     if (simd_lid == 0) { | ||||
|       threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd; | ||||
|       threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd; | ||||
|       const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / groups; | ||||
|       const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / groups; | ||||
|       const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / group_size; | ||||
|       const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / group_size; | ||||
|       #pragma clang loop unroll(full) | ||||
|       for (int gs=0; gs<groups_per_simd; gs++) { | ||||
|         #pragma clang loop unroll(full) | ||||
| @@ -199,13 +199,13 @@ template <typename T, const int BM, const int BK, const int BN, const int groups | ||||
|         threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int; | ||||
|  | ||||
|         uint32_t wi = *w_local; | ||||
|         T scale = scales_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)]; | ||||
|         T bias = biases_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)]; | ||||
|         T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)]; | ||||
|         T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)]; | ||||
|  | ||||
|         #pragma clang loop unroll(full) | ||||
|         for (int t=0; t<el_per_int; t++) { | ||||
|           Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias; | ||||
|           wi >>= width; | ||||
|           wi >>= bits; | ||||
|         } | ||||
|       } | ||||
|     } | ||||
| @@ -231,9 +231,9 @@ template <typename T, const int BM, const int BK, const int BN, const int groups | ||||
| } | ||||
|  | ||||
|  | ||||
| #define instantiate_qmv(name, itype, groups, width) \ | ||||
|   template [[host_name("qmv_n_" #name "_groups_" #groups "_width_" #width)]] \ | ||||
|   [[kernel]] void qmv<itype, 32, 32, groups, width>( \ | ||||
| #define instantiate_qmv(name, itype, group_size, bits) \ | ||||
|   template [[host_name("qmv_n_" #name "_gs_" #group_size "_b_" #bits)]] \ | ||||
|   [[kernel]] void qmv<itype, 32, 32, group_size, bits>( \ | ||||
|     const device uint32_t* w [[buffer(0)]], \ | ||||
|     const device itype* scales [[buffer(1)]], \ | ||||
|     const device itype* biases [[buffer(2)]], \ | ||||
| @@ -246,10 +246,10 @@ template <typename T, const int BM, const int BK, const int BN, const int groups | ||||
|     uint simd_gid [[simdgroup_index_in_threadgroup]], \ | ||||
|     uint simd_lid [[thread_index_in_simdgroup]]); | ||||
|  | ||||
| #define instantiate_qmv_types(groups, width) \ | ||||
|   instantiate_qmv(float32, float, groups, width) \ | ||||
|   instantiate_qmv(float16, half, groups, width) \ | ||||
|   instantiate_qmv(bfloat16, bfloat16_t, groups, width) | ||||
| #define instantiate_qmv_types(group_size, bits) \ | ||||
|   instantiate_qmv(float32, float, group_size, bits) \ | ||||
|   instantiate_qmv(float16, half, group_size, bits) \ | ||||
|   instantiate_qmv(bfloat16, bfloat16_t, group_size, bits) | ||||
|  | ||||
| instantiate_qmv_types(128, 2) | ||||
| instantiate_qmv_types(128, 4) | ||||
| @@ -258,9 +258,9 @@ instantiate_qmv_types( 64, 2) | ||||
| instantiate_qmv_types( 64, 4) | ||||
| instantiate_qmv_types( 64, 8) | ||||
|  | ||||
| #define instantiate_qmm_t(name, itype, groups, width) \ | ||||
|   template [[host_name("qmm_t_" #name "_groups_" #groups "_width_" #width)]] \ | ||||
|   [[kernel]] void qmm_t<itype, 32, 64, 32, groups, width>( \ | ||||
| #define instantiate_qmm_t(name, itype, group_size, bits) \ | ||||
|   template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \ | ||||
|   [[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \ | ||||
|       const device itype* x [[buffer(0)]], \ | ||||
|       const device uint32_t* w [[buffer(1)]], \ | ||||
|       const device itype* scales [[buffer(2)]], \ | ||||
| @@ -274,10 +274,10 @@ instantiate_qmv_types( 64, 8) | ||||
|       uint simd_gid [[simdgroup_index_in_threadgroup]], \ | ||||
|       uint simd_lid [[thread_index_in_simdgroup]]); | ||||
|  | ||||
| #define instantiate_qmm_t_types(groups, width) \ | ||||
|   instantiate_qmm_t(float32, float, groups, width) \ | ||||
|   instantiate_qmm_t(float16, half, groups, width) \ | ||||
|   instantiate_qmm_t(bfloat16, bfloat16_t, groups, width) | ||||
| #define instantiate_qmm_t_types(group_size, bits) \ | ||||
|   instantiate_qmm_t(float32, float, group_size, bits) \ | ||||
|   instantiate_qmm_t(float16, half, group_size, bits) \ | ||||
|   instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits) | ||||
|  | ||||
| instantiate_qmm_t_types(128, 2) | ||||
| instantiate_qmm_t_types(128, 4) | ||||
|   | ||||
| @@ -58,7 +58,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   if (B == 1) { | ||||
|     std::ostringstream kname; | ||||
|     kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out) | ||||
|           << "_groups_" << groups_ << "_width_" << width_; | ||||
|           << "_gs_" << group_size_ << "_b_" << bits_; | ||||
|  | ||||
|     // Encode and dispatch kernel | ||||
|     auto compute_encoder = d.get_command_encoder(s.index); | ||||
| @@ -87,7 +87,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   else { | ||||
|     std::ostringstream kname; | ||||
|     kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out) | ||||
|           << "_groups_" << groups_ << "_width_" << width_; | ||||
|           << "_gs_" << group_size_ << "_b_" << bits_; | ||||
|  | ||||
|     // Encode and dispatch kernel | ||||
|     auto compute_encoder = d.get_command_encoder(s.index); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user