Bug fix in quantize (#1054)

This commit is contained in:
Angelos Katharopoulos
2024-04-29 20:55:04 -07:00
committed by GitHub
parent 09f1777896
commit 8db7161c94
2 changed files with 7 additions and 9 deletions

View File

@@ -16,7 +16,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
eps = 1e-6
self.assertTrue((errors <= (scales[..., None] + eps)).all())
self.assertTrue((2 * errors <= (scales[..., None] + eps)).all())
# test quantize/dequantize 0s
a = mx.zeros((256, 512))