mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 23:15:09 +08:00
random.uniform must respect dtype, even if lower precision than "low" (#280)
Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
This commit is contained in:
@@ -260,6 +260,10 @@ TEST_CASE("test random uniform") {
|
||||
// Non float type throws
|
||||
CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument);
|
||||
|
||||
// dtype respected
|
||||
x = random::uniform(-.1, .1, {0}, bfloat16);
|
||||
CHECK_EQ(x.dtype(), bfloat16);
|
||||
|
||||
// Check broadcasting
|
||||
x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});
|
||||
CHECK_EQ(x.shape(), std::vector<int>{3, 3});
|
||||
|
||||
Reference in New Issue
Block a user