Add mx.random.permutation (#1471)

* random permutation

* comment
This commit is contained in:
Awni Hannun 2024-10-08 19:42:19 -07:00 committed by GitHub
parent 1fa0d20a30
commit e1c9600da3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 85 additions and 0 deletions

View File

@ -45,3 +45,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
truncated_normal truncated_normal
uniform uniform
laplace laplace
permutation

View File

@ -458,4 +458,19 @@ array laplace(
return samples; return samples;
} }
array permutation(
const array& x,
int axis /* = 0 */,
const std::optional<array>& key /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
return take(x, permutation(x.shape(axis), key, s), axis, s);
}
array permutation(
int x,
const std::optional<array>& key /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
return argsort(bits({x}, key, s), s);
}
} // namespace mlx::core::random } // namespace mlx::core::random

View File

@ -254,4 +254,17 @@ inline array laplace(
return laplace(shape, float32, 0.0, 1.0, key, s); return laplace(shape, float32, 0.0, 1.0, key, s);
} }
/* Randomly permute the elements of x along the given axis. */
array permutation(
const array& x,
int axis = 0,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
/* A random permutation of `arange(x)` */
array permutation(
int x,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
} // namespace mlx::core::random } // namespace mlx::core::random

View File

@ -454,6 +454,39 @@ void init_random(nb::module_& parent_module) {
Returns: Returns:
array: The output array of random values. array: The output array of random values.
)pbdoc"); )pbdoc");
m.def(
"permuation",
[](const std::variant<int, array>& x,
int axis,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
if (auto pv = std::get_if<int>(&x); pv) {
return permutation(*pv, key, s);
} else {
return permutation(std::get<array>(x), axis, key, s);
}
},
"shape"_a = std::vector<int>{},
"axis"_a = 0,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def permutation(x: Union[int, array], axis: int = 0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate a random permutation or permute the entries of an array.
Args:
x (int or array, optional): If an integer is provided a random
permtuation of ``mx.arange(x)`` is returned. Otherwise the entries
of ``x`` along the given axis are randomly permuted.
axis (int, optional): The axis to permute along. Default: ``0``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array:
The generated random permutation or randomly permuted input array.
)pbdoc");
// Register static Python object cleanup before the interpreter exits // Register static Python object cleanup before the interpreter exits
auto atexit = nb::module_::import_("atexit"); auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([]() { default_key().release(); })); atexit.attr("register")(nb::cpp_function([]() { default_key().release(); }));

View File

@ -325,6 +325,29 @@ class TestRandom(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mx.random.categorical(logits, shape=[10, 5], num_samples=5) mx.random.categorical(logits, shape=[10, 5], num_samples=5)
def test_permutation(self):
x = sorted(mx.random.permutation(4).tolist())
self.assertEqual([0, 1, 2, 3], x)
x = mx.array([0, 1, 2, 3])
x = sorted(mx.random.permutation(x).tolist())
self.assertEqual([0, 1, 2, 3], x)
x = mx.array([0, 1, 2, 3])
x = sorted(mx.random.permutation(x).tolist())
# 2-D
x = mx.arange(16).reshape(4, 4)
out = mx.sort(mx.random.permutation(x, axis=0), axis=0)
self.assertTrue(mx.array_equal(x, out))
out = mx.sort(mx.random.permutation(x, axis=1), axis=1)
self.assertTrue(mx.array_equal(x, out))
# Basically 0 probability this should fail.
sorted_x = mx.arange(16384)
x = mx.random.permutation(16384)
self.assertFalse(mx.array_equal(sorted_x, x))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()