lower memory uniform sampling (#2361)

* lower memory uniform

* use fp32

* fix
This commit is contained in:
Awni Hannun
2025-07-15 14:22:07 -07:00
committed by GitHub
parent cb349a291c
commit 2ba69bc8fa
2 changed files with 33 additions and 48 deletions

View File

@@ -350,7 +350,7 @@ TEST_CASE("test random uniform") {
// Check float16
{
auto key = random::key(0);
auto out = random::uniform({100}, float16, key);
auto out = random::uniform({1000}, float16, key);
CHECK_EQ(out.dtype(), float16);
CHECK(all(less(out, array(1.0f))).item<bool>());
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
@@ -360,7 +360,7 @@ TEST_CASE("test random uniform") {
{
auto key = random::key(0);
auto out = random::uniform({100}, bfloat16, key);
auto out = random::uniform({1000}, bfloat16, key);
CHECK_EQ(out.dtype(), bfloat16);
CHECK(all(less(out, array(1.0f))).item<bool>());
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());