Add mx.random.permutation (#1471)

* random permutation

* comment
This commit is contained in:
Awni Hannun
2024-10-08 19:42:19 -07:00
committed by GitHub
parent 1fa0d20a30
commit e1c9600da3
5 changed files with 85 additions and 0 deletions

View File

@@ -454,6 +454,39 @@ void init_random(nb::module_& parent_module) {
Returns:
array: The output array of random values.
)pbdoc");
m.def(
"permuation",
[](const std::variant<int, array>& x,
int axis,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
if (auto pv = std::get_if<int>(&x); pv) {
return permutation(*pv, key, s);
} else {
return permutation(std::get<array>(x), axis, key, s);
}
},
"shape"_a = std::vector<int>{},
"axis"_a = 0,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def permutation(x: Union[int, array], axis: int = 0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate a random permutation or permute the entries of an array.
Args:
x (int or array, optional): If an integer is provided a random
permtuation of ``mx.arange(x)`` is returned. Otherwise the entries
of ``x`` along the given axis are randomly permuted.
axis (int, optional): The axis to permute along. Default: ``0``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array:
The generated random permutation or randomly permuted input array.
)pbdoc");
// Register static Python object cleanup before the interpreter exits
auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([]() { default_key().release(); }));