Add mx.random.permutation (#1471)

* random permutation

* comment
This commit is contained in:
Awni Hannun
2024-10-08 19:42:19 -07:00
committed by GitHub
parent 1fa0d20a30
commit e1c9600da3
5 changed files with 85 additions and 0 deletions

View File

@@ -325,6 +325,29 @@ class TestRandom(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
mx.random.categorical(logits, shape=[10, 5], num_samples=5)
def test_permutation(self):
x = sorted(mx.random.permutation(4).tolist())
self.assertEqual([0, 1, 2, 3], x)
x = mx.array([0, 1, 2, 3])
x = sorted(mx.random.permutation(x).tolist())
self.assertEqual([0, 1, 2, 3], x)
x = mx.array([0, 1, 2, 3])
x = sorted(mx.random.permutation(x).tolist())
# 2-D
x = mx.arange(16).reshape(4, 4)
out = mx.sort(mx.random.permutation(x, axis=0), axis=0)
self.assertTrue(mx.array_equal(x, out))
out = mx.sort(mx.random.permutation(x, axis=1), axis=1)
self.assertTrue(mx.array_equal(x, out))
# Basically 0 probability this should fail.
sorted_x = mx.arange(16384)
x = mx.random.permutation(16384)
self.assertFalse(mx.array_equal(sorted_x, x))
if __name__ == "__main__":
unittest.main()