Add random normal distribution for complex numbers (#2182)

This commit is contained in:
Angelos Katharopoulos
2025-05-13 22:43:45 -07:00
committed by GitHub
parent 0751263dec
commit 130df35e1b
4 changed files with 109 additions and 24 deletions

View File

@@ -94,12 +94,24 @@ inline array uniform(
/** Generate samples from the standard normal distribution. */
array normal(
const Shape& shape,
Dtype dtype,
const std::optional<array>& loc,
const std::optional<array>& scale,
const std::optional<array>& key,
StreamOrDevice s = {});
inline array normal(
const Shape& shape,
Dtype dtype,
const float loc,
const float scale,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
StreamOrDevice s = {}) {
auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype));
auto scale_ =
scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype));
return normal(shape, dtype, loc_, scale_, key, s);
}
inline array normal(
const Shape& shape,
const float loc,
@@ -113,13 +125,13 @@ inline array normal(
const Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, dtype, 0.0, 1.0, key, s);
return normal(shape, dtype, std::nullopt, std::nullopt, key, s);
}
inline array normal(
const Shape& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, float32, 0.0, 1.0, key, s);
return normal(shape, float32, std::nullopt, std::nullopt, key, s);
}
/** Generate samples from a multivariate normal distribution. **/