random.uniform must respect dtype, even if lower precision than "low" (#280)

Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
This commit is contained in:
Daniel Strobusch
2023-12-24 16:04:43 +01:00
committed by GitHub
parent 8b227fa9af
commit 7365d142a3
3 changed files with 11 additions and 2 deletions

View File

@@ -103,7 +103,9 @@ array uniform(
}
auto stream = to_stream(s);
auto range = subtract(high, low, stream);
auto lo = astype(low, dtype, stream);
auto hi = astype(high, dtype, stream);
auto range = subtract(hi, lo, stream);
auto out_shape = broadcast_shapes(shape, range.shape());
if (out_shape != shape) {
std::ostringstream msg;
@@ -136,7 +138,7 @@ array uniform(
auto out = bits(shape, size_of(dtype), key, stream);
out = astype(divide(out, maxval, stream), dtype, stream);
out = minimum(out, upper, stream);
return add(multiply(range, out, stream), low, stream);
return add(multiply(range, out, stream), lo, stream);
}
array uniform(