Add random normal distribution for complex numbers (#2182)

This commit is contained in:
Angelos Katharopoulos
2025-05-13 22:43:45 -07:00
committed by GitHub
parent 0751263dec
commit 130df35e1b
4 changed files with 109 additions and 24 deletions

View File

@@ -176,24 +176,51 @@ array uniform(
array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s));
}
inline array complex_normal(
Shape shape,
const std::optional<array>& loc,
const std::optional<array>& scale,
const std::optional<array>& key,
StreamOrDevice s) {
auto stream = to_stream(s);
auto low = above_minus_one_with_default(float32);
auto high = array(1.0f, float32);
shape.push_back(2);
auto samples =
erfinv(uniform(low, high, shape, float32, key, stream), stream);
samples = squeeze(view(samples, complex64, stream), -1, stream);
if (scale.has_value()) {
samples = multiply(*scale, samples, stream);
}
if (loc.has_value()) {
samples = add(*loc, samples, stream);
}
return samples;
}
array normal(
const Shape& shape,
Dtype dtype,
const float loc /* = 0.0 */,
const float scale /* = 1.0 */,
const std::optional<array>& key /*= nullopt */,
const std::optional<array>& loc,
const std::optional<array>& scale,
const std::optional<array>& key,
StreamOrDevice s /* = {} */) {
if (dtype == complex64) {
return complex_normal(shape, loc, scale, key, s);
}
auto stream = to_stream(s);
auto low = above_minus_one_with_default(dtype);
auto high = array(1.0f, dtype);
auto samples = uniform(low, high, shape, dtype, key, stream);
samples =
multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
if (scale != 1.0) {
samples = multiply(array(scale, dtype), samples, stream);
auto applied_scale = array(std::sqrt(2.0), dtype);
if (scale.has_value()) {
applied_scale =
multiply(applied_scale, astype(*scale, dtype, stream), stream);
}
if (loc != 0.0) {
samples = add(array(loc, dtype), samples, stream);
samples = multiply(applied_scale, erfinv(samples, stream), stream);
if (loc.has_value()) {
samples = add(astype(*loc, dtype, stream), samples, stream);
}
return samples;
}

View File

@@ -94,12 +94,24 @@ inline array uniform(
/** Generate samples from the standard normal distribution. */
array normal(
const Shape& shape,
Dtype dtype,
const std::optional<array>& loc,
const std::optional<array>& scale,
const std::optional<array>& key,
StreamOrDevice s = {});
inline array normal(
const Shape& shape,
Dtype dtype,
const float loc,
const float scale,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
StreamOrDevice s = {}) {
auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype));
auto scale_ =
scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype));
return normal(shape, dtype, loc_, scale_, key, s);
}
inline array normal(
const Shape& shape,
const float loc,
@@ -113,13 +125,13 @@ inline array normal(
const Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, dtype, 0.0, 1.0, key, s);
return normal(shape, dtype, std::nullopt, std::nullopt, key, s);
}
inline array normal(
const Shape& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, float32, 0.0, 1.0, key, s);
return normal(shape, float32, std::nullopt, std::nullopt, key, s);
}
/** Generate samples from a multivariate normal distribution. **/