Implementation of mlx.random.multivariate_normal (#502) (#877)

* Implementation of mlx.random.multivariate_normal (#502)

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Updated typo in docstring

* Restricted multivariate_normal to  float32

* Generic mean and variance shapes

* Review edits

* Update mlx/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Test for ndim of mean and cov

* nits

* smaller size for test

* fix broadcasted sampling

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Luca Arnaboldi
2024-04-09 22:50:12 +02:00
committed by GitHub
parent a1a31eed27
commit fffe072028
6 changed files with 270 additions and 1 deletions

View File

@@ -179,6 +179,48 @@ void init_random(nb::module_& parent_module) {
array: The output array of random values.
)pbdoc");
m.def(
"multivariate_normal",
[](const array& mean,
const array& cov,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return multivariate_normal(
mean, cov, shape, type.value_or(float32), key, s);
},
"mean"_a,
"cov"_a,
"shape"_a = std::vector<int>{},
"dtype"_a.none() = float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def multivariate_normal(mean: array, cov: array, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate jointly-normal random samples given a mean and covariance.
The matrix ``cov`` must be positive semi-definite. The behavior is
undefined if it is not. The only supported ``dtype`` is ``float32``.
Args:
mean (array): array of shape ``(..., n)``, the mean of the
distribution.
cov (array): array of shape ``(..., n, n)``, the covariance
matrix of the distribution. The batch shape ``...`` must be
broadcast-compatible with that of ``mean``.
shape (list(int), optional): The output shape must be
broadcast-compatible with ``mean.shape[:-1]`` and ``cov.shape[:-2]``.
If empty, the result shape is determined by broadcasting the batch
shapes of ``mean`` and ``cov``. Default: ``[]``.
dtype (Dtype, optional): The output type. Default: ``float32``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The output array of random values.
)pbdoc");
m.def(
"randint",
[](const ScalarOrArray& low,
const ScalarOrArray& high,