mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Implement sampling from laplace distribution. (#1279)
This commit is contained in:
@@ -640,3 +640,74 @@ TEST_CASE("test categorical") {
|
||||
CHECK_EQ(categorical(logits, -2, 7).shape(), std::vector<int>{5, 3, 7});
|
||||
CHECK_EQ(categorical(logits, -3, 7).shape(), std::vector<int>{4, 3, 7});
|
||||
}
|
||||
|
||||
TEST_CASE("test laplace") {
|
||||
// Test shapes, types, and sizes
|
||||
{
|
||||
auto x = random::laplace({});
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
// Non float type throws
|
||||
CHECK_THROWS_AS(random::laplace({}, int32), std::invalid_argument);
|
||||
|
||||
// Check wrong key type or shape
|
||||
auto key = array({0, 0});
|
||||
CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument);
|
||||
key = array({0, 0}, {1, 2});
|
||||
CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument);
|
||||
key = array({0u, 0u, 0u}, {3, 1});
|
||||
CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument);
|
||||
key = array({0u, 0u}, {2, 1});
|
||||
CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument);
|
||||
}
|
||||
|
||||
{
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
auto key = random::key(128291);
|
||||
auto out = random::laplace({1000000}, key);
|
||||
float sample_mean = mean(out).item<float>();
|
||||
float sample_variance = var(out).item<float>();
|
||||
|
||||
CHECK(all(less(abs(out), array(inf))).item<bool>());
|
||||
CHECK(abs(sample_mean) < 0.1);
|
||||
|
||||
// Chebyshev's inequality.
|
||||
for (int k = 1; k <= 5; ++k) {
|
||||
float prob_above =
|
||||
mean(greater_equal(out, array(k * std::sqrt(sample_variance))))
|
||||
.item<float>();
|
||||
float bound = 1 / std::pow(k, 2);
|
||||
CHECK(prob_above < bound);
|
||||
}
|
||||
|
||||
// Expected variance for Laplace distribution is 2*scale^2.
|
||||
float expected_variance = 2.0;
|
||||
CHECK(std::abs(sample_variance - expected_variance) < 0.01);
|
||||
|
||||
// Expected kurtosis of Laplace distribution is 3.
|
||||
array fourth_pows = power(out - sample_mean, {4});
|
||||
float sample_kurtosis =
|
||||
mean(fourth_pows).item<float>() / std::pow(sample_variance, 2) - 3;
|
||||
float expected_kurtosis = 3.0;
|
||||
CHECK(std::abs(sample_kurtosis - expected_kurtosis) < 0.1);
|
||||
}
|
||||
|
||||
{
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
auto key = random::key(128291);
|
||||
auto out = random::laplace({10000}, float16, key);
|
||||
CHECK_EQ(out.dtype(), float16);
|
||||
CHECK(all(less(abs(out), array(inf))).item<bool>());
|
||||
CHECK(abs(float(mean(out).item<float16_t>())) < 0.1);
|
||||
}
|
||||
|
||||
{
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
auto key = random::key(128291);
|
||||
auto out = random::laplace({10000}, bfloat16, key);
|
||||
CHECK_EQ(out.dtype(), bfloat16);
|
||||
CHECK(all(less(abs(out), array(inf))).item<bool>());
|
||||
CHECK(abs(float(mean(out).item<bfloat16_t>())) < 0.1);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user