Improvements in the quantizer and dequantization kernel (#1061)

This commit is contained in:
Angelos Katharopoulos
2024-05-01 18:19:11 -07:00
committed by GitHub
parent 7f7b9662ea
commit 17f57df797
3 changed files with 25 additions and 27 deletions

View File

@@ -16,7 +16,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
eps = 1e-6
self.assertTrue((2 * errors <= (scales[..., None] + eps)).all())
self.assertTrue((errors <= (scales[..., None] + eps).abs()).all())
# test quantize/dequantize 0s
a = mx.zeros((256, 512))