diff --git a/mlx/random.cpp b/mlx/random.cpp index 65ea60535..def3169cb 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -82,6 +82,16 @@ array split(const array& key, int num, StreamOrDevice s /* = {} */) { return bits({num, 2}, 4, key, s); } +// Get the next representable value below 1.0 for half precision +// floating point types (fp16, bf16) +template +T below_one() { + T f = T(1.0); + uint16_t* m = (uint16_t*)&f; + *m -= 1; + return f; +} + array uniform( const array& low, const array& high, @@ -106,7 +116,23 @@ array uniform( << " from broadcasted shape " << out_shape << "."; throw std::invalid_argument(msg.str()); } - auto upper = array(std::nextafter(1.0f, 0.0f), float32); + + // Get random values between [0, nextafter(1.0, 0.0)] since samples must + // be in [low, high) + auto get_upper = [&dtype]() { + switch (dtype) { + case float32: + return array(std::nextafter(1.0f, 0.0f), float32); + case float16: + return array(below_one(), float32); + case bfloat16: + return array(below_one(), float32); + default: + throw std::runtime_error("[uniform] Unsupported type."); + } + }; + + auto upper = get_upper(); auto maxval = array(std::numeric_limits::max(), float32); auto out = bits(shape, size_of(float32), key, stream); out = divide(out, maxval, stream); @@ -154,6 +180,10 @@ array normal( StreamOrDevice s /* = {} */) { if (dtype == complex64) { return complex_normal(shape, loc, scale, key, s); + } else if (!issubdtype(dtype, floating)) { + throw std::invalid_argument( + "[normal] Can only generate uniform numbers with " + "floating point type."); } auto stream = to_stream(s); @@ -417,6 +447,12 @@ array laplace( const float scale /* = 1.0 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { + if (!issubdtype(dtype, floating)) { + throw std::invalid_argument( + "[laplace] Can only generate uniform numbers with real" + "floating point type."); + } + auto stream = to_stream(s); auto low = array(std::nextafter(-1.0f, 0.0f), float32); auto high = array(1.0f, float32); diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index 49f1f300b..6ddd37104 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -350,7 +350,7 @@ TEST_CASE("test random uniform") { // Check float16 { auto key = random::key(0); - auto out = random::uniform({100}, float16, key); + auto out = random::uniform({1000}, float16, key); CHECK_EQ(out.dtype(), float16); CHECK(all(less(out, array(1.0f))).item()); CHECK(all(greater_equal(out, array(0.0f))).item()); @@ -360,7 +360,7 @@ TEST_CASE("test random uniform") { { auto key = random::key(0); - auto out = random::uniform({100}, bfloat16, key); + auto out = random::uniform({1000}, bfloat16, key); CHECK_EQ(out.dtype(), bfloat16); CHECK(all(less(out, array(1.0f))).item()); CHECK(all(greater_equal(out, array(0.0f))).item());