From 130df35e1b520061a053c052fba07122dc390c6a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 13 May 2025 22:43:45 -0700 Subject: [PATCH] Add random normal distribution for complex numbers (#2182) --- mlx/random.cpp | 45 +++++++++++++++++++++++++++++-------- mlx/random.h | 18 ++++++++++++--- python/src/random.cpp | 35 +++++++++++++++++++---------- python/tests/test_random.py | 35 +++++++++++++++++++++++++++++ 4 files changed, 109 insertions(+), 24 deletions(-) diff --git a/mlx/random.cpp b/mlx/random.cpp index 89a027b17..6c6d1eb95 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -176,24 +176,51 @@ array uniform( array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s)); } +inline array complex_normal( + Shape shape, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, + StreamOrDevice s) { + auto stream = to_stream(s); + auto low = above_minus_one_with_default(float32); + auto high = array(1.0f, float32); + shape.push_back(2); + auto samples = + erfinv(uniform(low, high, shape, float32, key, stream), stream); + samples = squeeze(view(samples, complex64, stream), -1, stream); + if (scale.has_value()) { + samples = multiply(*scale, samples, stream); + } + if (loc.has_value()) { + samples = add(*loc, samples, stream); + } + return samples; +} + array normal( const Shape& shape, Dtype dtype, - const float loc /* = 0.0 */, - const float scale /* = 1.0 */, - const std::optional& key /*= nullopt */, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, StreamOrDevice s /* = {} */) { + if (dtype == complex64) { + return complex_normal(shape, loc, scale, key, s); + } + auto stream = to_stream(s); auto low = above_minus_one_with_default(dtype); auto high = array(1.0f, dtype); auto samples = uniform(low, high, shape, dtype, key, stream); - samples = - multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream); - if (scale != 1.0) { - samples = multiply(array(scale, dtype), samples, stream); + auto applied_scale = array(std::sqrt(2.0), dtype); + if (scale.has_value()) { + applied_scale = + multiply(applied_scale, astype(*scale, dtype, stream), stream); } - if (loc != 0.0) { - samples = add(array(loc, dtype), samples, stream); + samples = multiply(applied_scale, erfinv(samples, stream), stream); + if (loc.has_value()) { + samples = add(astype(*loc, dtype, stream), samples, stream); } return samples; } diff --git a/mlx/random.h b/mlx/random.h index b2c821736..0dfdab7a1 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -94,12 +94,24 @@ inline array uniform( /** Generate samples from the standard normal distribution. */ array normal( + const Shape& shape, + Dtype dtype, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, + StreamOrDevice s = {}); +inline array normal( const Shape& shape, Dtype dtype, const float loc, const float scale, const std::optional& key = std::nullopt, - StreamOrDevice s = {}); + StreamOrDevice s = {}) { + auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype)); + auto scale_ = + scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype)); + return normal(shape, dtype, loc_, scale_, key, s); +} inline array normal( const Shape& shape, const float loc, @@ -113,13 +125,13 @@ inline array normal( const Dtype dtype, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { - return normal(shape, dtype, 0.0, 1.0, key, s); + return normal(shape, dtype, std::nullopt, std::nullopt, key, s); } inline array normal( const Shape& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { - return normal(shape, float32, 0.0, 1.0, key, s); + return normal(shape, float32, std::nullopt, std::nullopt, key, s); } /** Generate samples from a multivariate normal distribution. **/ diff --git a/python/src/random.cpp b/python/src/random.cpp index 22b706174..837f91616 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -152,31 +152,42 @@ void init_random(nb::module_& parent_module) { "normal", [](const mx::Shape& shape, std::optional type, - float loc, - float scale, + const std::optional& loc_, + const std::optional& scale_, const std::optional& key_, mx::StreamOrDevice s) { + auto dtype = type.value_or(mx::float32); auto key = key_ ? key_.value() : default_key().next(); - return mx::random::normal( - shape, type.value_or(mx::float32), loc, scale, key, s); + auto loc = + loc_ ? std::make_optional(to_array(*loc_, dtype)) : std::nullopt; + auto scale = scale_ ? std::make_optional(to_array(*scale_, dtype)) + : std::nullopt; + return mx::random::normal(shape, dtype, loc, scale, key, s); }, "shape"_a = mx::Shape{}, "dtype"_a.none() = mx::float32, - "loc"_a = 0.0, - "scale"_a = 1.0, + "loc"_a = nb::none(), + "scale"_a = nb::none(), "key"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Generate normally distributed random numbers. + If ``loc`` and ``scale`` are not provided the "standard" normal + distribution is used. That means $x \sim \mathcal{N}(0, 1)$ for + real numbers and $\text{Re}(x),\text{Im}(x) \sim \mathcal{N}(0, + \frac{1}{2})$ for complex numbers. + Args: - shape (list(int), optional): Shape of the output. Default is ``()``. - dtype (Dtype, optional): Type of the output. Default is ``float32``. - loc (float, optional): Mean of the distribution. Default is ``0.0``. - scale (float, optional): Standard deviation of the distribution. Default is ``1.0``. - key (array, optional): A PRNG key. Default: None. + shape (list(int), optional): Shape of the output. Default: ``()``. + dtype (Dtype, optional): Type of the output. Default: ``float32``. + loc (scalar or array, optional): Mean of the distribution. + Default: ``None``. + scale (scalar or array, optional): Standard deviation of the + distribution. Default: ``None``. + key (array, optional): A PRNG key. Default: ``None``. Returns: array: The output array of random values. diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 9efbfb5f6..2fc768651 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -352,6 +352,41 @@ class TestRandom(mlx_tests.MLXTestCase): x = mx.random.permutation(mx.array([[1]])) self.assertEqual(x.shape, (1, 1)) + def test_complex_normal(self): + sample = mx.random.normal(tuple(), dtype=mx.complex64) + self.assertEqual(sample.shape, tuple()) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal( + (1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0 + 1j + ) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + def test_broadcastable_scale_loc(self): + b = mx.random.normal((10, 2)) + sample = mx.random.normal((2, 10, 2), loc=b, scale=b) + mx.eval(sample) + self.assertEqual(sample.shape, (2, 10, 2)) + + with self.assertRaises(ValueError): + b = mx.random.normal((10,)) + sample = mx.random.normal((2, 10, 2), loc=b, scale=b) + + b = mx.random.normal((3, 1, 2)) + sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b) + mx.eval(sample) + self.assertEqual(sample.shape, (3, 4, 2)) + self.assertEqual(sample.dtype, mx.float16) + if __name__ == "__main__": unittest.main()