random generation fix (#80)

Random generation fix
This commit is contained in:
Awni Hannun
2023-12-08 10:40:57 -08:00
committed by GitHub
parent 86b614afcd
commit 4e3bdb560c
3 changed files with 79 additions and 8 deletions

View File

@@ -344,6 +344,27 @@ TEST_CASE("test random uniform") {
CHECK(all(less(out, array(1.0f))).item<bool>());
CHECK(all(greater_equal(out, array(-1.0f))).item<bool>());
}
// Check float16
{
auto key = random::key(0);
auto out = random::uniform({100}, float16, key);
CHECK_EQ(out.dtype(), float16);
CHECK(all(less(out, array(1.0f))).item<bool>());
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
CHECK(!all(equal(out, array(0.0f))).item<bool>());
CHECK(abs(float(mean(out).item<float16_t>()) - 0.5f) < 0.02);
}
{
auto key = random::key(0);
auto out = random::uniform({100}, bfloat16, key);
CHECK_EQ(out.dtype(), bfloat16);
CHECK(all(less(out, array(1.0f))).item<bool>());
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
CHECK(!all(equal(out, array(0.0f))).item<bool>());
CHECK(abs(float(mean(out).item<bfloat16_t>()) - 0.5f) < 0.02);
}
}
TEST_CASE("test random normal") {
@@ -375,6 +396,25 @@ TEST_CASE("test random normal") {
auto key = random::key(128291);
auto out = random::normal({100}, key);
CHECK(all(less(abs(out), array(inf))).item<bool>());
CHECK(abs(mean(out).item<float>()) < 0.1);
}
{
constexpr float inf = std::numeric_limits<float>::infinity();
auto key = random::key(128291);
auto out = random::normal({200}, float16, key);
CHECK_EQ(out.dtype(), float16);
CHECK(all(less(abs(out), array(inf))).item<bool>());
CHECK(abs(float(mean(out).item<float16_t>())) < 0.1);
}
{
constexpr float inf = std::numeric_limits<float>::infinity();
auto key = random::key(128291);
auto out = random::normal({200}, bfloat16, key);
CHECK_EQ(out.dtype(), bfloat16);
CHECK(all(less(abs(out), array(inf))).item<bool>());
CHECK(abs(float(mean(out).item<bfloat16_t>())) < 0.1);
}
}