diff --git a/mlx/dtype.h b/mlx/dtype.h index cdb8aa591..d52830485 100644 --- a/mlx/dtype.h +++ b/mlx/dtype.h @@ -84,6 +84,10 @@ inline bool is_floating_point(const Dtype& t) { kindof(t) == Dtype::Kind::c; } +inline bool is_complex(const Dtype& t) { + return kindof(t) == Dtype::Kind::c; +} + inline bool is_integral(const Dtype& t) { return !(is_floating_point(t)); } diff --git a/mlx/random.cpp b/mlx/random.cpp index e976c2b89..232c458f9 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -80,6 +80,16 @@ array split(const array& key, int num, StreamOrDevice s /* = {} */) { return bits({num, 2}, 4, key, s); } +// Get the next representable value below 1.0 for half precision +// floating point types (fp16, bf16) +template +T below_one() { + T f = T(1.0); + uint16_t* m = (uint16_t*)&f; + *m -= 1; + return f; +} + array uniform( const array& low, const array& high, @@ -87,9 +97,9 @@ array uniform( Dtype dtype /* = float32 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { - if (!is_floating_point(dtype)) { + if (!is_floating_point(dtype) && !is_complex(dtype)) { throw std::invalid_argument( - "Can only generate uniform numbers with floating point type."); + "Can only generate uniform numbers with real floating point type."); } auto stream = to_stream(s); @@ -103,12 +113,29 @@ array uniform( } // Get random values between [0, nextafter(maxval, 0.0f)] since samples must // be in [low, high) - // TODO replace minimum with modulo uint32_t(nextafter(maxval, 0.0f)) to avoid - // clipping effects - float maxval = std::numeric_limits::max(); - auto upper = array(std::nextafter(maxval, 0.0f), dtype); - auto out = minimum(bits(shape, size_of(dtype), key, stream), upper, stream); - out = divide(out, array(maxval, dtype), stream); + auto get_limits = [&dtype]() { + switch (dtype) { + case float32: + return std::make_pair( + array(std::nextafter(1.0f, 0.0f), float32), + array(std::numeric_limits::max(), float32)); + case float16: + return std::make_pair( + array(below_one(), float16), + array(std::numeric_limits::max(), float32)); + case bfloat16: + return std::make_pair( + array(below_one(), bfloat16), + array(std::numeric_limits::max(), float32)); + default: + throw std::runtime_error("[uniform] Unsupported type."); + } + }; + + auto [upper, maxval] = get_limits(); + auto out = bits(shape, size_of(dtype), key, stream); + out = astype(divide(out, maxval, stream), dtype, stream); + out = minimum(out, upper, stream); return add(multiply(range, out, stream), low, stream); } diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index c878205e0..1a387febc 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -344,6 +344,27 @@ TEST_CASE("test random uniform") { CHECK(all(less(out, array(1.0f))).item()); CHECK(all(greater_equal(out, array(-1.0f))).item()); } + + // Check float16 + { + auto key = random::key(0); + auto out = random::uniform({100}, float16, key); + CHECK_EQ(out.dtype(), float16); + CHECK(all(less(out, array(1.0f))).item()); + CHECK(all(greater_equal(out, array(0.0f))).item()); + CHECK(!all(equal(out, array(0.0f))).item()); + CHECK(abs(float(mean(out).item()) - 0.5f) < 0.02); + } + + { + auto key = random::key(0); + auto out = random::uniform({100}, bfloat16, key); + CHECK_EQ(out.dtype(), bfloat16); + CHECK(all(less(out, array(1.0f))).item()); + CHECK(all(greater_equal(out, array(0.0f))).item()); + CHECK(!all(equal(out, array(0.0f))).item()); + CHECK(abs(float(mean(out).item()) - 0.5f) < 0.02); + } } TEST_CASE("test random normal") { @@ -375,6 +396,25 @@ TEST_CASE("test random normal") { auto key = random::key(128291); auto out = random::normal({100}, key); CHECK(all(less(abs(out), array(inf))).item()); + CHECK(abs(mean(out).item()) < 0.1); + } + + { + constexpr float inf = std::numeric_limits::infinity(); + auto key = random::key(128291); + auto out = random::normal({200}, float16, key); + CHECK_EQ(out.dtype(), float16); + CHECK(all(less(abs(out), array(inf))).item()); + CHECK(abs(float(mean(out).item())) < 0.1); + } + + { + constexpr float inf = std::numeric_limits::infinity(); + auto key = random::key(128291); + auto out = random::normal({200}, bfloat16, key); + CHECK_EQ(out.dtype(), bfloat16); + CHECK(all(less(abs(out), array(inf))).item()); + CHECK(abs(float(mean(out).item())) < 0.1); } }