Add standard normal for complex numbers

This commit is contained in:
Angelos Katharopoulos 2025-05-12 23:50:38 -07:00
parent 3aa9cf3f9e
commit 3d93f799df

View File

@ -176,6 +176,27 @@ array uniform(
array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s));
}
inline array complex_normal(
Shape shape,
const float loc,
const float scale,
const std::optional<array>& key,
StreamOrDevice s) {
auto stream = to_stream(s);
auto low = above_minus_one_with_default(float32);
auto high = array(1.0f, float32);
shape.push_back(2);
auto samples =
erfinv(uniform(low, high, shape, float32, key, stream), stream);
if (scale != 1.0) {
samples = multiply(array(scale, float32), samples, stream);
}
if (loc != 0.0) {
samples = add(array(loc, float32), samples, stream);
}
return squeeze(view(samples, complex64, stream), -1, stream);
}
array normal(
const Shape& shape,
Dtype dtype,
@ -183,15 +204,16 @@ array normal(
const float scale /* = 1.0 */,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
if (dtype == complex64) {
return complex_normal(shape, loc, scale, key, s);
}
auto stream = to_stream(s);
auto low = above_minus_one_with_default(dtype);
auto high = array(1.0f, dtype);
auto samples = uniform(low, high, shape, dtype, key, stream);
samples =
multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
if (scale != 1.0) {
samples = multiply(array(scale, dtype), samples, stream);
}
samples = multiply(
array(std::sqrt(2.0) * scale, dtype), erfinv(samples, stream), stream);
if (loc != 0.0) {
samples = add(array(loc, dtype), samples, stream);
}