Implement sampling from laplace distribution. (#1279)

This commit is contained in:
fgranqvist
2024-07-24 15:15:37 +02:00
committed by GitHub
parent c34a5ae7f7
commit 50eff6a10a
6 changed files with 210 additions and 40 deletions

View File

@@ -102,6 +102,19 @@ T above_minus_one() {
return f;
}
// Get the next representable value above -1.0 for half precision
// use std::nextafter as default case.
array above_minus_one_with_default(Dtype dtype) {
switch (dtype) {
case float16:
return array(above_minus_one<float16_t>(), dtype);
case bfloat16:
return array(above_minus_one<bfloat16_t>(), dtype);
default:
return array(std::nextafter(-1.0f, 0.0f), dtype);
}
}
array uniform(
const array& low,
const array& high,
@@ -171,17 +184,7 @@ array normal(
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
auto stream = to_stream(s);
auto get_low = [&dtype]() {
switch (dtype) {
case float16:
return array(above_minus_one<float16_t>(), dtype);
case bfloat16:
return array(above_minus_one<bfloat16_t>(), dtype);
default:
return array(std::nextafter(-1.0f, 0.0f), dtype);
}
};
auto low = get_low();
auto low = above_minus_one_with_default(dtype);
auto high = array(1.0f, dtype);
auto samples = uniform(low, high, shape, dtype, key, stream);
samples =
@@ -428,4 +431,30 @@ array categorical(
return categorical_impl(logits, axis, shape, key, s);
}
array laplace(
const std::vector<int>& shape,
Dtype dtype,
const float loc /* = 0.0 */,
const float scale /* = 1.0 */,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
auto stream = to_stream(s);
auto low = above_minus_one_with_default(dtype);
auto high = array(1.0f, dtype);
auto samples = uniform(low, high, shape, dtype, key, stream);
// Use inverse CDF to generate Laplacian noise
samples = multiply(
sign(samples),
log1p(multiply(array(-1.0f, dtype), abs(samples))),
stream);
if (scale != 1.0) {
samples = multiply(array(scale, dtype), samples, stream);
}
if (loc != 0.0) {
samples = add(array(loc, dtype), samples, stream);
}
return samples;
}
} // namespace mlx::core::random

View File

@@ -224,4 +224,34 @@ array categorical(
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
/** Generate samples from the laplace distribution. */
array laplace(
const std::vector<int>& shape,
Dtype dtype,
const float loc,
const float scale,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
inline array laplace(
const std::vector<int>& shape,
const float loc,
const float scale,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return laplace(shape, float32, loc, scale, key, s);
}
inline array laplace(
const std::vector<int>& shape,
const Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return laplace(shape, dtype, 0.0, 1.0, key, s);
}
inline array laplace(
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return laplace(shape, float32, 0.0, 1.0, key, s);
}
} // namespace mlx::core::random