mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
random.uniform must respect dtype, even if lower precision than "low" (#280)
Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
This commit is contained in:
@@ -103,7 +103,9 @@ array uniform(
|
||||
}
|
||||
|
||||
auto stream = to_stream(s);
|
||||
auto range = subtract(high, low, stream);
|
||||
auto lo = astype(low, dtype, stream);
|
||||
auto hi = astype(high, dtype, stream);
|
||||
auto range = subtract(hi, lo, stream);
|
||||
auto out_shape = broadcast_shapes(shape, range.shape());
|
||||
if (out_shape != shape) {
|
||||
std::ostringstream msg;
|
||||
@@ -136,7 +138,7 @@ array uniform(
|
||||
auto out = bits(shape, size_of(dtype), key, stream);
|
||||
out = astype(divide(out, maxval, stream), dtype, stream);
|
||||
out = minimum(out, upper, stream);
|
||||
return add(multiply(range, out, stream), low, stream);
|
||||
return add(multiply(range, out, stream), lo, stream);
|
||||
}
|
||||
|
||||
array uniform(
|
||||
|
||||
Reference in New Issue
Block a user