29 static uint64_t get_current_time_seed() {
30 auto now = std::chrono::system_clock::now();
31 return std::chrono::duration_cast<std::chrono::milliseconds>(
32 now.time_since_epoch())
45 const std::vector<int>& shape,
47 const std::optional<array>&
key = std::nullopt,
50 const std::vector<int>& shape,
51 const std::optional<array>&
key = std::nullopt,
66 const std::vector<int>& shape,
68 const std::optional<array>&
key = std::nullopt,
71template <
typename T,
typename U>
75 const std::vector<int>& shape,
77 const std::optional<array>&
key = std::nullopt,
84 const std::vector<int>& shape,
86 const std::optional<array>&
key = std::nullopt,
89 const std::vector<int>& shape,
90 const std::optional<array>&
key = std::nullopt,
97 const std::vector<int>& shape,
101 const std::optional<array>&
key = std::nullopt,
104 const std::vector<int>& shape,
107 const std::optional<array>&
key = std::nullopt,
112 const std::vector<int>& shape,
114 const std::optional<array>&
key = std::nullopt,
116 return normal(shape, dtype, 0.0, 1.0,
key, s);
119 const std::vector<int>& shape,
120 const std::optional<array>&
key = std::nullopt,
129 const std::vector<int>& shape,
131 const std::optional<array>&
key = std::nullopt,
138 const std::vector<int>& shape,
140 const std::optional<array>&
key = std::nullopt,
143template <
typename T,
typename U>
147 const std::vector<int>& shape,
149 const std::optional<array>&
key = std::nullopt,
157 const std::vector<int>& shape,
158 const std::optional<array>&
key = std::nullopt,
162 const std::optional<array>&
key = std::nullopt,
168 const std::optional<array>&
key = std::nullopt,
176 const std::vector<int>& shape,
177 const std::optional<array>&
key = std::nullopt,
183 const std::optional<array>&
key = std::nullopt,
189 const std::vector<int>& shape,
191 const std::optional<array>&
key = std::nullopt,
198 const std::optional<array>&
key = std::nullopt,
202 const std::vector<int>& shape,
204 const std::optional<array>&
key = std::nullopt,
210 const std::vector<int>& shape,
211 const std::optional<array>&
key = std::nullopt,
215 const array& logits_,
218 const std::optional<array>&
key = std::nullopt,
224 const std::optional<array>&
key = std::nullopt,
229 const std::vector<int>& shape,
233 const std::optional<array>&
key = std::nullopt,
236 const std::vector<int>& shape,
239 const std::optional<array>&
key = std::nullopt,
244 const std::vector<int>& shape,
246 const std::optional<array>&
key = std::nullopt,
248 return laplace(shape, dtype, 0.0, 1.0,
key, s);
251 const std::vector<int>& shape,
252 const std::optional<array>&
key = std::nullopt,
KeySequence(uint64_t seed)
static KeySequence & default_()
Definition random.h:22
array mean(const array &a, bool keepdims, StreamOrDevice s={})
Computes the mean of the elements of an array.
array truncated_normal(const array &lower, const array &upper, const std::vector< int > &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
array categorical(const array &logits, int axis, const std::vector< int > &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
std::pair< array, array > split(const array &key, StreamOrDevice s={})
Split the rng key into a pair of keys.
array randint(const array &low, const array &high, const std::vector< int > &shape, Dtype dtype=int32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate integer samples uniformly at random.
array multivariate_normal(const array &mean, const array &cov, const std::vector< int > &shape, Dtype dtype, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate samples from a multivariate normal distribution.
array normal(const std::vector< int > &shape, Dtype dtype, const float loc, const float scale, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate samples from the standard normal distribution.
array gumbel(const std::vector< int > &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
void seed(uint64_t seed)
Seed the default PRNG key.
array bernoulli(const array &p, const std::vector< int > &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate binary variables with probability to be true equal to p.
array key(uint64_t seed)
Get a PRNG key from a seed.
array laplace(const std::vector< int > &shape, Dtype dtype, const float loc, const float scale, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate samples from the laplace distribution.
array uniform(const array &low, const array &high, const std::vector< int > &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate uniform random numbers between low and high.
Stream to_stream(StreamOrDevice s)
constexpr Dtype int32
Definition dtype.h:67
constexpr Dtype float32
Definition dtype.h:71
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:14