Quantize with groups of 32 (#511)

* allow quantize with group sizes of 32

* missing cpu dispatch

* remove print

* Fix qvm for group_size 32

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-01-21 06:19:05 -08:00
committed by GitHub
parent 92c22c1ea3
commit 7a34e46677
6 changed files with 66 additions and 27 deletions

View File

@@ -10,18 +10,19 @@ import mlx_tests
class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 512))
for b in [2, 4, 8]:
w_q, scales, biases = mx.quantize(w, 64, b)
w_hat = mx.dequantize(w_q, scales, biases, 64, b)
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
eps = 1e-6
self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all())
for gs in [32, 64, 128]:
for b in [2, 4, 8]:
w_q, scales, biases = mx.quantize(w, gs, b)
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
eps = 1e-6
self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all())
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64], # group_size
[128, 64, 32], # group_size
[2, 4, 8], # bits
[8, 32, 33, 64], # M
[512, 1024], # N
@@ -75,7 +76,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64], # group_size
[128, 64, 32], # group_size
[2, 4, 8], # bits
[512, 1024], # M
[512, 1024], # N
@@ -97,7 +98,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64], # group_size
[128, 64, 32], # group_size
[2, 4, 8], # bits
[512, 1024], # M
[512, 1024], # N