Fix quantization of all 0s (#1028)

This commit is contained in:
Angelos Katharopoulos
2024-04-24 00:40:42 -07:00
committed by GitHub
parent d0dbfe0b97
commit ec8578d41a
2 changed files with 12 additions and 1 deletions

View File

@@ -18,6 +18,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
eps = 1e-6
self.assertTrue((errors <= (scales[..., None] + eps)).all())
# test quantize/dequantize 0s
a = mx.zeros((256, 512))
for gs in [32, 64, 128]:
for b in [2, 4, 8]:
w_q, scales, biases = mx.quantize(a, gs, b)
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
self.assertTrue(mx.all(a_hat == 0))
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)