MLX
Loading...
Searching...
No Matches
random.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <chrono>
6#include <optional>
7
8#include "mlx/array.h"
9#include "mlx/stream.h"
10
11namespace mlx::core::random {
12
14 public:
15 explicit KeySequence(uint64_t seed);
16
17 void seed(uint64_t seed);
19
20 // static default
22 static KeySequence ks(get_current_time_seed());
23 return ks;
24 }
25
26 private:
27 array key_;
28 static uint64_t get_current_time_seed() {
29 auto now = std::chrono::system_clock::now();
30 return std::chrono::duration_cast<std::chrono::milliseconds>(
31 now.time_since_epoch())
32 .count();
33 }
34};
35
37array key(uint64_t seed);
38
40void seed(uint64_t seed);
41
44 const std::vector<int>& shape,
45 int width,
46 const std::optional<array>& key = std::nullopt,
47 StreamOrDevice s = {});
48inline array bits(
49 const std::vector<int>& shape,
50 const std::optional<array>& key = std::nullopt,
51 StreamOrDevice s = {}) {
52 return bits(shape, 4, key, s);
53}
54
56std::pair<array, array> split(const array& key, StreamOrDevice s = {});
57
59array split(const array& key, int num, StreamOrDevice s = {});
60
63 const array& low,
64 const array& high,
65 const std::vector<int>& shape,
66 Dtype dtype = float32,
67 const std::optional<array>& key = std::nullopt,
68 StreamOrDevice s = {});
69
70template <typename T, typename U>
72 T low,
73 U high,
74 const std::vector<int>& shape,
75 Dtype dtype = float32,
76 const std::optional<array>& key = std::nullopt,
77 StreamOrDevice s = {}) {
78 return uniform(array(low), array(high), shape, dtype, key, to_stream(s));
79}
80
83 const std::vector<int>& shape,
84 Dtype dtype,
85 const std::optional<array>& key = std::nullopt,
86 StreamOrDevice s = {});
88 const std::vector<int>& shape,
89 const std::optional<array>& key = std::nullopt,
90 StreamOrDevice s = {}) {
91 return uniform(shape, float32, key);
92}
93
96 const std::vector<int>& shape,
97 Dtype dtype,
98 const float loc,
99 const float scale,
100 const std::optional<array>& key = std::nullopt,
101 StreamOrDevice s = {});
103 const std::vector<int>& shape,
104 const float loc,
105 const float scale,
106 const std::optional<array>& key = std::nullopt,
107 StreamOrDevice s = {}) {
108 return normal(shape, float32, loc, scale, key, s);
109}
111 const std::vector<int>& shape,
112 const Dtype dtype,
113 const std::optional<array>& key = std::nullopt,
114 StreamOrDevice s = {}) {
115 return normal(shape, dtype, 0.0, 1.0, key, s);
116}
118 const std::vector<int>& shape,
119 const std::optional<array>& key = std::nullopt,
120 StreamOrDevice s = {}) {
121 return normal(shape, float32, 0.0, 1.0, key, s);
122}
123
126 const array& mean,
127 const array& cov,
128 const std::vector<int>& shape,
129 Dtype dtype,
130 const std::optional<array>& key = std::nullopt,
131 StreamOrDevice s = {});
132
135 const array& low,
136 const array& high,
137 const std::vector<int>& shape,
138 Dtype dtype = int32,
139 const std::optional<array>& key = std::nullopt,
140 StreamOrDevice s = {});
141
142template <typename T, typename U>
144 T low,
145 U high,
146 const std::vector<int>& shape,
147 Dtype dtype = int32,
148 const std::optional<array>& key = std::nullopt,
149 StreamOrDevice s = {}) {
150 return randint(array(low), array(high), shape, dtype, key, to_stream(s));
151};
152
155 const array& p,
156 const std::vector<int>& shape,
157 const std::optional<array>& key = std::nullopt,
158 StreamOrDevice s = {});
160 const array& p,
161 const std::optional<array>& key = std::nullopt,
162 StreamOrDevice s = {});
163
164template <typename T>
166 T p,
167 const std::optional<array>& key = std::nullopt,
168 StreamOrDevice s = {}) {
169 return bernoulli(array(p), key, s);
170};
171
172template <typename T>
174 T p,
175 const std::vector<int>& shape,
176 const std::optional<array>& key = std::nullopt,
177 StreamOrDevice s = {}) {
178 return bernoulli(array(p), shape, key, s);
179};
180
182 const std::optional<array>& key = std::nullopt,
183 StreamOrDevice s = {});
184
186 const array& lower,
187 const array& upper,
188 const std::vector<int>& shape,
189 Dtype dtype = float32,
190 const std::optional<array>& key = std::nullopt,
191 StreamOrDevice s = {});
192
194 const array& lower,
195 const array& upper,
196 Dtype dtype = float32,
197 const std::optional<array>& key = std::nullopt,
198 StreamOrDevice s = {});
199
201 const std::vector<int>& shape,
202 Dtype dtype = float32,
203 const std::optional<array>& key = std::nullopt,
204 StreamOrDevice s = {});
205
207 const array& logits,
208 int axis,
209 const std::vector<int>& shape,
210 const std::optional<array>& key = std::nullopt,
211 StreamOrDevice s = {});
212
214 const array& logits_,
215 int axis,
216 int num_samples,
217 const std::optional<array>& key = std::nullopt,
218 StreamOrDevice s = {});
219
221 const array& logits,
222 int axis = -1,
223 const std::optional<array>& key = std::nullopt,
224 StreamOrDevice s = {});
225
226} // namespace mlx::core::random
Definition array.h:20
Definition random.h:13
void seed(uint64_t seed)
static KeySequence & default_()
Definition random.h:21
array mean(const array &a, bool keepdims, StreamOrDevice s={})
Computes the mean of the elements of an array.
Definition threefry.h:8
array truncated_normal(const array &lower, const array &upper, const std::vector< int > &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
array categorical(const array &logits, int axis, const std::vector< int > &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
std::pair< array, array > split(const array &key, StreamOrDevice s={})
Split the rng key into a pair of keys.
array randint(const array &low, const array &high, const std::vector< int > &shape, Dtype dtype=int32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate integer samples uniformly at random.
array multivariate_normal(const array &mean, const array &cov, const std::vector< int > &shape, Dtype dtype, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate samples from a multivariate normal distribution.
array normal(const std::vector< int > &shape, Dtype dtype, const float loc, const float scale, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate samples from the standard normal distribution.
array gumbel(const std::vector< int > &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
void seed(uint64_t seed)
Seed the default PRNG key.
array bernoulli(const array &p, const std::vector< int > &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate binary variables with probability to be true equal to p.
array key(uint64_t seed)
Get a PRNG key from a seed.
array uniform(const array &low, const array &high, const std::vector< int > &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate uniform random numbers between low and high.
Stream to_stream(StreamOrDevice s)
constexpr Dtype int32
Definition dtype.h:69
constexpr Dtype float32
Definition dtype.h:73
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:14
Definition dtype.h:15