Implement sampling from laplace distribution. (#1279)

This commit is contained in:
fgranqvist
2024-07-24 15:15:37 +02:00
committed by GitHub
parent c34a5ae7f7
commit 50eff6a10a
6 changed files with 210 additions and 40 deletions

View File

@@ -640,3 +640,74 @@ TEST_CASE("test categorical") {
CHECK_EQ(categorical(logits, -2, 7).shape(), std::vector<int>{5, 3, 7});
CHECK_EQ(categorical(logits, -3, 7).shape(), std::vector<int>{4, 3, 7});
}
TEST_CASE("test laplace") {
// Test shapes, types, and sizes
{
auto x = random::laplace({});
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float32);
// Non float type throws
CHECK_THROWS_AS(random::laplace({}, int32), std::invalid_argument);
// Check wrong key type or shape
auto key = array({0, 0});
CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument);
key = array({0, 0}, {1, 2});
CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument);
key = array({0u, 0u, 0u}, {3, 1});
CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument);
key = array({0u, 0u}, {2, 1});
CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument);
}
{
constexpr float inf = std::numeric_limits<float>::infinity();
auto key = random::key(128291);
auto out = random::laplace({1000000}, key);
float sample_mean = mean(out).item<float>();
float sample_variance = var(out).item<float>();
CHECK(all(less(abs(out), array(inf))).item<bool>());
CHECK(abs(sample_mean) < 0.1);
// Chebyshev's inequality.
for (int k = 1; k <= 5; ++k) {
float prob_above =
mean(greater_equal(out, array(k * std::sqrt(sample_variance))))
.item<float>();
float bound = 1 / std::pow(k, 2);
CHECK(prob_above < bound);
}
// Expected variance for Laplace distribution is 2*scale^2.
float expected_variance = 2.0;
CHECK(std::abs(sample_variance - expected_variance) < 0.01);
// Expected kurtosis of Laplace distribution is 3.
array fourth_pows = power(out - sample_mean, {4});
float sample_kurtosis =
mean(fourth_pows).item<float>() / std::pow(sample_variance, 2) - 3;
float expected_kurtosis = 3.0;
CHECK(std::abs(sample_kurtosis - expected_kurtosis) < 0.1);
}
{
constexpr float inf = std::numeric_limits<float>::infinity();
auto key = random::key(128291);
auto out = random::laplace({10000}, float16, key);
CHECK_EQ(out.dtype(), float16);
CHECK(all(less(abs(out), array(inf))).item<bool>());
CHECK(abs(float(mean(out).item<float16_t>())) < 0.1);
}
{
constexpr float inf = std::numeric_limits<float>::infinity();
auto key = random::key(128291);
auto out = random::laplace({10000}, bfloat16, key);
CHECK_EQ(out.dtype(), bfloat16);
CHECK(all(less(abs(out), array(inf))).item<bool>());
CHECK(abs(float(mean(out).item<bfloat16_t>())) < 0.1);
}
}