Implementation of mlx.random.multivariate_normal (#502) (#877)

* Implementation of mlx.random.multivariate_normal (#502)

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Updated typo in docstring

* Restricted multivariate_normal to  float32

* Generic mean and variance shapes

* Review edits

* Update mlx/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Test for ndim of mean and cov

* nits

* smaller size for test

* fix broadcasted sampling

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Luca Arnaboldi
2024-04-09 22:50:12 +02:00
committed by GitHub
parent a1a31eed27
commit fffe072028
6 changed files with 270 additions and 1 deletions

View File

@@ -420,6 +420,63 @@ TEST_CASE("test random normal") {
}
}
TEST_CASE("test random multivariate_normal") {
{
auto mean = zeros({3});
auto cov = eye(3);
auto x = random::multivariate_normal(mean, cov, {1000}, float32);
CHECK_EQ(x.shape(), std::vector<int>({1000, 3}));
CHECK_EQ(x.dtype(), float32);
}
// Limit case
{
auto mean = array({0, 0});
auto cov = array({1., -1, -.1, 1.});
cov = reshape(cov, {2, 2});
auto x = random::multivariate_normal(mean, cov, {1}, float32);
CHECK_EQ(x.shape(), std::vector<int>({1, 2}));
CHECK_EQ(x.dtype(), float32);
}
// Check wrong shapes
{
auto mean = zeros({3, 1});
auto cov = eye(3);
CHECK_THROWS_AS(
random::multivariate_normal(
mean,
cov,
{
1000,
},
float32),
std::invalid_argument);
}
{
auto mean = zeros({3});
auto cov = zeros({1, 2, 3, 3});
auto x = random::multivariate_normal(mean, cov, {1000, 2}, float32);
CHECK_EQ(x.shape(), std::vector<int>({1000, 2, 3}));
}
{
auto mean = zeros({3});
auto cov = eye(4);
CHECK_THROWS_AS(
random::multivariate_normal(mean, cov, {1000, 3}, float32),
std::invalid_argument);
}
// Check wrong type
{
auto mean = zeros({3});
auto cov = eye(3);
CHECK_THROWS_AS(
random::multivariate_normal(mean, cov, {1000, 3}, float16),
std::invalid_argument);
}
}
TEST_CASE("test random randint") {
CHECK_THROWS_AS(
random::randint(array(3), array(5), {1}, float32), std::invalid_argument);