No reshapes in quantized embedding (#1682)

* no reshapes in quantized embedding

* fix inadvertant cast

* add tol
This commit is contained in:
Awni Hannun
2024-12-09 18:57:38 -08:00
committed by GitHub
parent 87d7a2520e
commit 29a620cab2
6 changed files with 26 additions and 12 deletions

View File

@@ -348,6 +348,10 @@ class TestRandom(mlx_tests.MLXTestCase):
x = mx.random.permutation(16384)
self.assertFalse(mx.array_equal(sorted_x, x))
# Preserves shape / doesn't cast input to int
x = mx.random.permutation(mx.array([[1]]))
self.assertEqual(x.shape, (1, 1))
if __name__ == "__main__":
unittest.main()