mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Allow arbitrary first dimension in quantization kernels. (#458)
* Allow arbitrary first dim on qmm_t and qmv * Allow arbitrary first dim on qmm and qvm * Specialized aligned vs unaligned case * Add more checks for valid quantizations
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							f44c132f4a
						
					
				
				
					commit
					c15fe3e61b
				
			| @@ -60,20 +60,48 @@ def matmul(x, y): | ||||
|     mx.eval(ys) | ||||
|  | ||||
|  | ||||
| def _quant_matmul(x, w, s, b, group_size, bits): | ||||
| def _quant_matmul(x, w, s, b, transpose, group_size, bits): | ||||
|     ys = [] | ||||
|     for i in range(10): | ||||
|         ys.append(mx.quantized_matmul(x, w, s, b, group_size=group_size, bits=bits)) | ||||
|         ys.append( | ||||
|             mx.quantized_matmul( | ||||
|                 x, w, s, b, transpose=transpose, group_size=group_size, bits=bits | ||||
|             ) | ||||
|         ) | ||||
|     mx.eval(ys) | ||||
|  | ||||
|  | ||||
| quant_matmul = { | ||||
|     "quant_matmul_64_2": partial(_quant_matmul, group_size=64, bits=2), | ||||
|     "quant_matmul_64_4": partial(_quant_matmul, group_size=64, bits=4), | ||||
|     "quant_matmul_64_8": partial(_quant_matmul, group_size=64, bits=8), | ||||
|     "quant_matmul_128_2": partial(_quant_matmul, group_size=128, bits=2), | ||||
|     "quant_matmul_128_4": partial(_quant_matmul, group_size=128, bits=4), | ||||
|     "quant_matmul_128_8": partial(_quant_matmul, group_size=128, bits=8), | ||||
|     "quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2), | ||||
|     "quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4), | ||||
|     "quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8), | ||||
|     "quant_matmul_128_2": partial( | ||||
|         _quant_matmul, transpose=False, group_size=128, bits=2 | ||||
|     ), | ||||
|     "quant_matmul_128_4": partial( | ||||
|         _quant_matmul, transpose=False, group_size=128, bits=4 | ||||
|     ), | ||||
|     "quant_matmul_128_8": partial( | ||||
|         _quant_matmul, transpose=False, group_size=128, bits=8 | ||||
|     ), | ||||
|     "quant_matmul_t_64_2": partial( | ||||
|         _quant_matmul, transpose=True, group_size=64, bits=2 | ||||
|     ), | ||||
|     "quant_matmul_t_64_4": partial( | ||||
|         _quant_matmul, transpose=True, group_size=64, bits=4 | ||||
|     ), | ||||
|     "quant_matmul_t_64_8": partial( | ||||
|         _quant_matmul, transpose=True, group_size=64, bits=8 | ||||
|     ), | ||||
|     "quant_matmul_t_128_2": partial( | ||||
|         _quant_matmul, transpose=True, group_size=128, bits=2 | ||||
|     ), | ||||
|     "quant_matmul_t_128_4": partial( | ||||
|         _quant_matmul, transpose=True, group_size=128, bits=4 | ||||
|     ), | ||||
|     "quant_matmul_t_128_8": partial( | ||||
|         _quant_matmul, transpose=True, group_size=128, bits=8 | ||||
|     ), | ||||
| } | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user