mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
No reshapes in quantized embedding (#1682)
* no reshapes in quantized embedding * fix inadvertant cast * add tol
This commit is contained in:
@@ -353,7 +353,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
|
||||
for i in range(a.shape[0]):
|
||||
self.assertTrue(
|
||||
mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=0, atol=1e-5)
|
||||
mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=1e-4, atol=1e-5)
|
||||
)
|
||||
|
||||
a = mx.random.uniform(shape=(4, 3, 4))
|
||||
@@ -367,7 +367,9 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
|
||||
for i in range(a.shape[1]):
|
||||
self.assertTrue(
|
||||
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
|
||||
mx.allclose(
|
||||
a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=1e-4, atol=1e-5
|
||||
)
|
||||
)
|
||||
|
||||
def test_vmap_gather(self):
|
||||
|
Reference in New Issue
Block a user