random.uniform must respect dtype, even if lower precision than "low" (#280)

Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
This commit is contained in:
Daniel Strobusch
2023-12-24 16:04:43 +01:00
committed by GitHub
parent 8b227fa9af
commit 7365d142a3
3 changed files with 11 additions and 2 deletions

View File

@@ -260,6 +260,10 @@ TEST_CASE("test random uniform") {
// Non float type throws
CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument);
// dtype respected
x = random::uniform(-.1, .1, {0}, bfloat16);
CHECK_EQ(x.dtype(), bfloat16);
// Check broadcasting
x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});
CHECK_EQ(x.shape(), std::vector<int>{3, 3});