From e1c9600da37aa1a7ac0b5ebaa37a04534b1e27e2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 8 Oct 2024 19:42:19 -0700 Subject: [PATCH] Add `mx.random.permutation` (#1471) * random permutation * comment --- docs/src/python/random.rst | 1 + mlx/random.cpp | 15 +++++++++++++++ mlx/random.h | 13 +++++++++++++ python/src/random.cpp | 33 +++++++++++++++++++++++++++++++++ python/tests/test_random.py | 23 +++++++++++++++++++++++ 5 files changed, 85 insertions(+) diff --git a/docs/src/python/random.rst b/docs/src/python/random.rst index 5d98304bb..248959108 100644 --- a/docs/src/python/random.rst +++ b/docs/src/python/random.rst @@ -45,3 +45,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG. truncated_normal uniform laplace + permutation diff --git a/mlx/random.cpp b/mlx/random.cpp index 8dae8964b..6368bdeed 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -458,4 +458,19 @@ array laplace( return samples; } +array permutation( + const array& x, + int axis /* = 0 */, + const std::optional& key /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + return take(x, permutation(x.shape(axis), key, s), axis, s); +} + +array permutation( + int x, + const std::optional& key /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + return argsort(bits({x}, key, s), s); +} + } // namespace mlx::core::random diff --git a/mlx/random.h b/mlx/random.h index ad030c7e3..d4d827230 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -254,4 +254,17 @@ inline array laplace( return laplace(shape, float32, 0.0, 1.0, key, s); } +/* Randomly permute the elements of x along the given axis. */ +array permutation( + const array& x, + int axis = 0, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +/* A random permutation of `arange(x)` */ +array permutation( + int x, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + } // namespace mlx::core::random diff --git a/python/src/random.cpp b/python/src/random.cpp index 13055b1fc..af95d4e6a 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -454,6 +454,39 @@ void init_random(nb::module_& parent_module) { Returns: array: The output array of random values. )pbdoc"); + m.def( + "permuation", + [](const std::variant& x, + int axis, + const std::optional& key_, + StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); + if (auto pv = std::get_if(&x); pv) { + return permutation(*pv, key, s); + } else { + return permutation(std::get(x), axis, key, s); + } + }, + "shape"_a = std::vector{}, + "axis"_a = 0, + "key"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def permutation(x: Union[int, array], axis: int = 0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Generate a random permutation or permute the entries of an array. + + Args: + x (int or array, optional): If an integer is provided a random + permtuation of ``mx.arange(x)`` is returned. Otherwise the entries + of ``x`` along the given axis are randomly permuted. + axis (int, optional): The axis to permute along. Default: ``0``. + key (array, optional): A PRNG key. Default: ``None``. + + Returns: + array: + The generated random permutation or randomly permuted input array. + )pbdoc"); // Register static Python object cleanup before the interpreter exits auto atexit = nb::module_::import_("atexit"); atexit.attr("register")(nb::cpp_function([]() { default_key().release(); })); diff --git a/python/tests/test_random.py b/python/tests/test_random.py index b6f632491..3491297ed 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -325,6 +325,29 @@ class TestRandom(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.random.categorical(logits, shape=[10, 5], num_samples=5) + def test_permutation(self): + x = sorted(mx.random.permutation(4).tolist()) + self.assertEqual([0, 1, 2, 3], x) + + x = mx.array([0, 1, 2, 3]) + x = sorted(mx.random.permutation(x).tolist()) + self.assertEqual([0, 1, 2, 3], x) + + x = mx.array([0, 1, 2, 3]) + x = sorted(mx.random.permutation(x).tolist()) + + # 2-D + x = mx.arange(16).reshape(4, 4) + out = mx.sort(mx.random.permutation(x, axis=0), axis=0) + self.assertTrue(mx.array_equal(x, out)) + out = mx.sort(mx.random.permutation(x, axis=1), axis=1) + self.assertTrue(mx.array_equal(x, out)) + + # Basically 0 probability this should fail. + sorted_x = mx.arange(16384) + x = mx.random.permutation(16384) + self.assertFalse(mx.array_equal(sorted_x, x)) + if __name__ == "__main__": unittest.main()