random.uniform must respect dtype, even if lower precision than "low" (#280)

Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
This commit is contained in:
Daniel Strobusch 2023-12-24 16:04:43 +01:00 committed by GitHub
parent 8b227fa9af
commit 7365d142a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 2 deletions

View File

@ -103,7 +103,9 @@ array uniform(
}
auto stream = to_stream(s);
auto range = subtract(high, low, stream);
auto lo = astype(low, dtype, stream);
auto hi = astype(high, dtype, stream);
auto range = subtract(hi, lo, stream);
auto out_shape = broadcast_shapes(shape, range.shape());
if (out_shape != shape) {
std::ostringstream msg;
@ -136,7 +138,7 @@ array uniform(
auto out = bits(shape, size_of(dtype), key, stream);
out = astype(divide(out, maxval, stream), dtype, stream);
out = minimum(out, upper, stream);
return add(multiply(range, out, stream), low, stream);
return add(multiply(range, out, stream), lo, stream);
}
array uniform(

View File

@ -58,6 +58,9 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)
self.assertTrue(mx.all((a > -1) < 5).item())
a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)
self.assertEqual(a.dtype, mx.bfloat16)
def test_normal(self):
key = mx.random.key(0)
a = mx.random.normal(key=key)

View File

@ -260,6 +260,10 @@ TEST_CASE("test random uniform") {
// Non float type throws
CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument);
// dtype respected
x = random::uniform(-.1, .1, {0}, bfloat16);
CHECK_EQ(x.dtype(), bfloat16);
// Check broadcasting
x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});
CHECK_EQ(x.shape(), std::vector<int>{3, 3});