Add mx.random.permutation (#1471)

* random permutation

* comment
This commit is contained in:
Awni Hannun
2024-10-08 19:42:19 -07:00
committed by GitHub
parent 1fa0d20a30
commit e1c9600da3
5 changed files with 85 additions and 0 deletions

View File

@@ -254,4 +254,17 @@ inline array laplace(
return laplace(shape, float32, 0.0, 1.0, key, s);
}
/* Randomly permute the elements of x along the given axis. */
array permutation(
const array& x,
int axis = 0,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
/* A random permutation of `arange(x)` */
array permutation(
int x,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
} // namespace mlx::core::random