diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ed92831a7..20b445d1f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3268,6 +3268,9 @@ std::tuple quantize( array scales = squeeze(delta, -1, s); array biases = squeeze(w_min, -1, s); + // making sure that 0 is represented exactly in the resulting quantization + biases = multiply(round(divide(biases, scales, s), s), scales, s); + // Quantize and pack w packed_w = astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32); diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 06c974652..25540d8d1 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1631,12 +1631,19 @@ class TestLayers(mlx_tests.MLXTestCase): x = mx.array([2, 6, 9, 3, 0, 3]) y = emb(x) yq = qemb(x) - self.assertLess((y - yq).abs().max(), 1e-3) + self.assertLess((y - yq).abs().max(), qemb.scales.max()) x = mx.random.uniform(shape=(2, 256)) y = emb.as_linear(x) yq = qemb.as_linear(x) - self.assertLess((y - yq).abs().max(), 1e-2) + + def cosine(a, b): + ab = (a * b).sum(-1) + aa = mx.linalg.norm(a, axis=-1) + bb = mx.linalg.norm(b, axis=-1) + return ab / aa / bb + + self.assertGreater(cosine(y, yq).min(), 0.99) if __name__ == "__main__": diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 57e369aa3..60e036d69 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -16,7 +16,7 @@ class TestQuantized(mlx_tests.MLXTestCase): w_hat = mx.dequantize(w_q, scales, biases, gs, b) errors = (w - w_hat).abs().reshape(*scales.shape, -1) eps = 1e-6 - self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all()) + self.assertTrue((errors <= (scales[..., None] + eps)).all()) def test_qmm(self): key = mx.random.key(0)