From 7365d142a3a4461d383342d899846e9252fdfc81 Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Sun, 24 Dec 2023 16:04:43 +0100 Subject: [PATCH] random.uniform must respect dtype, even if lower precision than "low" (#280) Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array. --- mlx/random.cpp | 6 ++++-- python/tests/test_random.py | 3 +++ tests/random_tests.cpp | 4 ++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/mlx/random.cpp b/mlx/random.cpp index 232c458f9..ef11f8c65 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -103,7 +103,9 @@ array uniform( } auto stream = to_stream(s); - auto range = subtract(high, low, stream); + auto lo = astype(low, dtype, stream); + auto hi = astype(high, dtype, stream); + auto range = subtract(hi, lo, stream); auto out_shape = broadcast_shapes(shape, range.shape()); if (out_shape != shape) { std::ostringstream msg; @@ -136,7 +138,7 @@ array uniform( auto out = bits(shape, size_of(dtype), key, stream); out = astype(divide(out, maxval, stream), dtype, stream); out = minimum(out, upper, stream); - return add(multiply(range, out, stream), low, stream); + return add(multiply(range, out, stream), lo, stream); } array uniform( diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 1603371b3..aa01339f4 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -58,6 +58,9 @@ class TestRandom(mlx_tests.MLXTestCase): a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5) self.assertTrue(mx.all((a > -1) < 5).item()) + a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16) + self.assertEqual(a.dtype, mx.bfloat16) + def test_normal(self): key = mx.random.key(0) a = mx.random.normal(key=key) diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index 1a387febc..b7793e41c 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -260,6 +260,10 @@ TEST_CASE("test random uniform") { // Non float type throws CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument); + // dtype respected + x = random::uniform(-.1, .1, {0}, bfloat16); + CHECK_EQ(x.dtype(), bfloat16); + // Check broadcasting x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3}); CHECK_EQ(x.shape(), std::vector{3, 3});