mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 06:38:38 +08:00
lower memory uniform sampling (#2361)
* lower memory uniform * use fp32 * fix
This commit is contained in:
@@ -350,7 +350,7 @@ TEST_CASE("test random uniform") {
|
||||
// Check float16
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto out = random::uniform({100}, float16, key);
|
||||
auto out = random::uniform({1000}, float16, key);
|
||||
CHECK_EQ(out.dtype(), float16);
|
||||
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||||
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
|
||||
@@ -360,7 +360,7 @@ TEST_CASE("test random uniform") {
|
||||
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto out = random::uniform({100}, bfloat16, key);
|
||||
auto out = random::uniform({1000}, bfloat16, key);
|
||||
CHECK_EQ(out.dtype(), bfloat16);
|
||||
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||||
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
|
||||
|
Reference in New Issue
Block a user