mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Make sure 0 is represented in the quantization (#1016)
This commit is contained in:
parent
ed83908931
commit
84d61d27aa
@ -3268,6 +3268,9 @@ std::tuple<array, array, array> quantize(
|
|||||||
array scales = squeeze(delta, -1, s);
|
array scales = squeeze(delta, -1, s);
|
||||||
array biases = squeeze(w_min, -1, s);
|
array biases = squeeze(w_min, -1, s);
|
||||||
|
|
||||||
|
// making sure that 0 is represented exactly in the resulting quantization
|
||||||
|
biases = multiply(round(divide(biases, scales, s), s), scales, s);
|
||||||
|
|
||||||
// Quantize and pack w
|
// Quantize and pack w
|
||||||
packed_w =
|
packed_w =
|
||||||
astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32);
|
astype(round(divide(subtract(packed_w, w_min, s), delta, s), s), uint32);
|
||||||
|
@ -1631,12 +1631,19 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
x = mx.array([2, 6, 9, 3, 0, 3])
|
x = mx.array([2, 6, 9, 3, 0, 3])
|
||||||
y = emb(x)
|
y = emb(x)
|
||||||
yq = qemb(x)
|
yq = qemb(x)
|
||||||
self.assertLess((y - yq).abs().max(), 1e-3)
|
self.assertLess((y - yq).abs().max(), qemb.scales.max())
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(2, 256))
|
x = mx.random.uniform(shape=(2, 256))
|
||||||
y = emb.as_linear(x)
|
y = emb.as_linear(x)
|
||||||
yq = qemb.as_linear(x)
|
yq = qemb.as_linear(x)
|
||||||
self.assertLess((y - yq).abs().max(), 1e-2)
|
|
||||||
|
def cosine(a, b):
|
||||||
|
ab = (a * b).sum(-1)
|
||||||
|
aa = mx.linalg.norm(a, axis=-1)
|
||||||
|
bb = mx.linalg.norm(b, axis=-1)
|
||||||
|
return ab / aa / bb
|
||||||
|
|
||||||
|
self.assertGreater(cosine(y, yq).min(), 0.99)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -16,7 +16,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||||
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all())
|
self.assertTrue((errors <= (scales[..., None] + eps)).all())
|
||||||
|
|
||||||
def test_qmm(self):
|
def test_qmm(self):
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user