Update the loc and scale to be arrays

This commit is contained in:
Angelos Katharopoulos
2025-05-13 15:20:47 -07:00
parent c0cac3755c
commit 49878758e5
4 changed files with 72 additions and 29 deletions

View File

@@ -152,31 +152,42 @@ void init_random(nb::module_& parent_module) {
"normal",
[](const mx::Shape& shape,
std::optional<mx::Dtype> type,
float loc,
float scale,
const std::optional<ScalarOrArray>& loc_,
const std::optional<ScalarOrArray>& scale_,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto dtype = type.value_or(mx::float32);
auto key = key_ ? key_.value() : default_key().next();
return mx::random::normal(
shape, type.value_or(mx::float32), loc, scale, key, s);
auto loc =
loc_ ? std::make_optional(to_array(*loc_, dtype)) : std::nullopt;
auto scale = scale_ ? std::make_optional(to_array(*scale_, dtype))
: std::nullopt;
return mx::random::normal(shape, dtype, loc, scale, key, s);
},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"loc"_a = 0.0,
"scale"_a = 1.0,
"loc"_a = nb::none(),
"scale"_a = nb::none(),
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate normally distributed random numbers.
If ``loc`` and ``scale`` are not provided the "standard" normal
distribution is used. That means $x \sim \mathcal{N}(0, 1)$ for
real numbers and $\text{Re}(x),\text{Im}(x) \sim \mathcal{N}(0,
\frac{1}{2})$ for complex numbers.
Args:
shape (list(int), optional): Shape of the output. Default is ``()``.
dtype (Dtype, optional): Type of the output. Default is ``float32``.
loc (float, optional): Mean of the distribution. Default is ``0.0``.
scale (float, optional): Standard deviation of the distribution. Default is ``1.0``.
key (array, optional): A PRNG key. Default: None.
shape (list(int), optional): Shape of the output. Default: ``()``.
dtype (Dtype, optional): Type of the output. Default: ``float32``.
loc (scalar or array, optional): Mean of the distribution.
Default: ``None``.
scale (scalar or array, optional): Standard deviation of the
distribution. Default: ``None``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The output array of random values.

View File

@@ -365,6 +365,22 @@ class TestRandom(mlx_tests.MLXTestCase):
self.assertEqual(sample.shape, (1, 2, 3, 4))
self.assertEqual(sample.dtype, mx.complex64)
sample = mx.random.normal(
(1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0 + 1j
)
self.assertEqual(sample.shape, (1, 2, 3, 4))
self.assertEqual(sample.dtype, mx.complex64)
def test_broadcastable_scale_loc(self):
b = mx.random.normal((10, 2))
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
mx.eval(sample)
self.assertEqual(sample.shape, (2, 10, 2))
with self.assertRaises(ValueError):
b = mx.random.normal((10,))
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
if __name__ == "__main__":
unittest.main()