mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Quantize with groups of 32 (#511)
* allow quantize with group sizes of 32 * missing cpu dispatch * remove print * Fix qvm for group_size 32 --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -140,7 +140,6 @@ class TestLosses(mlx_tests.MLXTestCase): | ||||
|                 probs, targets, with_logits=False, reduction="none" | ||||
|             ) | ||||
|             expected_none = mx.array([0.693147, 0.916291, 0.356675, 0.223144]) | ||||
|             print(losses_none, expected_none) | ||||
|             self.assertTrue(mx.allclose(losses_none, expected_none)) | ||||
|  | ||||
|             # Test with reduction 'mean' | ||||
|   | ||||
| @@ -10,18 +10,19 @@ import mlx_tests | ||||
| class TestQuantized(mlx_tests.MLXTestCase): | ||||
|     def test_quantize_dequantize(self): | ||||
|         w = mx.random.normal(shape=(128, 512)) | ||||
|         for b in [2, 4, 8]: | ||||
|             w_q, scales, biases = mx.quantize(w, 64, b) | ||||
|             w_hat = mx.dequantize(w_q, scales, biases, 64, b) | ||||
|             errors = (w - w_hat).abs().reshape(*scales.shape, -1) | ||||
|             eps = 1e-6 | ||||
|             self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all()) | ||||
|         for gs in [32, 64, 128]: | ||||
|             for b in [2, 4, 8]: | ||||
|                 w_q, scales, biases = mx.quantize(w, gs, b) | ||||
|                 w_hat = mx.dequantize(w_q, scales, biases, gs, b) | ||||
|                 errors = (w - w_hat).abs().reshape(*scales.shape, -1) | ||||
|                 eps = 1e-6 | ||||
|                 self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all()) | ||||
|  | ||||
|     def test_qmm(self): | ||||
|         key = mx.random.key(0) | ||||
|         k1, k2 = mx.random.split(key) | ||||
|         tests = product( | ||||
|             [128, 64],  # group_size | ||||
|             [128, 64, 32],  # group_size | ||||
|             [2, 4, 8],  # bits | ||||
|             [8, 32, 33, 64],  # M | ||||
|             [512, 1024],  # N | ||||
| @@ -75,7 +76,7 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|         key = mx.random.key(0) | ||||
|         k1, k2 = mx.random.split(key) | ||||
|         tests = product( | ||||
|             [128, 64],  # group_size | ||||
|             [128, 64, 32],  # group_size | ||||
|             [2, 4, 8],  # bits | ||||
|             [512, 1024],  # M | ||||
|             [512, 1024],  # N | ||||
| @@ -97,7 +98,7 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|         key = mx.random.key(0) | ||||
|         k1, k2 = mx.random.split(key) | ||||
|         tests = product( | ||||
|             [128, 64],  # group_size | ||||
|             [128, 64, 32],  # group_size | ||||
|             [2, 4, 8],  # bits | ||||
|             [512, 1024],  # M | ||||
|             [512, 1024],  # N | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun