mlx/mlx/random.h

271 lines
7.0 KiB
C
Raw Normal View History

2023-12-01 03:12:53 +08:00
// Copyright © 2023 Apple Inc.
2023-11-30 04:38:32 +08:00
#pragma once
#include <chrono>
2023-11-30 04:38:32 +08:00
#include <optional>
#include "mlx/array.h"
#include "mlx/stream.h"
2024-07-15 23:20:24 +08:00
#include "mlx/utils.h"
2023-11-30 04:38:32 +08:00
namespace mlx::core::random {
class KeySequence {
public:
explicit KeySequence(uint64_t seed);
void seed(uint64_t seed);
array next();
Spelling (#342) * spelling: accumulates Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: across Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: additional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: against Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: among Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: array Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: at least Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: available Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: axes Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: basically Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bfloat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bounds Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: broadcast Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: buffer Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: class Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: coefficients Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: collision Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: combinations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: committing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: computation Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: consider Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: constructing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: conversions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: correctly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: corresponding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: declaration Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: default Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dependency Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destination Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destructor Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dimensions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: divided Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: element-wise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: elements Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: endianness Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: equivalent Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: explicitly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: github Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: indices Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: irregularly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: memory Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: metallib Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: negative Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: notable Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: optional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: otherwise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: overridden Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partially Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partition Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perform Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perturbations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: positively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: primitive Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeats Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respect Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respectively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: result Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: rounding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: separate Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: skipping Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: structure Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: the Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: transpose Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unnecessary Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unneeded Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unsupported Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> --------- Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
2024-01-02 13:08:17 +08:00
// static default
2023-11-30 04:38:32 +08:00
static KeySequence& default_() {
static KeySequence ks(get_current_time_seed());
2023-11-30 04:38:32 +08:00
return ks;
}
private:
array key_;
static uint64_t get_current_time_seed() {
auto now = std::chrono::system_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch())
.count();
}
2023-11-30 04:38:32 +08:00
};
/** Get a PRNG key from a seed. */
array key(uint64_t seed);
/** Seed the default PRNG key. */
void seed(uint64_t seed);
/** Generate an array with type uint32 filled with random bits. */
array bits(
const std::vector<int>& shape,
int width,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
inline array bits(
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return bits(shape, 4, key, s);
}
/** Split the rng key into a pair of keys. */
std::pair<array, array> split(const array& key, StreamOrDevice s = {});
/** Split the rng key into `num` keys. */
array split(const array& key, int num, StreamOrDevice s = {});
/** Generate uniform random numbers between low and high. */
array uniform(
const array& low,
const array& high,
const std::vector<int>& shape,
Dtype dtype = float32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
template <typename T, typename U>
array uniform(
T low,
U high,
const std::vector<int>& shape,
Dtype dtype = float32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return uniform(array(low), array(high), shape, dtype, key, to_stream(s));
}
/** Generate uniform random numbers between 0 and 1. */
array uniform(
const std::vector<int>& shape,
Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
inline array uniform(
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return uniform(shape, float32, key);
}
/** Generate samples from the standard normal distribution. */
array normal(
const std::vector<int>& shape,
Dtype dtype,
const float loc,
const float scale,
2023-11-30 04:38:32 +08:00
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
inline array normal(
const std::vector<int>& shape,
const float loc,
const float scale,
2023-11-30 04:38:32 +08:00
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, float32, loc, scale, key, s);
}
inline array normal(
const std::vector<int>& shape,
const Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, dtype, 0.0, 1.0, key, s);
}
inline array normal(
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, float32, 0.0, 1.0, key, s);
2023-11-30 04:38:32 +08:00
}
/** Generate samples from a multivariate normal distribution. **/
array multivariate_normal(
const array& mean,
const array& cov,
const std::vector<int>& shape,
Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
2023-11-30 04:38:32 +08:00
/** Generate integer samples uniformly at random */
array randint(
const array& low,
const array& high,
const std::vector<int>& shape,
Dtype dtype = int32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
template <typename T, typename U>
array randint(
T low,
U high,
const std::vector<int>& shape,
Dtype dtype = int32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return randint(array(low), array(high), shape, dtype, key, to_stream(s));
}
2023-11-30 04:38:32 +08:00
/** Generate binary variables with probability to be true equal to p */
array bernoulli(
const array& p,
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array bernoulli(
const array& p,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
template <typename T>
array bernoulli(
T p,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return bernoulli(array(p), key, s);
}
2023-11-30 04:38:32 +08:00
template <typename T>
array bernoulli(
T p,
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return bernoulli(array(p), shape, key, s);
}
2023-11-30 04:38:32 +08:00
array bernoulli(
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array truncated_normal(
const array& lower,
const array& upper,
const std::vector<int>& shape,
Dtype dtype = float32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array truncated_normal(
const array& lower,
const array& upper,
Dtype dtype = float32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array gumbel(
const std::vector<int>& shape,
Dtype dtype = float32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array categorical(
const array& logits,
int axis,
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array categorical(
const array& logits_,
int axis,
int num_samples,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array categorical(
const array& logits,
int axis = -1,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
/** Generate samples from the laplace distribution. */
array laplace(
const std::vector<int>& shape,
Dtype dtype,
const float loc,
const float scale,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
inline array laplace(
const std::vector<int>& shape,
const float loc,
const float scale,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return laplace(shape, float32, loc, scale, key, s);
}
inline array laplace(
const std::vector<int>& shape,
const Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return laplace(shape, dtype, 0.0, 1.0, key, s);
}
inline array laplace(
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return laplace(shape, float32, 0.0, 1.0, key, s);
}
/* Randomly permute the elements of x along the given axis. */
array permutation(
const array& x,
int axis = 0,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
/* A random permutation of `arange(x)` */
array permutation(
int x,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
2023-11-30 04:38:32 +08:00
} // namespace mlx::core::random