mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-25 12:48:14 +08:00 
			
		
		
		
	Quantize with groups of 32 (#511)
* allow quantize with group sizes of 32 * missing cpu dispatch * remove print * Fix qvm for group_size 32 --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -119,6 +119,12 @@ void _qmm_dispatch_typed( | ||||
|   switch (bits) { | ||||
|     case 2: { | ||||
|       switch (group_size) { | ||||
|         case 32: | ||||
|           if (transposed_w) { | ||||
|             return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K); | ||||
|           } else { | ||||
|             return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K); | ||||
|           } | ||||
|         case 64: | ||||
|           if (transposed_w) { | ||||
|             return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K); | ||||
| @@ -135,6 +141,12 @@ void _qmm_dispatch_typed( | ||||
|     } | ||||
|     case 4: { | ||||
|       switch (group_size) { | ||||
|         case 32: | ||||
|           if (transposed_w) { | ||||
|             return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K); | ||||
|           } else { | ||||
|             return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K); | ||||
|           } | ||||
|         case 64: | ||||
|           if (transposed_w) { | ||||
|             return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K); | ||||
| @@ -151,6 +163,12 @@ void _qmm_dispatch_typed( | ||||
|     } | ||||
|     case 8: { | ||||
|       switch (group_size) { | ||||
|         case 32: | ||||
|           if (transposed_w) { | ||||
|             return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K); | ||||
|           } else { | ||||
|             return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K); | ||||
|           } | ||||
|         case 64: | ||||
|           if (transposed_w) { | ||||
|             return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K); | ||||
|   | ||||
| @@ -142,10 +142,11 @@ template <typename T, const int BM, const int BN, const int group_size, const in | ||||
|   // Adjust positions | ||||
|   const int out_vec_size_w = out_vec_size / el_per_int; | ||||
|   const int out_vec_size_g = out_vec_size / group_size; | ||||
|   int out_col = (tid.y * BN + simd_gid) * el_per_int; | ||||
|   int out_col_start = tid.y * (BN * el_per_int); | ||||
|   int out_col = out_col_start + simd_gid * el_per_int; | ||||
|   w += out_col / el_per_int; | ||||
|   scales += out_col / group_size; | ||||
|   biases += out_col / group_size; | ||||
|   scales += out_col_start / group_size; | ||||
|   biases += out_col_start / group_size; | ||||
|   x += tid.z * in_vec_size; | ||||
|   y += tid.z * out_vec_size + out_col; | ||||
|  | ||||
| @@ -155,26 +156,22 @@ template <typename T, const int BM, const int BN, const int group_size, const in | ||||
|  | ||||
|   // Loop over in_vec in blocks of colgroup | ||||
|   for (int i=0; i<in_vec_size; i+=BM) { | ||||
|     int offset = simd_lid + i; | ||||
|     bool thread_in_bounds = offset < in_vec_size; | ||||
|     int offset_lid = simd_lid + i; | ||||
|     int offset_gid = simd_gid + i; | ||||
|     bool thread_in_bounds = offset_lid < in_vec_size; | ||||
|     bool group_in_bounds = offset_gid < in_vec_size; | ||||
|  | ||||
|     // Load the vec to shared memory | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|     if (simd_gid == 0) { | ||||
|       x_block[simd_lid] = (thread_in_bounds) ? x[offset] : 0; | ||||
|       x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0; | ||||
|     } | ||||
|  | ||||
|     // Load the scales and biases to shared memory | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|     if (simd_gid == 0) { | ||||
|       #pragma clang loop unroll(full) | ||||
|       for (int j=0; j<groups_per_block; j++) { | ||||
|         scales_block[simd_lid * groups_per_block + j] = scales[(i + simd_lid) * out_vec_size_g + j]; | ||||
|       } | ||||
|       #pragma clang loop unroll(full) | ||||
|       for (int j=0; j<groups_per_block; j++) { | ||||
|         biases_block[simd_lid * groups_per_block + j] = biases[(i + simd_lid) * out_vec_size_g + j]; | ||||
|       } | ||||
|     if (simd_lid < groups_per_block && group_in_bounds) { | ||||
|       scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid]; | ||||
|       biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid]; | ||||
|     } | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
| @@ -184,7 +181,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in | ||||
|     bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size]; | ||||
|  | ||||
|     // Load the matrix elements | ||||
|     w_local = (thread_in_bounds) ? w[offset * out_vec_size_w] : 0; | ||||
|     w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0; | ||||
|  | ||||
|     // Do all the work. | ||||
|     #pragma clang loop unroll(full) | ||||
| @@ -543,6 +540,9 @@ instantiate_qmv_types(128, 8) | ||||
| instantiate_qmv_types( 64, 2) | ||||
| instantiate_qmv_types( 64, 4) | ||||
| instantiate_qmv_types( 64, 8) | ||||
| instantiate_qmv_types( 32, 2) | ||||
| instantiate_qmv_types( 32, 4) | ||||
| instantiate_qmv_types( 32, 8) | ||||
|  | ||||
| #define instantiate_qvm(name, itype, group_size, bits) \ | ||||
|   template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \ | ||||
| @@ -570,6 +570,9 @@ instantiate_qvm_types(128, 8) | ||||
| instantiate_qvm_types( 64, 2) | ||||
| instantiate_qvm_types( 64, 4) | ||||
| instantiate_qvm_types( 64, 8) | ||||
| instantiate_qvm_types( 32, 2) | ||||
| instantiate_qvm_types( 32, 4) | ||||
| instantiate_qvm_types( 32, 8) | ||||
|  | ||||
| #define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \ | ||||
|   template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \ | ||||
| @@ -601,6 +604,9 @@ instantiate_qmm_t_types(128, 8) | ||||
| instantiate_qmm_t_types( 64, 2) | ||||
| instantiate_qmm_t_types( 64, 4) | ||||
| instantiate_qmm_t_types( 64, 8) | ||||
| instantiate_qmm_t_types( 32, 2) | ||||
| instantiate_qmm_t_types( 32, 4) | ||||
| instantiate_qmm_t_types( 32, 8) | ||||
|  | ||||
| #define instantiate_qmm_n(name, itype, group_size, bits) \ | ||||
|   template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \ | ||||
| @@ -629,3 +635,6 @@ instantiate_qmm_n_types(128, 8) | ||||
| instantiate_qmm_n_types( 64, 2) | ||||
| instantiate_qmm_n_types( 64, 4) | ||||
| instantiate_qmm_n_types( 64, 8) | ||||
| instantiate_qmm_n_types( 32, 2) | ||||
| instantiate_qmm_n_types( 32, 4) | ||||
| instantiate_qmm_n_types( 32, 8) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun