diff --git a/mlx/random.cpp b/mlx/random.cpp index 232c458f9..ef11f8c65 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -103,7 +103,9 @@ array uniform( } auto stream = to_stream(s); - auto range = subtract(high, low, stream); + auto lo = astype(low, dtype, stream); + auto hi = astype(high, dtype, stream); + auto range = subtract(hi, lo, stream); auto out_shape = broadcast_shapes(shape, range.shape()); if (out_shape != shape) { std::ostringstream msg; @@ -136,7 +138,7 @@ array uniform( auto out = bits(shape, size_of(dtype), key, stream); out = astype(divide(out, maxval, stream), dtype, stream); out = minimum(out, upper, stream); - return add(multiply(range, out, stream), low, stream); + return add(multiply(range, out, stream), lo, stream); } array uniform( diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 1603371b3..aa01339f4 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -58,6 +58,9 @@ class TestRandom(mlx_tests.MLXTestCase): a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5) self.assertTrue(mx.all((a > -1) < 5).item()) + a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16) + self.assertEqual(a.dtype, mx.bfloat16) + def test_normal(self): key = mx.random.key(0) a = mx.random.normal(key=key) diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index 1a387febc..b7793e41c 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -260,6 +260,10 @@ TEST_CASE("test random uniform") { // Non float type throws CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument); + // dtype respected + x = random::uniform(-.1, .1, {0}, bfloat16); + CHECK_EQ(x.dtype(), bfloat16); + // Check broadcasting x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3}); CHECK_EQ(x.shape(), std::vector{3, 3});