diff --git a/mlx/random.cpp b/mlx/random.cpp index 1490be159..fae2e592c 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -90,6 +90,16 @@ T below_one() { return f; } +// Get the next representable value above -1.0 for half precision +// floating point types (fp16, bf16) +template +T above_minus_one() { + T f = T(-1.0); + uint16_t* m = (uint16_t*)&f; + *m -= 1; + return f; +} + array uniform( const array& low, const array& high, @@ -158,7 +168,17 @@ array normal( const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { auto stream = to_stream(s); - auto low = array(std::nextafter(-1.0f, 0.0f), dtype); + auto get_low = [&dtype]() { + switch (dtype) { + case float16: + return array(above_minus_one(), dtype); + case bfloat16: + return array(above_minus_one(), dtype); + default: + return array(std::nextafter(-1.0f, 0.0f), dtype); + } + }; + auto low = get_low(); auto high = array(1.0f, dtype); auto samples = uniform(low, high, shape, dtype, key, stream); samples = diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 892db37df..7515cf468 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -96,6 +96,11 @@ class TestRandom(mlx_tests.MLXTestCase): self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype) + # Test not getting -inf or inf with half precison + for hp in [mx.float16, mx.bfloat16]: + a = abs(mx.random.normal(shape=(10000,), loc=0, scale=1, dtype=hp)) + self.assertTrue(mx.all(a < mx.inf)) + def test_randint(self): a = mx.random.randint(0, 1, []) self.assertEqual(a.shape, ())