Make sure 0 is represented in the quantization (#1016)

This commit is contained in:
Angelos Katharopoulos 2024-04-19 19:47:26 -07:00 committed by GitHub
parent ed83908931
commit 84d61d27aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 3 deletions

View File

@ -3268,6 +3268,9 @@ std::tuple<array, array, array> quantize(
array scales = squeeze(delta, -1, s);
array biases = squeeze(w_min, -1, s);
// making sure that 0 is represented exactly in the resulting quantization
biases = multiply(round(divide(biases, scales, s), s), scales, s);
// Quantize and pack w
packed_w =
astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32);

View File

@ -1631,12 +1631,19 @@ class TestLayers(mlx_tests.MLXTestCase):
x = mx.array([2, 6, 9, 3, 0, 3])
y = emb(x)
yq = qemb(x)
self.assertLess((y - yq).abs().max(), 1e-3)
self.assertLess((y - yq).abs().max(), qemb.scales.max())
x = mx.random.uniform(shape=(2, 256))
y = emb.as_linear(x)
yq = qemb.as_linear(x)
self.assertLess((y - yq).abs().max(), 1e-2)
def cosine(a, b):
ab = (a * b).sum(-1)
aa = mx.linalg.norm(a, axis=-1)
bb = mx.linalg.norm(b, axis=-1)
return ab / aa / bb
self.assertGreater(cosine(y, yq).min(), 0.99)
if __name__ == "__main__":

View File

@ -16,7 +16,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
eps = 1e-6
self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all())
self.assertTrue((errors <= (scales[..., None] + eps)).all())
def test_qmm(self):
key = mx.random.key(0)