MLX
 
Loading...
Searching...
No Matches
random.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <chrono>
6#include <optional>
7
8#include "mlx/array.h"
9#include "mlx/stream.h"
10#include "mlx/utils.h"
11
12namespace mlx::core::random {
13
15 public:
16 explicit KeySequence(uint64_t seed);
17
18 void seed(uint64_t seed);
20
21 // static default
23 static KeySequence ks(get_current_time_seed());
24 return ks;
25 }
26
27 private:
28 array key_;
29 static uint64_t get_current_time_seed() {
30 auto now = std::chrono::system_clock::now();
31 return std::chrono::duration_cast<std::chrono::milliseconds>(
32 now.time_since_epoch())
33 .count();
34 }
35};
36
38array key(uint64_t seed);
39
41void seed(uint64_t seed);
42
45 const Shape& shape,
46 int width,
47 const std::optional<array>& key = std::nullopt,
48 StreamOrDevice s = {});
49inline array bits(
50 const Shape& shape,
51 const std::optional<array>& key = std::nullopt,
52 StreamOrDevice s = {}) {
53 return bits(shape, 4, key, s);
54}
55
57std::pair<array, array> split(const array& key, StreamOrDevice s = {});
58
60array split(const array& key, int num, StreamOrDevice s = {});
61
64 const array& low,
65 const array& high,
66 const Shape& shape,
67 Dtype dtype = float32,
68 const std::optional<array>& key = std::nullopt,
69 StreamOrDevice s = {});
70
71template <typename T, typename U>
73 T low,
74 U high,
75 const Shape& shape,
76 Dtype dtype = float32,
77 const std::optional<array>& key = std::nullopt,
78 StreamOrDevice s = {}) {
79 return uniform(array(low), array(high), shape, dtype, key, to_stream(s));
80}
81
84 const Shape& shape,
85 Dtype dtype,
86 const std::optional<array>& key = std::nullopt,
87 StreamOrDevice s = {});
89 const Shape& shape,
90 const std::optional<array>& key = std::nullopt,
91 StreamOrDevice s = {}) {
92 return uniform(shape, float32, key);
93}
94
97 const Shape& shape,
98 Dtype dtype,
99 const float loc,
100 const float scale,
101 const std::optional<array>& key = std::nullopt,
102 StreamOrDevice s = {});
104 const Shape& shape,
105 const float loc,
106 const float scale,
107 const std::optional<array>& key = std::nullopt,
108 StreamOrDevice s = {}) {
109 return normal(shape, float32, loc, scale, key, s);
110}
112 const Shape& shape,
113 const Dtype dtype,
114 const std::optional<array>& key = std::nullopt,
115 StreamOrDevice s = {}) {
116 return normal(shape, dtype, 0.0, 1.0, key, s);
117}
119 const Shape& shape,
120 const std::optional<array>& key = std::nullopt,
121 StreamOrDevice s = {}) {
122 return normal(shape, float32, 0.0, 1.0, key, s);
123}
124
127 const array& mean,
128 const array& cov,
129 const Shape& shape,
130 Dtype dtype,
131 const std::optional<array>& key = std::nullopt,
132 StreamOrDevice s = {});
133
136 const array& low,
137 const array& high,
138 const Shape& shape,
139 Dtype dtype = int32,
140 const std::optional<array>& key = std::nullopt,
141 StreamOrDevice s = {});
142
143template <typename T, typename U>
145 T low,
146 U high,
147 const Shape& shape,
148 Dtype dtype = int32,
149 const std::optional<array>& key = std::nullopt,
150 StreamOrDevice s = {}) {
151 return randint(array(low), array(high), shape, dtype, key, to_stream(s));
152}
153
156 const array& p,
157 const Shape& shape,
158 const std::optional<array>& key = std::nullopt,
159 StreamOrDevice s = {});
161 const array& p,
162 const std::optional<array>& key = std::nullopt,
163 StreamOrDevice s = {});
164
165template <typename T>
167 T p,
168 const std::optional<array>& key = std::nullopt,
169 StreamOrDevice s = {}) {
170 return bernoulli(array(p), key, s);
171}
172
173template <typename T>
175 T p,
176 const Shape& shape,
177 const std::optional<array>& key = std::nullopt,
178 StreamOrDevice s = {}) {
179 return bernoulli(array(p), shape, key, s);
180}
181
183 const std::optional<array>& key = std::nullopt,
184 StreamOrDevice s = {});
185
187 const array& lower,
188 const array& upper,
189 const Shape& shape,
190 Dtype dtype = float32,
191 const std::optional<array>& key = std::nullopt,
192 StreamOrDevice s = {});
193
195 const array& lower,
196 const array& upper,
197 Dtype dtype = float32,
198 const std::optional<array>& key = std::nullopt,
199 StreamOrDevice s = {});
200
202 const Shape& shape,
203 Dtype dtype = float32,
204 const std::optional<array>& key = std::nullopt,
205 StreamOrDevice s = {});
206
208 const array& logits,
209 int axis,
210 const Shape& shape,
211 const std::optional<array>& key = std::nullopt,
212 StreamOrDevice s = {});
213
215 const array& logits_,
216 int axis,
217 int num_samples,
218 const std::optional<array>& key = std::nullopt,
219 StreamOrDevice s = {});
220
222 const array& logits,
223 int axis = -1,
224 const std::optional<array>& key = std::nullopt,
225 StreamOrDevice s = {});
226
229 const Shape& shape,
230 Dtype dtype,
231 const float loc,
232 const float scale,
233 const std::optional<array>& key = std::nullopt,
234 StreamOrDevice s = {});
236 const Shape& shape,
237 const float loc,
238 const float scale,
239 const std::optional<array>& key = std::nullopt,
240 StreamOrDevice s = {}) {
241 return laplace(shape, float32, loc, scale, key, s);
242}
244 const Shape& shape,
245 const Dtype dtype,
246 const std::optional<array>& key = std::nullopt,
247 StreamOrDevice s = {}) {
248 return laplace(shape, dtype, 0.0, 1.0, key, s);
249}
251 const Shape& shape,
252 const std::optional<array>& key = std::nullopt,
253 StreamOrDevice s = {}) {
254 return laplace(shape, float32, 0.0, 1.0, key, s);
255}
256
257/* Randomly permute the elements of x along the given axis. */
259 const array& x,
260 int axis = 0,
261 const std::optional<array>& key = std::nullopt,
262 StreamOrDevice s = {});
263
264/* A random permutation of `arange(x)` */
266 int x,
267 const std::optional<array>& key = std::nullopt,
268 StreamOrDevice s = {});
269
270} // namespace mlx::core::random
Definition array.h:24
void seed(uint64_t seed)
static KeySequence & default_()
Definition random.h:22
array mean(const array &a, bool keepdims, StreamOrDevice s={})
Computes the mean of the elements of an array.
Definition threefry.h:8
array bernoulli(const array &p, const Shape &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate binary variables with probability to be true equal to p.
array categorical(const array &logits, int axis, const Shape &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
std::pair< array, array > split(const array &key, StreamOrDevice s={})
Split the rng key into a pair of keys.
array normal(const Shape &shape, Dtype dtype, const float loc, const float scale, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate samples from the standard normal distribution.
array gumbel(const Shape &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
array laplace(const Shape &shape, Dtype dtype, const float loc, const float scale, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate samples from the laplace distribution.
array uniform(const array &low, const array &high, const Shape &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate uniform random numbers between low and high.
void seed(uint64_t seed)
Seed the default PRNG key.
array key(uint64_t seed)
Get a PRNG key from a seed.
array bits(const Shape &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
array randint(const array &low, const array &high, const Shape &shape, Dtype dtype=int32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate integer samples uniformly at random.
array permutation(const array &x, int axis=0, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
array multivariate_normal(const array &mean, const array &cov, const Shape &shape, Dtype dtype, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate samples from a multivariate normal distribution.
array truncated_normal(const array &lower, const array &upper, const Shape &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Stream to_stream(StreamOrDevice s)
constexpr Dtype int32
Definition dtype.h:77
constexpr Dtype float32
Definition dtype.h:81
std::vector< ShapeElem > Shape
Definition array.h:21
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:15
Definition dtype.h:13