mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-25 12:48:14 +08:00 
			
		
		
		
	Quantized matmul fix (#677)
* Fix qmv for small or unaligned matrices * Fix qmm
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							4cc70290f7
						
					
				
				
					commit
					40c108766b
				
			| @@ -39,11 +39,12 @@ template <typename T, const int BM, const int BN, const int group_size, const in | ||||
|  | ||||
|   static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE"); | ||||
|  | ||||
|   (void)lid; | ||||
|  | ||||
|   constexpr int bitmask = (1 << bits) - 1; | ||||
|   constexpr int el_per_thread = 32 / bits; | ||||
|   constexpr int colgroup = BN * el_per_thread; | ||||
|   constexpr int groups_per_block = colgroup / group_size; | ||||
|   constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE; | ||||
|  | ||||
|   typedef typename AccT<T>::acc_t U; | ||||
|   threadgroup U scales_block[BM * groups_per_block]; | ||||
| @@ -66,12 +67,19 @@ template <typename T, const int BM, const int BN, const int group_size, const in | ||||
|   x += tid.z * in_vec_size; | ||||
|   y += tid.z * out_vec_size; | ||||
|  | ||||
|   if (out_row >= out_vec_size) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   // Loop over in_vec in blocks of colgroup | ||||
|   for (int i=0; i<in_vec_size; i+=colgroup) { | ||||
|     // Load the vec to shared memory | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|     if (simd_gid < simdgroups_fetching_vec) { | ||||
|       x_block[lid] = x[lid + i]; | ||||
|     if (simd_gid == 0) { | ||||
|       #pragma clang loop unroll(full) | ||||
|       for (int j=0; j<el_per_thread; j++) { | ||||
|         x_block[simd_lid * el_per_thread + j] = x[i + simd_lid * el_per_thread + j]; | ||||
|       } | ||||
|     } | ||||
|     if (simd_lid == 0) { | ||||
|       #pragma clang loop unroll(full) | ||||
| @@ -250,7 +258,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_ | ||||
|   using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>; | ||||
|   using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>; | ||||
|  | ||||
|  | ||||
|   threadgroup T scales_block[BN * groups_per_block]; | ||||
|   threadgroup T biases_block[BN * groups_per_block]; | ||||
|   threadgroup T Xs[BM * BK]; | ||||
| @@ -313,7 +320,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_ | ||||
|           const device uint32_t * w_local = w + offset_row * K_w + offset_col; | ||||
|           threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int; | ||||
|  | ||||
|           if (y_col + offset_col < N) { | ||||
|           if (y_row + offset_row < N) { | ||||
|             uint32_t wi = *w_local; | ||||
|             T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)]; | ||||
|             T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)]; | ||||
| @@ -428,8 +435,9 @@ template <typename T, const int BM, const int BK, const int BN, const int group_ | ||||
|   for (int k=0; k<K; k += BK) { | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|     // Load the x tile | ||||
|     if (num_els < BM) { | ||||
|         loader_x.load_safe(short2(BK, num_els)); | ||||
|     short num_k = min(BK, K - k); | ||||
|     if (num_els < BM || num_k < BK) { | ||||
|         loader_x.load_safe(short2(num_k, num_els)); | ||||
|     } else { | ||||
|         loader_x.load_unsafe(); | ||||
|     } | ||||
| @@ -457,7 +465,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_ | ||||
|  | ||||
|     // Load the w tile | ||||
|     { | ||||
|       if (k + BK >= K) { | ||||
|       if (num_k < BK) { | ||||
|         for (int wo=0; wo<w_els_per_thread; wo++) { | ||||
|           int offset = lid * w_els_per_thread + wo; | ||||
|           int offset_row = offset / (BN / el_per_int); | ||||
|   | ||||
| @@ -55,7 +55,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       int bo = std::min(32, O); | ||||
|       int bd = 32; | ||||
|       MTL::Size group_dims = MTL::Size(bd, bo, 1); | ||||
|       MTL::Size grid_dims = MTL::Size(1, O / bo, B); | ||||
|       MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B); | ||||
|  | ||||
|       set_array_buffer(compute_encoder, w, 0); | ||||
|       set_array_buffer(compute_encoder, scales, 1); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user