mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Make sure 0 is represented in the quantization (#1016)
This commit is contained in:
parent
ed83908931
commit
84d61d27aa
@ -3268,6 +3268,9 @@ std::tuple<array, array, array> quantize(
|
||||
array scales = squeeze(delta, -1, s);
|
||||
array biases = squeeze(w_min, -1, s);
|
||||
|
||||
// making sure that 0 is represented exactly in the resulting quantization
|
||||
biases = multiply(round(divide(biases, scales, s), s), scales, s);
|
||||
|
||||
// Quantize and pack w
|
||||
packed_w =
|
||||
astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32);
|
||||
|
@ -1631,12 +1631,19 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
x = mx.array([2, 6, 9, 3, 0, 3])
|
||||
y = emb(x)
|
||||
yq = qemb(x)
|
||||
self.assertLess((y - yq).abs().max(), 1e-3)
|
||||
self.assertLess((y - yq).abs().max(), qemb.scales.max())
|
||||
|
||||
x = mx.random.uniform(shape=(2, 256))
|
||||
y = emb.as_linear(x)
|
||||
yq = qemb.as_linear(x)
|
||||
self.assertLess((y - yq).abs().max(), 1e-2)
|
||||
|
||||
def cosine(a, b):
|
||||
ab = (a * b).sum(-1)
|
||||
aa = mx.linalg.norm(a, axis=-1)
|
||||
bb = mx.linalg.norm(b, axis=-1)
|
||||
return ab / aa / bb
|
||||
|
||||
self.assertGreater(cosine(y, yq).min(), 0.99)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -16,7 +16,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
||||
eps = 1e-6
|
||||
self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all())
|
||||
self.assertTrue((errors <= (scales[..., None] + eps)).all())
|
||||
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
|
Loading…
Reference in New Issue
Block a user