mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
* Implementation of mlx.random.multivariate_normal (#502) * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Updated typo in docstring * Restricted multivariate_normal to float32 * Generic mean and variance shapes * Review edits * Update mlx/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Test for ndim of mean and cov * nits * smaller size for test * fix broadcasted sampling --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -420,6 +420,63 @@ TEST_CASE("test random normal") {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test random multivariate_normal") {
|
||||
{
|
||||
auto mean = zeros({3});
|
||||
auto cov = eye(3);
|
||||
auto x = random::multivariate_normal(mean, cov, {1000}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>({1000, 3}));
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
}
|
||||
|
||||
// Limit case
|
||||
{
|
||||
auto mean = array({0, 0});
|
||||
auto cov = array({1., -1, -.1, 1.});
|
||||
cov = reshape(cov, {2, 2});
|
||||
auto x = random::multivariate_normal(mean, cov, {1}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>({1, 2}));
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
}
|
||||
|
||||
// Check wrong shapes
|
||||
{
|
||||
auto mean = zeros({3, 1});
|
||||
auto cov = eye(3);
|
||||
CHECK_THROWS_AS(
|
||||
random::multivariate_normal(
|
||||
mean,
|
||||
cov,
|
||||
{
|
||||
1000,
|
||||
},
|
||||
float32),
|
||||
std::invalid_argument);
|
||||
}
|
||||
{
|
||||
auto mean = zeros({3});
|
||||
auto cov = zeros({1, 2, 3, 3});
|
||||
auto x = random::multivariate_normal(mean, cov, {1000, 2}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>({1000, 2, 3}));
|
||||
}
|
||||
{
|
||||
auto mean = zeros({3});
|
||||
auto cov = eye(4);
|
||||
CHECK_THROWS_AS(
|
||||
random::multivariate_normal(mean, cov, {1000, 3}, float32),
|
||||
std::invalid_argument);
|
||||
}
|
||||
|
||||
// Check wrong type
|
||||
{
|
||||
auto mean = zeros({3});
|
||||
auto cov = eye(3);
|
||||
CHECK_THROWS_AS(
|
||||
random::multivariate_normal(mean, cov, {1000, 3}, float16),
|
||||
std::invalid_argument);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test random randint") {
|
||||
CHECK_THROWS_AS(
|
||||
random::randint(array(3), array(5), {1}, float32), std::invalid_argument);
|
||||
|
Reference in New Issue
Block a user