mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
@@ -344,6 +344,27 @@ TEST_CASE("test random uniform") {
|
||||
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||||
CHECK(all(greater_equal(out, array(-1.0f))).item<bool>());
|
||||
}
|
||||
|
||||
// Check float16
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto out = random::uniform({100}, float16, key);
|
||||
CHECK_EQ(out.dtype(), float16);
|
||||
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||||
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
|
||||
CHECK(!all(equal(out, array(0.0f))).item<bool>());
|
||||
CHECK(abs(float(mean(out).item<float16_t>()) - 0.5f) < 0.02);
|
||||
}
|
||||
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto out = random::uniform({100}, bfloat16, key);
|
||||
CHECK_EQ(out.dtype(), bfloat16);
|
||||
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||||
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
|
||||
CHECK(!all(equal(out, array(0.0f))).item<bool>());
|
||||
CHECK(abs(float(mean(out).item<bfloat16_t>()) - 0.5f) < 0.02);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test random normal") {
|
||||
@@ -375,6 +396,25 @@ TEST_CASE("test random normal") {
|
||||
auto key = random::key(128291);
|
||||
auto out = random::normal({100}, key);
|
||||
CHECK(all(less(abs(out), array(inf))).item<bool>());
|
||||
CHECK(abs(mean(out).item<float>()) < 0.1);
|
||||
}
|
||||
|
||||
{
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
auto key = random::key(128291);
|
||||
auto out = random::normal({200}, float16, key);
|
||||
CHECK_EQ(out.dtype(), float16);
|
||||
CHECK(all(less(abs(out), array(inf))).item<bool>());
|
||||
CHECK(abs(float(mean(out).item<float16_t>())) < 0.1);
|
||||
}
|
||||
|
||||
{
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
auto key = random::key(128291);
|
||||
auto out = random::normal({200}, bfloat16, key);
|
||||
CHECK_EQ(out.dtype(), bfloat16);
|
||||
CHECK(all(less(abs(out), array(inf))).item<bool>());
|
||||
CHECK(abs(float(mean(out).item<bfloat16_t>())) < 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user