diff --git a/mlx/random.cpp b/mlx/random.cpp index 5367e7ca9..65ea60535 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -82,39 +82,6 @@ array split(const array& key, int num, StreamOrDevice s /* = {} */) { return bits({num, 2}, 4, key, s); } -// Get the next representable value below 1.0 for half precision -// floating point types (fp16, bf16) -template -T below_one() { - T f = T(1.0); - uint16_t* m = (uint16_t*)&f; - *m -= 1; - return f; -} - -// Get the next representable value above -1.0 for half precision -// floating point types (fp16, bf16) -template -T above_minus_one() { - T f = T(-1.0); - uint16_t* m = (uint16_t*)&f; - *m -= 1; - return f; -} - -// Get the next representable value above -1.0 for half precision -// use std::nextafter as default case. -array above_minus_one_with_default(Dtype dtype) { - switch (dtype) { - case float16: - return array(above_minus_one(), dtype); - case bfloat16: - return array(above_minus_one(), dtype); - default: - return array(std::nextafter(-1.0f, 0.0f), dtype); - } -} - array uniform( const array& low, const array& high, @@ -139,31 +106,11 @@ array uniform( << " from broadcasted shape " << out_shape << "."; throw std::invalid_argument(msg.str()); } - // Get random values between [0, nextafter(maxval, 0.0f)] since samples must - // be in [low, high) - auto get_limits = [&dtype]() { - switch (dtype) { - case float32: - return std::make_pair( - array(std::nextafter(1.0f, 0.0f), float32), - array(std::numeric_limits::max(), float32)); - case float16: - return std::make_pair( - array(below_one(), float16), - array(std::numeric_limits::max(), float32)); - case bfloat16: - return std::make_pair( - array(below_one(), bfloat16), - array(std::numeric_limits::max(), bfloat16)); - default: - throw std::runtime_error("[uniform] Unsupported type."); - } - }; - - auto [upper, maxval] = get_limits(); - auto out = bits(shape, size_of(maxval.dtype()), key, stream); - out = astype(divide(out, maxval, stream), dtype, stream); - out = minimum(out, upper, stream); + auto upper = array(std::nextafter(1.0f, 0.0f), float32); + auto maxval = array(std::numeric_limits::max(), float32); + auto out = bits(shape, size_of(float32), key, stream); + out = divide(out, maxval, stream); + out = astype(minimum(out, upper, stream), dtype, stream); return add(multiply(range, out, stream), lo, stream); } @@ -183,7 +130,7 @@ inline array complex_normal( const std::optional& key, StreamOrDevice s) { auto stream = to_stream(s); - auto low = above_minus_one_with_default(float32); + auto low = array(std::nextafter(-1.0f, 0.0f), float32); auto high = array(1.0f, float32); shape.push_back(2); auto samples = @@ -210,15 +157,16 @@ array normal( } auto stream = to_stream(s); - auto low = above_minus_one_with_default(dtype); - auto high = array(1.0f, dtype); - auto samples = uniform(low, high, shape, dtype, key, stream); + auto low = array(std::nextafter(-1.0f, 0.0f), float32); + auto high = array(1.0f, float32); + auto samples = uniform(low, high, shape, float32, key, stream); auto applied_scale = array(std::sqrt(2.0), dtype); if (scale.has_value()) { applied_scale = multiply(applied_scale, astype(*scale, dtype, stream), stream); } - samples = multiply(applied_scale, erfinv(samples, stream), stream); + samples = astype(erfinv(samples, stream), dtype, stream); + samples = multiply(applied_scale, samples, stream); if (loc.has_value()) { samples = add(astype(*loc, dtype, stream), samples, stream); } @@ -470,15 +418,16 @@ array laplace( const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { auto stream = to_stream(s); - auto low = above_minus_one_with_default(dtype); - auto high = array(1.0f, dtype); - auto samples = uniform(low, high, shape, dtype, key, stream); + auto low = array(std::nextafter(-1.0f, 0.0f), float32); + auto high = array(1.0f, float32); + auto samples = uniform(low, high, shape, float32, key, stream); // Use inverse CDF to generate Laplacian noise samples = multiply( sign(samples, stream), log1p( multiply(array(-1.0f, dtype), abs(samples, stream), stream), stream), stream); + samples = astype(samples, dtype, stream); if (scale != 1.0) { samples = multiply(array(scale, dtype), samples, stream);