diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 5fcc882a5..049f92fdb 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -13,7 +13,8 @@ class TestQuantized(mlx_tests.MLXTestCase): w_q, scales, biases = mx.quantize(w, 64, b) w_hat = mx.dequantize(w_q, scales, biases, 64, b) errors = (w - w_hat).abs().reshape(*scales.shape, -1) - self.assertTrue((errors <= scales[..., None] / 2).all()) + eps = 1e-6 + self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all()) def test_qmm(self): key = mx.random.key(0)