Add random normal distribution for complex numbers (#2182)

This commit is contained in:
Angelos Katharopoulos 2025-05-13 22:43:45 -07:00 committed by GitHub
parent 0751263dec
commit 130df35e1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 109 additions and 24 deletions

View File

@ -176,24 +176,51 @@ array uniform(
array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s)); array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s));
} }
inline array complex_normal(
Shape shape,
const std::optional<array>& loc,
const std::optional<array>& scale,
const std::optional<array>& key,
StreamOrDevice s) {
auto stream = to_stream(s);
auto low = above_minus_one_with_default(float32);
auto high = array(1.0f, float32);
shape.push_back(2);
auto samples =
erfinv(uniform(low, high, shape, float32, key, stream), stream);
samples = squeeze(view(samples, complex64, stream), -1, stream);
if (scale.has_value()) {
samples = multiply(*scale, samples, stream);
}
if (loc.has_value()) {
samples = add(*loc, samples, stream);
}
return samples;
}
array normal( array normal(
const Shape& shape, const Shape& shape,
Dtype dtype, Dtype dtype,
const float loc /* = 0.0 */, const std::optional<array>& loc,
const float scale /* = 1.0 */, const std::optional<array>& scale,
const std::optional<array>& key /*= nullopt */, const std::optional<array>& key,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (dtype == complex64) {
return complex_normal(shape, loc, scale, key, s);
}
auto stream = to_stream(s); auto stream = to_stream(s);
auto low = above_minus_one_with_default(dtype); auto low = above_minus_one_with_default(dtype);
auto high = array(1.0f, dtype); auto high = array(1.0f, dtype);
auto samples = uniform(low, high, shape, dtype, key, stream); auto samples = uniform(low, high, shape, dtype, key, stream);
samples = auto applied_scale = array(std::sqrt(2.0), dtype);
multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream); if (scale.has_value()) {
if (scale != 1.0) { applied_scale =
samples = multiply(array(scale, dtype), samples, stream); multiply(applied_scale, astype(*scale, dtype, stream), stream);
} }
if (loc != 0.0) { samples = multiply(applied_scale, erfinv(samples, stream), stream);
samples = add(array(loc, dtype), samples, stream); if (loc.has_value()) {
samples = add(astype(*loc, dtype, stream), samples, stream);
} }
return samples; return samples;
} }

View File

@ -94,12 +94,24 @@ inline array uniform(
/** Generate samples from the standard normal distribution. */ /** Generate samples from the standard normal distribution. */
array normal( array normal(
const Shape& shape,
Dtype dtype,
const std::optional<array>& loc,
const std::optional<array>& scale,
const std::optional<array>& key,
StreamOrDevice s = {});
inline array normal(
const Shape& shape, const Shape& shape,
Dtype dtype, Dtype dtype,
const float loc, const float loc,
const float scale, const float scale,
const std::optional<array>& key = std::nullopt, const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}); StreamOrDevice s = {}) {
auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype));
auto scale_ =
scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype));
return normal(shape, dtype, loc_, scale_, key, s);
}
inline array normal( inline array normal(
const Shape& shape, const Shape& shape,
const float loc, const float loc,
@ -113,13 +125,13 @@ inline array normal(
const Dtype dtype, const Dtype dtype,
const std::optional<array>& key = std::nullopt, const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) { StreamOrDevice s = {}) {
return normal(shape, dtype, 0.0, 1.0, key, s); return normal(shape, dtype, std::nullopt, std::nullopt, key, s);
} }
inline array normal( inline array normal(
const Shape& shape, const Shape& shape,
const std::optional<array>& key = std::nullopt, const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) { StreamOrDevice s = {}) {
return normal(shape, float32, 0.0, 1.0, key, s); return normal(shape, float32, std::nullopt, std::nullopt, key, s);
} }
/** Generate samples from a multivariate normal distribution. **/ /** Generate samples from a multivariate normal distribution. **/

View File

@ -152,31 +152,42 @@ void init_random(nb::module_& parent_module) {
"normal", "normal",
[](const mx::Shape& shape, [](const mx::Shape& shape,
std::optional<mx::Dtype> type, std::optional<mx::Dtype> type,
float loc, const std::optional<ScalarOrArray>& loc_,
float scale, const std::optional<ScalarOrArray>& scale_,
const std::optional<mx::array>& key_, const std::optional<mx::array>& key_,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
auto dtype = type.value_or(mx::float32);
auto key = key_ ? key_.value() : default_key().next(); auto key = key_ ? key_.value() : default_key().next();
return mx::random::normal( auto loc =
shape, type.value_or(mx::float32), loc, scale, key, s); loc_ ? std::make_optional(to_array(*loc_, dtype)) : std::nullopt;
auto scale = scale_ ? std::make_optional(to_array(*scale_, dtype))
: std::nullopt;
return mx::random::normal(shape, dtype, loc, scale, key, s);
}, },
"shape"_a = mx::Shape{}, "shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32, "dtype"_a.none() = mx::float32,
"loc"_a = 0.0, "loc"_a = nb::none(),
"scale"_a = 1.0, "scale"_a = nb::none(),
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Generate normally distributed random numbers. Generate normally distributed random numbers.
If ``loc`` and ``scale`` are not provided the "standard" normal
distribution is used. That means $x \sim \mathcal{N}(0, 1)$ for
real numbers and $\text{Re}(x),\text{Im}(x) \sim \mathcal{N}(0,
\frac{1}{2})$ for complex numbers.
Args: Args:
shape (list(int), optional): Shape of the output. Default is ``()``. shape (list(int), optional): Shape of the output. Default: ``()``.
dtype (Dtype, optional): Type of the output. Default is ``float32``. dtype (Dtype, optional): Type of the output. Default: ``float32``.
loc (float, optional): Mean of the distribution. Default is ``0.0``. loc (scalar or array, optional): Mean of the distribution.
scale (float, optional): Standard deviation of the distribution. Default is ``1.0``. Default: ``None``.
key (array, optional): A PRNG key. Default: None. scale (scalar or array, optional): Standard deviation of the
distribution. Default: ``None``.
key (array, optional): A PRNG key. Default: ``None``.
Returns: Returns:
array: The output array of random values. array: The output array of random values.

View File

@ -352,6 +352,41 @@ class TestRandom(mlx_tests.MLXTestCase):
x = mx.random.permutation(mx.array([[1]])) x = mx.random.permutation(mx.array([[1]]))
self.assertEqual(x.shape, (1, 1)) self.assertEqual(x.shape, (1, 1))
def test_complex_normal(self):
sample = mx.random.normal(tuple(), dtype=mx.complex64)
self.assertEqual(sample.shape, tuple())
self.assertEqual(sample.dtype, mx.complex64)
sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64)
self.assertEqual(sample.shape, (1, 2, 3, 4))
self.assertEqual(sample.dtype, mx.complex64)
sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0)
self.assertEqual(sample.shape, (1, 2, 3, 4))
self.assertEqual(sample.dtype, mx.complex64)
sample = mx.random.normal(
(1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0 + 1j
)
self.assertEqual(sample.shape, (1, 2, 3, 4))
self.assertEqual(sample.dtype, mx.complex64)
def test_broadcastable_scale_loc(self):
b = mx.random.normal((10, 2))
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
mx.eval(sample)
self.assertEqual(sample.shape, (2, 10, 2))
with self.assertRaises(ValueError):
b = mx.random.normal((10,))
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
b = mx.random.normal((3, 1, 2))
sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b)
mx.eval(sample)
self.assertEqual(sample.shape, (3, 4, 2))
self.assertEqual(sample.dtype, mx.float16)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()