mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Make sure 0 is represented in the quantization (#1016)
This commit is contained in:

committed by
GitHub

parent
ed83908931
commit
84d61d27aa
@@ -1631,12 +1631,19 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
x = mx.array([2, 6, 9, 3, 0, 3])
|
||||
y = emb(x)
|
||||
yq = qemb(x)
|
||||
self.assertLess((y - yq).abs().max(), 1e-3)
|
||||
self.assertLess((y - yq).abs().max(), qemb.scales.max())
|
||||
|
||||
x = mx.random.uniform(shape=(2, 256))
|
||||
y = emb.as_linear(x)
|
||||
yq = qemb.as_linear(x)
|
||||
self.assertLess((y - yq).abs().max(), 1e-2)
|
||||
|
||||
def cosine(a, b):
|
||||
ab = (a * b).sum(-1)
|
||||
aa = mx.linalg.norm(a, axis=-1)
|
||||
bb = mx.linalg.norm(b, axis=-1)
|
||||
return ab / aa / bb
|
||||
|
||||
self.assertGreater(cosine(y, yq).min(), 0.99)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user