diff --git a/mlx/random.h b/mlx/random.h new file mode 100644 index 000000000..25dc9f924 --- /dev/null +++ b/mlx/random.h @@ -0,0 +1,191 @@ +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/stream.h" + +namespace mlx::core::random { + +class KeySequence { + public: + explicit KeySequence(uint64_t seed); + + void seed(uint64_t seed); + array next(); + + // static defualt + static KeySequence& default_() { + static KeySequence ks(0); + return ks; + } + + private: + array key_; +}; + +/** Get a PRNG key from a seed. */ +array key(uint64_t seed); + +/** Seed the default PRNG key. */ +void seed(uint64_t seed); + +/** Generate an array with type uint32 filled with random bits. */ +array bits( + const std::vector& shape, + int width, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +inline array bits( + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return bits(shape, 4, key, s); +} + +/** Split the rng key into a pair of keys. */ +std::pair split(const array& key, StreamOrDevice s = {}); + +/** Split the rng key into `num` keys. */ +array split(const array& key, int num, StreamOrDevice s = {}); + +/** Generate uniform random numbers between low and high. */ +array uniform( + const array& low, + const array& high, + const std::vector& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +template +array uniform( + T low, + U high, + const std::vector& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return uniform(array(low), array(high), shape, dtype, key, to_stream(s)); +} + +/** Generate uniform random numbers between 0 and 1. */ +array uniform( + const std::vector& shape, + Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +inline array uniform( + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return uniform(shape, float32, key); +} + +/** Generate samples from the standard normal distribution. */ +array normal( + const std::vector& shape, + Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +inline array normal( + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return normal(shape, float32, key, s); +} + +/** Generate integer samples uniformly at random */ +array randint( + const array& low, + const array& high, + const std::vector& shape, + Dtype dtype = int32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +template +array randint( + T low, + U high, + const std::vector& shape, + Dtype dtype = int32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return randint(array(low), array(high), shape, dtype, key, to_stream(s)); +}; + +/** Generate binary variables with probability to be true equal to p */ +array bernoulli( + const array& p, + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +array bernoulli( + const array& p, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +template +array bernoulli( + T p, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return bernoulli(array(p), key, s); +}; + +template +array bernoulli( + T p, + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return bernoulli(array(p), shape, key, s); +}; + +array bernoulli( + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array truncated_normal( + const array& lower, + const array& upper, + const std::vector& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array truncated_normal( + const array& lower, + const array& upper, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array gumbel( + const std::vector& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array categorical( + const array& logits, + int axis, + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array categorical( + const array& logits_, + int axis, + int num_samples, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array categorical( + const array& logits, + int axis = -1, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +} // namespace mlx::core::random