diff --git a/mlx/random.cpp b/mlx/random.cpp index 6c6d1eb95..def3169cb 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -92,29 +92,6 @@ T below_one() { return f; } -// Get the next representable value above -1.0 for half precision -// floating point types (fp16, bf16) -template -T above_minus_one() { - T f = T(-1.0); - uint16_t* m = (uint16_t*)&f; - *m -= 1; - return f; -} - -// Get the next representable value above -1.0 for half precision -// use std::nextafter as default case. -array above_minus_one_with_default(Dtype dtype) { - switch (dtype) { - case float16: - return array(above_minus_one(), dtype); - case bfloat16: - return array(above_minus_one(), dtype); - default: - return array(std::nextafter(-1.0f, 0.0f), dtype); - } -} - array uniform( const array& low, const array& high, @@ -139,31 +116,27 @@ array uniform( << " from broadcasted shape " << out_shape << "."; throw std::invalid_argument(msg.str()); } - // Get random values between [0, nextafter(maxval, 0.0f)] since samples must + + // Get random values between [0, nextafter(1.0, 0.0)] since samples must // be in [low, high) - auto get_limits = [&dtype]() { + auto get_upper = [&dtype]() { switch (dtype) { case float32: - return std::make_pair( - array(std::nextafter(1.0f, 0.0f), float32), - array(std::numeric_limits::max(), float32)); + return array(std::nextafter(1.0f, 0.0f), float32); case float16: - return std::make_pair( - array(below_one(), float16), - array(std::numeric_limits::max(), float32)); + return array(below_one(), float32); case bfloat16: - return std::make_pair( - array(below_one(), bfloat16), - array(std::numeric_limits::max(), float32)); + return array(below_one(), float32); default: throw std::runtime_error("[uniform] Unsupported type."); } }; - auto [upper, maxval] = get_limits(); - auto out = bits(shape, size_of(dtype), key, stream); - out = astype(divide(out, maxval, stream), dtype, stream); - out = minimum(out, upper, stream); + auto upper = get_upper(); + auto maxval = array(std::numeric_limits::max(), float32); + auto out = bits(shape, size_of(float32), key, stream); + out = divide(out, maxval, stream); + out = astype(minimum(out, upper, stream), dtype, stream); return add(multiply(range, out, stream), lo, stream); } @@ -183,7 +156,7 @@ inline array complex_normal( const std::optional& key, StreamOrDevice s) { auto stream = to_stream(s); - auto low = above_minus_one_with_default(float32); + auto low = array(std::nextafter(-1.0f, 0.0f), float32); auto high = array(1.0f, float32); shape.push_back(2); auto samples = @@ -207,18 +180,23 @@ array normal( StreamOrDevice s /* = {} */) { if (dtype == complex64) { return complex_normal(shape, loc, scale, key, s); + } else if (!issubdtype(dtype, floating)) { + throw std::invalid_argument( + "[normal] Can only generate uniform numbers with " + "floating point type."); } auto stream = to_stream(s); - auto low = above_minus_one_with_default(dtype); - auto high = array(1.0f, dtype); - auto samples = uniform(low, high, shape, dtype, key, stream); + auto low = array(std::nextafter(-1.0f, 0.0f), float32); + auto high = array(1.0f, float32); + auto samples = uniform(low, high, shape, float32, key, stream); auto applied_scale = array(std::sqrt(2.0), dtype); if (scale.has_value()) { applied_scale = multiply(applied_scale, astype(*scale, dtype, stream), stream); } - samples = multiply(applied_scale, erfinv(samples, stream), stream); + samples = astype(erfinv(samples, stream), dtype, stream); + samples = multiply(applied_scale, samples, stream); if (loc.has_value()) { samples = add(astype(*loc, dtype, stream), samples, stream); } @@ -469,16 +447,23 @@ array laplace( const float scale /* = 1.0 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { + if (!issubdtype(dtype, floating)) { + throw std::invalid_argument( + "[laplace] Can only generate uniform numbers with real" + "floating point type."); + } + auto stream = to_stream(s); - auto low = above_minus_one_with_default(dtype); - auto high = array(1.0f, dtype); - auto samples = uniform(low, high, shape, dtype, key, stream); + auto low = array(std::nextafter(-1.0f, 0.0f), float32); + auto high = array(1.0f, float32); + auto samples = uniform(low, high, shape, float32, key, stream); // Use inverse CDF to generate Laplacian noise samples = multiply( sign(samples, stream), log1p( multiply(array(-1.0f, dtype), abs(samples, stream), stream), stream), stream); + samples = astype(samples, dtype, stream); if (scale != 1.0) { samples = multiply(array(scale, dtype), samples, stream); diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index 49f1f300b..6ddd37104 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -350,7 +350,7 @@ TEST_CASE("test random uniform") { // Check float16 { auto key = random::key(0); - auto out = random::uniform({100}, float16, key); + auto out = random::uniform({1000}, float16, key); CHECK_EQ(out.dtype(), float16); CHECK(all(less(out, array(1.0f))).item()); CHECK(all(greater_equal(out, array(0.0f))).item()); @@ -360,7 +360,7 @@ TEST_CASE("test random uniform") { { auto key = random::key(0); - auto out = random::uniform({100}, bfloat16, key); + auto out = random::uniform({1000}, bfloat16, key); CHECK_EQ(out.dtype(), bfloat16); CHECK(all(less(out, array(1.0f))).item()); CHECK(all(greater_equal(out, array(0.0f))).item());