mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix
This commit is contained in:
@@ -82,6 +82,16 @@ array split(const array& key, int num, StreamOrDevice s /* = {} */) {
|
||||
return bits({num, 2}, 4, key, s);
|
||||
}
|
||||
|
||||
// Get the next representable value below 1.0 for half precision
|
||||
// floating point types (fp16, bf16)
|
||||
template <typename T>
|
||||
T below_one() {
|
||||
T f = T(1.0);
|
||||
uint16_t* m = (uint16_t*)&f;
|
||||
*m -= 1;
|
||||
return f;
|
||||
}
|
||||
|
||||
array uniform(
|
||||
const array& low,
|
||||
const array& high,
|
||||
@@ -106,7 +116,23 @@ array uniform(
|
||||
<< " from broadcasted shape " << out_shape << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto upper = array(std::nextafter(1.0f, 0.0f), float32);
|
||||
|
||||
// Get random values between [0, nextafter(1.0, 0.0)] since samples must
|
||||
// be in [low, high)
|
||||
auto get_upper = [&dtype]() {
|
||||
switch (dtype) {
|
||||
case float32:
|
||||
return array(std::nextafter(1.0f, 0.0f), float32);
|
||||
case float16:
|
||||
return array(below_one<float16_t>(), float32);
|
||||
case bfloat16:
|
||||
return array(below_one<bfloat16_t>(), float32);
|
||||
default:
|
||||
throw std::runtime_error("[uniform] Unsupported type.");
|
||||
}
|
||||
};
|
||||
|
||||
auto upper = get_upper();
|
||||
auto maxval = array(std::numeric_limits<uint32_t>::max(), float32);
|
||||
auto out = bits(shape, size_of(float32), key, stream);
|
||||
out = divide(out, maxval, stream);
|
||||
@@ -154,6 +180,10 @@ array normal(
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (dtype == complex64) {
|
||||
return complex_normal(shape, loc, scale, key, s);
|
||||
} else if (!issubdtype(dtype, floating)) {
|
||||
throw std::invalid_argument(
|
||||
"[normal] Can only generate uniform numbers with "
|
||||
"floating point type.");
|
||||
}
|
||||
|
||||
auto stream = to_stream(s);
|
||||
@@ -417,6 +447,12 @@ array laplace(
|
||||
const float scale /* = 1.0 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
throw std::invalid_argument(
|
||||
"[laplace] Can only generate uniform numbers with real"
|
||||
"floating point type.");
|
||||
}
|
||||
|
||||
auto stream = to_stream(s);
|
||||
auto low = array(std::nextafter(-1.0f, 0.0f), float32);
|
||||
auto high = array(1.0f, float32);
|
||||
|
||||
@@ -350,7 +350,7 @@ TEST_CASE("test random uniform") {
|
||||
// Check float16
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto out = random::uniform({100}, float16, key);
|
||||
auto out = random::uniform({1000}, float16, key);
|
||||
CHECK_EQ(out.dtype(), float16);
|
||||
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||||
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
|
||||
@@ -360,7 +360,7 @@ TEST_CASE("test random uniform") {
|
||||
|
||||
{
|
||||
auto key = random::key(0);
|
||||
auto out = random::uniform({100}, bfloat16, key);
|
||||
auto out = random::uniform({1000}, bfloat16, key);
|
||||
CHECK_EQ(out.dtype(), bfloat16);
|
||||
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||||
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
|
||||
|
||||
Reference in New Issue
Block a user