29 static uint64_t get_current_time_seed() {
30 auto now = std::chrono::system_clock::now();
31 return std::chrono::duration_cast<std::chrono::milliseconds>(
32 now.time_since_epoch())
47 const std::optional<array>&
key = std::nullopt,
51 const std::optional<array>&
key = std::nullopt,
68 const std::optional<array>&
key = std::nullopt,
71template <
typename T,
typename U>
77 const std::optional<array>&
key = std::nullopt,
86 const std::optional<array>&
key = std::nullopt,
90 const std::optional<array>&
key = std::nullopt,
101 const std::optional<array>&
key = std::nullopt,
107 const std::optional<array>&
key = std::nullopt,
114 const std::optional<array>&
key = std::nullopt,
116 return normal(shape, dtype, 0.0, 1.0,
key, s);
120 const std::optional<array>&
key = std::nullopt,
131 const std::optional<array>&
key = std::nullopt,
140 const std::optional<array>&
key = std::nullopt,
143template <
typename T,
typename U>
149 const std::optional<array>&
key = std::nullopt,
158 const std::optional<array>&
key = std::nullopt,
162 const std::optional<array>&
key = std::nullopt,
168 const std::optional<array>&
key = std::nullopt,
177 const std::optional<array>&
key = std::nullopt,
183 const std::optional<array>&
key = std::nullopt,
191 const std::optional<array>&
key = std::nullopt,
198 const std::optional<array>&
key = std::nullopt,
204 const std::optional<array>&
key = std::nullopt,
211 const std::optional<array>&
key = std::nullopt,
215 const array& logits_,
218 const std::optional<array>&
key = std::nullopt,
224 const std::optional<array>&
key = std::nullopt,
233 const std::optional<array>&
key = std::nullopt,
239 const std::optional<array>&
key = std::nullopt,
246 const std::optional<array>&
key = std::nullopt,
248 return laplace(shape, dtype, 0.0, 1.0,
key, s);
252 const std::optional<array>&
key = std::nullopt,
261 const std::optional<array>&
key = std::nullopt,
267 const std::optional<array>&
key = std::nullopt,
KeySequence(uint64_t seed)
static KeySequence & default_()
Definition random.h:22
array mean(const array &a, bool keepdims, StreamOrDevice s={})
Computes the mean of the elements of an array.
array bernoulli(const array &p, const Shape &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate binary variables with probability to be true equal to p.
array categorical(const array &logits, int axis, const Shape &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
std::pair< array, array > split(const array &key, StreamOrDevice s={})
Split the rng key into a pair of keys.
array normal(const Shape &shape, Dtype dtype, const float loc, const float scale, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate samples from the standard normal distribution.
array gumbel(const Shape &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
array laplace(const Shape &shape, Dtype dtype, const float loc, const float scale, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate samples from the laplace distribution.
array uniform(const array &low, const array &high, const Shape &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate uniform random numbers between low and high.
void seed(uint64_t seed)
Seed the default PRNG key.
array key(uint64_t seed)
Get a PRNG key from a seed.
array bits(const Shape &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
array randint(const array &low, const array &high, const Shape &shape, Dtype dtype=int32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate integer samples uniformly at random.
array permutation(const array &x, int axis=0, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
array multivariate_normal(const array &mean, const array &cov, const Shape &shape, Dtype dtype, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate samples from a multivariate normal distribution.
array truncated_normal(const array &lower, const array &upper, const Shape &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Stream to_stream(StreamOrDevice s)
constexpr Dtype int32
Definition dtype.h:77
constexpr Dtype float32
Definition dtype.h:81
std::vector< ShapeElem > Shape
Definition array.h:21
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:15