Implement sampling from laplace distribution. (#1279)

This commit is contained in:
fgranqvist
2024-07-24 15:15:37 +02:00
committed by GitHub
parent c34a5ae7f7
commit 50eff6a10a
6 changed files with 210 additions and 40 deletions

View File

@@ -224,4 +224,34 @@ array categorical(
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
/** Generate samples from the laplace distribution. */
array laplace(
const std::vector<int>& shape,
Dtype dtype,
const float loc,
const float scale,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
inline array laplace(
const std::vector<int>& shape,
const float loc,
const float scale,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return laplace(shape, float32, loc, scale, key, s);
}
inline array laplace(
const std::vector<int>& shape,
const Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return laplace(shape, dtype, 0.0, 1.0, key, s);
}
inline array laplace(
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return laplace(shape, float32, 0.0, 1.0, key, s);
}
} // namespace mlx::core::random