mlx/mlx/random.h
Josh Soref 44c1ce5e6a
Spelling (#342)
* spelling: accumulates

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: across

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: additional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: against

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: among

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: array

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: at least

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: available

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: axes

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: basically

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bfloat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bounds

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: broadcast

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: buffer

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: class

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: coefficients

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: collision

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: combinations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: committing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: computation

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: consider

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: constructing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: conversions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: correctly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: corresponding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: declaration

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: default

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dependency

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destination

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destructor

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dimensions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: divided

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: element-wise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: elements

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: endianness

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: equivalent

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: explicitly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: github

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: indices

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: irregularly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: memory

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: metallib

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: negative

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: notable

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: optional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: otherwise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: overridden

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partially

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partition

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perform

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perturbations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: positively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: primitive

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeats

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respect

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respectively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: result

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: rounding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: separate

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: skipping

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: structure

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: the

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: transpose

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unnecessary

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unneeded

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unsupported

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

---------

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
2024-01-01 21:08:17 -08:00

194 lines
4.8 KiB
C++

// Copyright © 2023 Apple Inc.
#pragma once
#include <optional>
#include "mlx/array.h"
#include "mlx/stream.h"
namespace mlx::core::random {
class KeySequence {
public:
explicit KeySequence(uint64_t seed);
void seed(uint64_t seed);
array next();
// static default
static KeySequence& default_() {
static KeySequence ks(0);
return ks;
}
private:
array key_;
};
/** Get a PRNG key from a seed. */
array key(uint64_t seed);
/** Seed the default PRNG key. */
void seed(uint64_t seed);
/** Generate an array with type uint32 filled with random bits. */
array bits(
const std::vector<int>& shape,
int width,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
inline array bits(
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return bits(shape, 4, key, s);
}
/** Split the rng key into a pair of keys. */
std::pair<array, array> split(const array& key, StreamOrDevice s = {});
/** Split the rng key into `num` keys. */
array split(const array& key, int num, StreamOrDevice s = {});
/** Generate uniform random numbers between low and high. */
array uniform(
const array& low,
const array& high,
const std::vector<int>& shape,
Dtype dtype = float32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
template <typename T, typename U>
array uniform(
T low,
U high,
const std::vector<int>& shape,
Dtype dtype = float32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return uniform(array(low), array(high), shape, dtype, key, to_stream(s));
}
/** Generate uniform random numbers between 0 and 1. */
array uniform(
const std::vector<int>& shape,
Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
inline array uniform(
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return uniform(shape, float32, key);
}
/** Generate samples from the standard normal distribution. */
array normal(
const std::vector<int>& shape,
Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
inline array normal(
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, float32, key, s);
}
/** Generate integer samples uniformly at random */
array randint(
const array& low,
const array& high,
const std::vector<int>& shape,
Dtype dtype = int32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
template <typename T, typename U>
array randint(
T low,
U high,
const std::vector<int>& shape,
Dtype dtype = int32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return randint(array(low), array(high), shape, dtype, key, to_stream(s));
};
/** Generate binary variables with probability to be true equal to p */
array bernoulli(
const array& p,
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array bernoulli(
const array& p,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
template <typename T>
array bernoulli(
T p,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return bernoulli(array(p), key, s);
};
template <typename T>
array bernoulli(
T p,
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return bernoulli(array(p), shape, key, s);
};
array bernoulli(
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array truncated_normal(
const array& lower,
const array& upper,
const std::vector<int>& shape,
Dtype dtype = float32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array truncated_normal(
const array& lower,
const array& upper,
Dtype dtype = float32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array gumbel(
const std::vector<int>& shape,
Dtype dtype = float32,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array categorical(
const array& logits,
int axis,
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array categorical(
const array& logits_,
int axis,
int num_samples,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
array categorical(
const array& logits,
int axis = -1,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
} // namespace mlx::core::random