mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix quantization of all 0s (#1028)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							d0dbfe0b97
						
					
				
				
					commit
					ec8578d41a
				
			| @@ -18,6 +18,14 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|                 eps = 1e-6 | ||||
|                 self.assertTrue((errors <= (scales[..., None] + eps)).all()) | ||||
|  | ||||
|         # test quantize/dequantize 0s | ||||
|         a = mx.zeros((256, 512)) | ||||
|         for gs in [32, 64, 128]: | ||||
|             for b in [2, 4, 8]: | ||||
|                 w_q, scales, biases = mx.quantize(a, gs, b) | ||||
|                 a_hat = mx.dequantize(w_q, scales, biases, gs, b) | ||||
|                 self.assertTrue(mx.all(a_hat == 0)) | ||||
|  | ||||
|     def test_qmm(self): | ||||
|         key = mx.random.key(0) | ||||
|         k1, k2 = mx.random.split(key) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user