mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add standard normal for complex numbers
This commit is contained in:
parent
3aa9cf3f9e
commit
3d93f799df
@ -176,6 +176,27 @@ array uniform(
|
||||
array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s));
|
||||
}
|
||||
|
||||
inline array complex_normal(
|
||||
Shape shape,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
auto stream = to_stream(s);
|
||||
auto low = above_minus_one_with_default(float32);
|
||||
auto high = array(1.0f, float32);
|
||||
shape.push_back(2);
|
||||
auto samples =
|
||||
erfinv(uniform(low, high, shape, float32, key, stream), stream);
|
||||
if (scale != 1.0) {
|
||||
samples = multiply(array(scale, float32), samples, stream);
|
||||
}
|
||||
if (loc != 0.0) {
|
||||
samples = add(array(loc, float32), samples, stream);
|
||||
}
|
||||
return squeeze(view(samples, complex64, stream), -1, stream);
|
||||
}
|
||||
|
||||
array normal(
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
@ -183,15 +204,16 @@ array normal(
|
||||
const float scale /* = 1.0 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (dtype == complex64) {
|
||||
return complex_normal(shape, loc, scale, key, s);
|
||||
}
|
||||
|
||||
auto stream = to_stream(s);
|
||||
auto low = above_minus_one_with_default(dtype);
|
||||
auto high = array(1.0f, dtype);
|
||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||
samples =
|
||||
multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
|
||||
if (scale != 1.0) {
|
||||
samples = multiply(array(scale, dtype), samples, stream);
|
||||
}
|
||||
samples = multiply(
|
||||
array(std::sqrt(2.0) * scale, dtype), erfinv(samples, stream), stream);
|
||||
if (loc != 0.0) {
|
||||
samples = add(array(loc, dtype), samples, stream);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user