mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
random.uniform must respect dtype, even if lower precision than "low" (#280)
Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
This commit is contained in:
parent
8b227fa9af
commit
7365d142a3
@ -103,7 +103,9 @@ array uniform(
|
||||
}
|
||||
|
||||
auto stream = to_stream(s);
|
||||
auto range = subtract(high, low, stream);
|
||||
auto lo = astype(low, dtype, stream);
|
||||
auto hi = astype(high, dtype, stream);
|
||||
auto range = subtract(hi, lo, stream);
|
||||
auto out_shape = broadcast_shapes(shape, range.shape());
|
||||
if (out_shape != shape) {
|
||||
std::ostringstream msg;
|
||||
@ -136,7 +138,7 @@ array uniform(
|
||||
auto out = bits(shape, size_of(dtype), key, stream);
|
||||
out = astype(divide(out, maxval, stream), dtype, stream);
|
||||
out = minimum(out, upper, stream);
|
||||
return add(multiply(range, out, stream), low, stream);
|
||||
return add(multiply(range, out, stream), lo, stream);
|
||||
}
|
||||
|
||||
array uniform(
|
||||
|
@ -58,6 +58,9 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)
|
||||
self.assertTrue(mx.all((a > -1) < 5).item())
|
||||
|
||||
a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)
|
||||
self.assertEqual(a.dtype, mx.bfloat16)
|
||||
|
||||
def test_normal(self):
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.normal(key=key)
|
||||
|
@ -260,6 +260,10 @@ TEST_CASE("test random uniform") {
|
||||
// Non float type throws
|
||||
CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument);
|
||||
|
||||
// dtype respected
|
||||
x = random::uniform(-.1, .1, {0}, bfloat16);
|
||||
CHECK_EQ(x.dtype(), bfloat16);
|
||||
|
||||
// Check broadcasting
|
||||
x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});
|
||||
CHECK_EQ(x.shape(), std::vector<int>{3, 3});
|
||||
|
Loading…
Reference in New Issue
Block a user