diff --git a/mlx/random.cpp b/mlx/random.cpp index 89a027b17..fd3c2d502 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -176,6 +176,27 @@ array uniform( array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s)); } +inline array complex_normal( + Shape shape, + const float loc, + const float scale, + const std::optional& key, + StreamOrDevice s) { + auto stream = to_stream(s); + auto low = above_minus_one_with_default(float32); + auto high = array(1.0f, float32); + shape.push_back(2); + auto samples = + erfinv(uniform(low, high, shape, float32, key, stream), stream); + if (scale != 1.0) { + samples = multiply(array(scale, float32), samples, stream); + } + if (loc != 0.0) { + samples = add(array(loc, float32), samples, stream); + } + return squeeze(view(samples, complex64, stream), -1, stream); +} + array normal( const Shape& shape, Dtype dtype, @@ -183,15 +204,16 @@ array normal( const float scale /* = 1.0 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { + if (dtype == complex64) { + return complex_normal(shape, loc, scale, key, s); + } + auto stream = to_stream(s); auto low = above_minus_one_with_default(dtype); auto high = array(1.0f, dtype); auto samples = uniform(low, high, shape, dtype, key, stream); - samples = - multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream); - if (scale != 1.0) { - samples = multiply(array(scale, dtype), samples, stream); - } + samples = multiply( + array(std::sqrt(2.0) * scale, dtype), erfinv(samples, stream), stream); if (loc != 0.0) { samples = add(array(loc, dtype), samples, stream); }