mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Add random normal distribution for complex numbers (#2182)
This commit is contained in:
parent
0751263dec
commit
130df35e1b
@ -176,24 +176,51 @@ array uniform(
|
||||
array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s));
|
||||
}
|
||||
|
||||
inline array complex_normal(
|
||||
Shape shape,
|
||||
const std::optional<array>& loc,
|
||||
const std::optional<array>& scale,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s) {
|
||||
auto stream = to_stream(s);
|
||||
auto low = above_minus_one_with_default(float32);
|
||||
auto high = array(1.0f, float32);
|
||||
shape.push_back(2);
|
||||
auto samples =
|
||||
erfinv(uniform(low, high, shape, float32, key, stream), stream);
|
||||
samples = squeeze(view(samples, complex64, stream), -1, stream);
|
||||
if (scale.has_value()) {
|
||||
samples = multiply(*scale, samples, stream);
|
||||
}
|
||||
if (loc.has_value()) {
|
||||
samples = add(*loc, samples, stream);
|
||||
}
|
||||
return samples;
|
||||
}
|
||||
|
||||
array normal(
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const float loc /* = 0.0 */,
|
||||
const float scale /* = 1.0 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
const std::optional<array>& loc,
|
||||
const std::optional<array>& scale,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (dtype == complex64) {
|
||||
return complex_normal(shape, loc, scale, key, s);
|
||||
}
|
||||
|
||||
auto stream = to_stream(s);
|
||||
auto low = above_minus_one_with_default(dtype);
|
||||
auto high = array(1.0f, dtype);
|
||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||
samples =
|
||||
multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
|
||||
if (scale != 1.0) {
|
||||
samples = multiply(array(scale, dtype), samples, stream);
|
||||
auto applied_scale = array(std::sqrt(2.0), dtype);
|
||||
if (scale.has_value()) {
|
||||
applied_scale =
|
||||
multiply(applied_scale, astype(*scale, dtype, stream), stream);
|
||||
}
|
||||
if (loc != 0.0) {
|
||||
samples = add(array(loc, dtype), samples, stream);
|
||||
samples = multiply(applied_scale, erfinv(samples, stream), stream);
|
||||
if (loc.has_value()) {
|
||||
samples = add(astype(*loc, dtype, stream), samples, stream);
|
||||
}
|
||||
return samples;
|
||||
}
|
||||
|
18
mlx/random.h
18
mlx/random.h
@ -94,12 +94,24 @@ inline array uniform(
|
||||
|
||||
/** Generate samples from the standard normal distribution. */
|
||||
array normal(
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const std::optional<array>& loc,
|
||||
const std::optional<array>& scale,
|
||||
const std::optional<array>& key,
|
||||
StreamOrDevice s = {});
|
||||
inline array normal(
|
||||
const Shape& shape,
|
||||
Dtype dtype,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
StreamOrDevice s = {}) {
|
||||
auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype));
|
||||
auto scale_ =
|
||||
scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype));
|
||||
return normal(shape, dtype, loc_, scale_, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const Shape& shape,
|
||||
const float loc,
|
||||
@ -113,13 +125,13 @@ inline array normal(
|
||||
const Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, dtype, 0.0, 1.0, key, s);
|
||||
return normal(shape, dtype, std::nullopt, std::nullopt, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const Shape& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, float32, 0.0, 1.0, key, s);
|
||||
return normal(shape, float32, std::nullopt, std::nullopt, key, s);
|
||||
}
|
||||
|
||||
/** Generate samples from a multivariate normal distribution. **/
|
||||
|
@ -152,31 +152,42 @@ void init_random(nb::module_& parent_module) {
|
||||
"normal",
|
||||
[](const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
const std::optional<ScalarOrArray>& loc_,
|
||||
const std::optional<ScalarOrArray>& scale_,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto dtype = type.value_or(mx::float32);
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return mx::random::normal(
|
||||
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||
auto loc =
|
||||
loc_ ? std::make_optional(to_array(*loc_, dtype)) : std::nullopt;
|
||||
auto scale = scale_ ? std::make_optional(to_array(*scale_, dtype))
|
||||
: std::nullopt;
|
||||
return mx::random::normal(shape, dtype, loc, scale, key, s);
|
||||
},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"loc"_a = 0.0,
|
||||
"scale"_a = 1.0,
|
||||
"loc"_a = nb::none(),
|
||||
"scale"_a = nb::none(),
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate normally distributed random numbers.
|
||||
|
||||
If ``loc`` and ``scale`` are not provided the "standard" normal
|
||||
distribution is used. That means $x \sim \mathcal{N}(0, 1)$ for
|
||||
real numbers and $\text{Re}(x),\text{Im}(x) \sim \mathcal{N}(0,
|
||||
\frac{1}{2})$ for complex numbers.
|
||||
|
||||
Args:
|
||||
shape (list(int), optional): Shape of the output. Default is ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Default is ``float32``.
|
||||
loc (float, optional): Mean of the distribution. Default is ``0.0``.
|
||||
scale (float, optional): Standard deviation of the distribution. Default is ``1.0``.
|
||||
key (array, optional): A PRNG key. Default: None.
|
||||
shape (list(int), optional): Shape of the output. Default: ``()``.
|
||||
dtype (Dtype, optional): Type of the output. Default: ``float32``.
|
||||
loc (scalar or array, optional): Mean of the distribution.
|
||||
Default: ``None``.
|
||||
scale (scalar or array, optional): Standard deviation of the
|
||||
distribution. Default: ``None``.
|
||||
key (array, optional): A PRNG key. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The output array of random values.
|
||||
|
@ -352,6 +352,41 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
x = mx.random.permutation(mx.array([[1]]))
|
||||
self.assertEqual(x.shape, (1, 1))
|
||||
|
||||
def test_complex_normal(self):
|
||||
sample = mx.random.normal(tuple(), dtype=mx.complex64)
|
||||
self.assertEqual(sample.shape, tuple())
|
||||
self.assertEqual(sample.dtype, mx.complex64)
|
||||
|
||||
sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64)
|
||||
self.assertEqual(sample.shape, (1, 2, 3, 4))
|
||||
self.assertEqual(sample.dtype, mx.complex64)
|
||||
|
||||
sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0)
|
||||
self.assertEqual(sample.shape, (1, 2, 3, 4))
|
||||
self.assertEqual(sample.dtype, mx.complex64)
|
||||
|
||||
sample = mx.random.normal(
|
||||
(1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0 + 1j
|
||||
)
|
||||
self.assertEqual(sample.shape, (1, 2, 3, 4))
|
||||
self.assertEqual(sample.dtype, mx.complex64)
|
||||
|
||||
def test_broadcastable_scale_loc(self):
|
||||
b = mx.random.normal((10, 2))
|
||||
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
|
||||
mx.eval(sample)
|
||||
self.assertEqual(sample.shape, (2, 10, 2))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
b = mx.random.normal((10,))
|
||||
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
|
||||
|
||||
b = mx.random.normal((3, 1, 2))
|
||||
sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b)
|
||||
mx.eval(sample)
|
||||
self.assertEqual(sample.shape, (3, 4, 2))
|
||||
self.assertEqual(sample.dtype, mx.float16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user