Make sure 0 is represented in the quantization (#1016)

This commit is contained in:
Angelos Katharopoulos
2024-04-19 19:47:26 -07:00
committed by GitHub
parent ed83908931
commit 84d61d27aa
3 changed files with 13 additions and 3 deletions

View File

@@ -1631,12 +1631,19 @@ class TestLayers(mlx_tests.MLXTestCase):
x = mx.array([2, 6, 9, 3, 0, 3])
y = emb(x)
yq = qemb(x)
self.assertLess((y - yq).abs().max(), 1e-3)
self.assertLess((y - yq).abs().max(), qemb.scales.max())
x = mx.random.uniform(shape=(2, 256))
y = emb.as_linear(x)
yq = qemb.as_linear(x)
self.assertLess((y - yq).abs().max(), 1e-2)
def cosine(a, b):
ab = (a * b).sum(-1)
aa = mx.linalg.norm(a, axis=-1)
bb = mx.linalg.norm(b, axis=-1)
return ab / aa / bb
self.assertGreater(cosine(y, yq).min(), 0.99)
if __name__ == "__main__":