mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Implement sampling from laplace distribution. (#1279)
This commit is contained in:
30
mlx/random.h
30
mlx/random.h
@@ -224,4 +224,34 @@ array categorical(
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Generate samples from the laplace distribution. */
|
||||
array laplace(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
inline array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return laplace(shape, float32, loc, scale, key, s);
|
||||
}
|
||||
inline array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return laplace(shape, dtype, 0.0, 1.0, key, s);
|
||||
}
|
||||
inline array laplace(
|
||||
const std::vector<int>& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return laplace(shape, float32, 0.0, 1.0, key, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::random
|
||||
|
||||
Reference in New Issue
Block a user