mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 08:24:39 +08:00
@@ -325,6 +325,29 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.categorical(logits, shape=[10, 5], num_samples=5)
|
||||
|
||||
def test_permutation(self):
|
||||
x = sorted(mx.random.permutation(4).tolist())
|
||||
self.assertEqual([0, 1, 2, 3], x)
|
||||
|
||||
x = mx.array([0, 1, 2, 3])
|
||||
x = sorted(mx.random.permutation(x).tolist())
|
||||
self.assertEqual([0, 1, 2, 3], x)
|
||||
|
||||
x = mx.array([0, 1, 2, 3])
|
||||
x = sorted(mx.random.permutation(x).tolist())
|
||||
|
||||
# 2-D
|
||||
x = mx.arange(16).reshape(4, 4)
|
||||
out = mx.sort(mx.random.permutation(x, axis=0), axis=0)
|
||||
self.assertTrue(mx.array_equal(x, out))
|
||||
out = mx.sort(mx.random.permutation(x, axis=1), axis=1)
|
||||
self.assertTrue(mx.array_equal(x, out))
|
||||
|
||||
# Basically 0 probability this should fail.
|
||||
sorted_x = mx.arange(16384)
|
||||
x = mx.random.permutation(16384)
|
||||
self.assertFalse(mx.array_equal(sorted_x, x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user