No reshapes in quantized embedding (#1682)

* no reshapes in quantized embedding

* fix inadvertant cast

* add tol
This commit is contained in:
Awni Hannun
2024-12-09 18:57:38 -08:00
committed by GitHub
parent 87d7a2520e
commit 29a620cab2
6 changed files with 26 additions and 12 deletions

View File

@@ -353,7 +353,7 @@ class TestVmap(mlx_tests.MLXTestCase):
for i in range(a.shape[0]):
self.assertTrue(
mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=0, atol=1e-5)
mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=1e-4, atol=1e-5)
)
a = mx.random.uniform(shape=(4, 3, 4))
@@ -367,7 +367,9 @@ class TestVmap(mlx_tests.MLXTestCase):
for i in range(a.shape[1]):
self.assertTrue(
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
mx.allclose(
a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=1e-4, atol=1e-5
)
)
def test_vmap_gather(self):