28 static uint64_t get_current_time_seed() {
29 auto now = std::chrono::system_clock::now();
30 return std::chrono::duration_cast<std::chrono::milliseconds>(
31 now.time_since_epoch())
44 const std::vector<int>& shape,
46 const std::optional<array>&
key = std::nullopt,
49 const std::vector<int>& shape,
50 const std::optional<array>&
key = std::nullopt,
65 const std::vector<int>& shape,
67 const std::optional<array>&
key = std::nullopt,
70template <
typename T,
typename U>
74 const std::vector<int>& shape,
76 const std::optional<array>&
key = std::nullopt,
83 const std::vector<int>& shape,
85 const std::optional<array>&
key = std::nullopt,
88 const std::vector<int>& shape,
89 const std::optional<array>&
key = std::nullopt,
96 const std::vector<int>& shape,
100 const std::optional<array>&
key = std::nullopt,
103 const std::vector<int>& shape,
106 const std::optional<array>&
key = std::nullopt,
111 const std::vector<int>& shape,
113 const std::optional<array>&
key = std::nullopt,
115 return normal(shape, dtype, 0.0, 1.0,
key, s);
118 const std::vector<int>& shape,
119 const std::optional<array>&
key = std::nullopt,
128 const std::vector<int>& shape,
130 const std::optional<array>&
key = std::nullopt,
137 const std::vector<int>& shape,
139 const std::optional<array>&
key = std::nullopt,
142template <
typename T,
typename U>
146 const std::vector<int>& shape,
148 const std::optional<array>&
key = std::nullopt,
156 const std::vector<int>& shape,
157 const std::optional<array>&
key = std::nullopt,
161 const std::optional<array>&
key = std::nullopt,
167 const std::optional<array>&
key = std::nullopt,
175 const std::vector<int>& shape,
176 const std::optional<array>&
key = std::nullopt,
182 const std::optional<array>&
key = std::nullopt,
188 const std::vector<int>& shape,
190 const std::optional<array>&
key = std::nullopt,
197 const std::optional<array>&
key = std::nullopt,
201 const std::vector<int>& shape,
203 const std::optional<array>&
key = std::nullopt,
209 const std::vector<int>& shape,
210 const std::optional<array>&
key = std::nullopt,
214 const array& logits_,
217 const std::optional<array>&
key = std::nullopt,
223 const std::optional<array>&
key = std::nullopt,
KeySequence(uint64_t seed)
static KeySequence & default_()
Definition random.h:21
array mean(const array &a, bool keepdims, StreamOrDevice s={})
Computes the mean of the elements of an array.
array truncated_normal(const array &lower, const array &upper, const std::vector< int > &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
array categorical(const array &logits, int axis, const std::vector< int > &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
std::pair< array, array > split(const array &key, StreamOrDevice s={})
Split the rng key into a pair of keys.
array randint(const array &low, const array &high, const std::vector< int > &shape, Dtype dtype=int32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate integer samples uniformly at random.
array multivariate_normal(const array &mean, const array &cov, const std::vector< int > &shape, Dtype dtype, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate samples from a multivariate normal distribution.
array normal(const std::vector< int > &shape, Dtype dtype, const float loc, const float scale, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate samples from the standard normal distribution.
array gumbel(const std::vector< int > &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
void seed(uint64_t seed)
Seed the default PRNG key.
array bernoulli(const array &p, const std::vector< int > &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate binary variables with probability to be true equal to p.
array key(uint64_t seed)
Get a PRNG key from a seed.
array uniform(const array &low, const array &high, const std::vector< int > &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate uniform random numbers between low and high.
Stream to_stream(StreamOrDevice s)
constexpr Dtype int32
Definition dtype.h:69
constexpr Dtype float32
Definition dtype.h:73
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:14