mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:29:35 +08:00
No reshapes in quantized embedding (#1682)
* no reshapes in quantized embedding * fix inadvertant cast * add tol
This commit is contained in:
@@ -348,6 +348,10 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
x = mx.random.permutation(16384)
|
||||
self.assertFalse(mx.array_equal(sorted_x, x))
|
||||
|
||||
# Preserves shape / doesn't cast input to int
|
||||
x = mx.random.permutation(mx.array([[1]]))
|
||||
self.assertEqual(x.shape, (1, 1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user