Implementation of mlx.random.multivariate_normal (#502) (#877)

* Implementation of mlx.random.multivariate_normal (#502)

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Updated typo in docstring

* Restricted multivariate_normal to  float32

* Generic mean and variance shapes

* Review edits

* Update mlx/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Test for ndim of mean and cov

* nits

* smaller size for test

* fix broadcasted sampling

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Luca Arnaboldi
2024-04-09 22:50:12 +02:00
committed by GitHub
parent a1a31eed27
commit fffe072028
6 changed files with 270 additions and 1 deletions

View File

@@ -121,6 +121,15 @@ inline array normal(
return normal(shape, float32, 0.0, 1.0, key, s);
}
/** Generate samples from a multivariate normal distribution. **/
array multivariate_normal(
const array& mean,
const array& cov,
const std::vector<int>& shape,
Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
/** Generate integer samples uniformly at random */
array randint(
const array& low,