mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
@@ -454,6 +454,39 @@ void init_random(nb::module_& parent_module) {
|
||||
Returns:
|
||||
array: The output array of random values.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"permuation",
|
||||
[](const std::variant<int, array>& x,
|
||||
int axis,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
if (auto pv = std::get_if<int>(&x); pv) {
|
||||
return permutation(*pv, key, s);
|
||||
} else {
|
||||
return permutation(std::get<array>(x), axis, key, s);
|
||||
}
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"axis"_a = 0,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def permutation(x: Union[int, array], axis: int = 0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate a random permutation or permute the entries of an array.
|
||||
|
||||
Args:
|
||||
x (int or array, optional): If an integer is provided a random
|
||||
permtuation of ``mx.arange(x)`` is returned. Otherwise the entries
|
||||
of ``x`` along the given axis are randomly permuted.
|
||||
axis (int, optional): The axis to permute along. Default: ``0``.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array:
|
||||
The generated random permutation or randomly permuted input array.
|
||||
)pbdoc");
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = nb::module_::import_("atexit");
|
||||
atexit.attr("register")(nb::cpp_function([]() { default_key().release(); }));
|
||||
|
Reference in New Issue
Block a user