Random#
Random sampling functions in MLX use an implicit global PRNG state by default.
However, all function take an optional key keyword argument for when more
fine-grained control or explicit state management is needed.
For example, you can generate random numbers with:
for _ in range(3):
  print(mx.random.uniform())
which will print a sequence of unique pseudo random numbers. Alternatively you can explicitly set the key:
key = mx.random.key(0)
for _ in range(3):
  print(mx.random.uniform(key=key))
which will yield the same pseudo random number at each iteration.
Following JAX’s PRNG design we use a splittable version of Threefry, which is a counter-based PRNG.
  | 
Seed the global PRNG.  | 
  | 
Get a PRNG key from a seed.  | 
  | 
Split a PRNG key into sub keys.  | 
  | 
Generate Bernoulli random values.  | 
  | 
Sample from a categorical distribution.  | 
  | 
Sample from the standard Gumbel distribution.  | 
  | 
Generate normally distributed random numbers.  | 
  | 
Generate random integers from the given interval.  | 
  | 
Generate uniformly distributed random numbers.  | 
  | 
Generate values from a truncated normal distribution.  |