mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	Fix qmm_t for unaligned cases (#923)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							46caf0bef0
						
					
				
				
					commit
					5f9ba3019f
				
			@@ -520,6 +520,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
 | 
			
		||||
  const int K_g = K / group_size;
 | 
			
		||||
  const int y_row = tid.y * BM;
 | 
			
		||||
  const int y_col = tid.x * BN;
 | 
			
		||||
 | 
			
		||||
  x += y_row * K;
 | 
			
		||||
  w += y_col * K_w;
 | 
			
		||||
  scales += y_col * K_g;
 | 
			
		||||
@@ -572,7 +573,10 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
 | 
			
		||||
          const device uint32_t * w_local = w + offset_row * K_w + offset_col;
 | 
			
		||||
          threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
 | 
			
		||||
 | 
			
		||||
          if (y_row + offset_row < N) {
 | 
			
		||||
          // y_col corresponds to the row of the weight matrix and added to
 | 
			
		||||
          // offset_row it should be less than the total number of rows
 | 
			
		||||
          // otherwise skip.
 | 
			
		||||
          if (y_col + offset_row < N) {
 | 
			
		||||
            uint32_t wi = *w_local;
 | 
			
		||||
            T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
 | 
			
		||||
            T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user