mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Implement sampling from laplace distribution. (#1279)
This commit is contained in:
@@ -102,6 +102,19 @@ T above_minus_one() {
|
||||
return f;
|
||||
}
|
||||
|
||||
// Get the next representable value above -1.0 for half precision
|
||||
// use std::nextafter as default case.
|
||||
array above_minus_one_with_default(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case float16:
|
||||
return array(above_minus_one<float16_t>(), dtype);
|
||||
case bfloat16:
|
||||
return array(above_minus_one<bfloat16_t>(), dtype);
|
||||
default:
|
||||
return array(std::nextafter(-1.0f, 0.0f), dtype);
|
||||
}
|
||||
}
|
||||
|
||||
array uniform(
|
||||
const array& low,
|
||||
const array& high,
|
||||
@@ -171,17 +184,7 @@ array normal(
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto stream = to_stream(s);
|
||||
auto get_low = [&dtype]() {
|
||||
switch (dtype) {
|
||||
case float16:
|
||||
return array(above_minus_one<float16_t>(), dtype);
|
||||
case bfloat16:
|
||||
return array(above_minus_one<bfloat16_t>(), dtype);
|
||||
default:
|
||||
return array(std::nextafter(-1.0f, 0.0f), dtype);
|
||||
}
|
||||
};
|
||||
auto low = get_low();
|
||||
auto low = above_minus_one_with_default(dtype);
|
||||
auto high = array(1.0f, dtype);
|
||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||
samples =
|
||||
@@ -428,4 +431,30 @@ array categorical(
|
||||
return categorical_impl(logits, axis, shape, key, s);
|
||||
}
|
||||
|
||||
array laplace(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
const float loc /* = 0.0 */,
|
||||
const float scale /* = 1.0 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto stream = to_stream(s);
|
||||
auto low = above_minus_one_with_default(dtype);
|
||||
auto high = array(1.0f, dtype);
|
||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||
// Use inverse CDF to generate Laplacian noise
|
||||
samples = multiply(
|
||||
sign(samples),
|
||||
log1p(multiply(array(-1.0f, dtype), abs(samples))),
|
||||
stream);
|
||||
|
||||
if (scale != 1.0) {
|
||||
samples = multiply(array(scale, dtype), samples, stream);
|
||||
}
|
||||
if (loc != 0.0) {
|
||||
samples = add(array(loc, dtype), samples, stream);
|
||||
}
|
||||
return samples;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::random
|
||||
|
30
mlx/random.h
30
mlx/random.h
@@ -224,4 +224,34 @@ array categorical(
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Generate samples from the laplace distribution. */
|
||||
array laplace(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
inline array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return laplace(shape, float32, loc, scale, key, s);
|
||||
}
|
||||
inline array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return laplace(shape, dtype, 0.0, 1.0, key, s);
|
||||
}
|
||||
inline array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return laplace(shape, float32, 0.0, 1.0, key, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::random
|
||||
|
Reference in New Issue
Block a user