MLX
Loading...
Searching...
No Matches
random.h File Reference
#include <chrono>
#include <optional>
#include "mlx/array.h"
#include "mlx/stream.h"
#include "mlx/utils.h"

Go to the source code of this file.

Classes

class  mlx::core::random::KeySequence
 

Namespaces

namespace  mlx
 
namespace  mlx::core
 
namespace  mlx::core::random
 

Functions

array mlx::core::random::key (uint64_t seed)
 Get a PRNG key from a seed.
 
void mlx::core::random::seed (uint64_t seed)
 Seed the default PRNG key.
 
array mlx::core::random::bits (const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 Generate an array with type uint32 filled with random bits.
 
array mlx::core::random::bits (const std::vector< int > &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
std::pair< array, arraymlx::core::random::split (const array &key, StreamOrDevice s={})
 Split the rng key into a pair of keys.
 
array mlx::core::random::split (const array &key, int num, StreamOrDevice s={})
 Split the rng key into num keys.
 
array mlx::core::random::uniform (const array &low, const array &high, const std::vector< int > &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 Generate uniform random numbers between low and high.
 
template<typename T , typename U >
array mlx::core::random::uniform (T low, U high, const std::vector< int > &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::uniform (const std::vector< int > &shape, Dtype dtype, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 Generate uniform random numbers between 0 and 1.
 
array mlx::core::random::uniform (const std::vector< int > &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::normal (const std::vector< int > &shape, Dtype dtype, const float loc, const float scale, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 Generate samples from the standard normal distribution.
 
array mlx::core::random::normal (const std::vector< int > &shape, const float loc, const float scale, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::normal (const std::vector< int > &shape, const Dtype dtype, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::normal (const std::vector< int > &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::multivariate_normal (const array &mean, const array &cov, const std::vector< int > &shape, Dtype dtype, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 Generate samples from a multivariate normal distribution.
 
array mlx::core::random::randint (const array &low, const array &high, const std::vector< int > &shape, Dtype dtype=int32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 Generate integer samples uniformly at random.
 
template<typename T , typename U >
array mlx::core::random::randint (T low, U high, const std::vector< int > &shape, Dtype dtype=int32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::bernoulli (const array &p, const std::vector< int > &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 Generate binary variables with probability to be true equal to p.
 
array mlx::core::random::bernoulli (const array &p, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
template<typename T >
array mlx::core::random::bernoulli (T p, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
template<typename T >
array mlx::core::random::bernoulli (T p, const std::vector< int > &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::bernoulli (const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::truncated_normal (const array &lower, const array &upper, const std::vector< int > &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::truncated_normal (const array &lower, const array &upper, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::gumbel (const std::vector< int > &shape, Dtype dtype=float32, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::categorical (const array &logits, int axis, const std::vector< int > &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::categorical (const array &logits_, int axis, int num_samples, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::categorical (const array &logits, int axis=-1, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::laplace (const std::vector< int > &shape, Dtype dtype, const float loc, const float scale, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 Generate samples from the laplace distribution.
 
array mlx::core::random::laplace (const std::vector< int > &shape, const float loc, const float scale, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::laplace (const std::vector< int > &shape, const Dtype dtype, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::laplace (const std::vector< int > &shape, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::permutation (const array &x, int axis=0, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
 
array mlx::core::random::permutation (int x, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})