random generation fix (#80)

Random generation fix
This commit is contained in:
Awni Hannun 2023-12-08 10:40:57 -08:00 committed by GitHub
parent 86b614afcd
commit 4e3bdb560c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 8 deletions

View File

@ -84,6 +84,10 @@ inline bool is_floating_point(const Dtype& t) {
kindof(t) == Dtype::Kind::c; kindof(t) == Dtype::Kind::c;
} }
inline bool is_complex(const Dtype& t) {
return kindof(t) == Dtype::Kind::c;
}
inline bool is_integral(const Dtype& t) { inline bool is_integral(const Dtype& t) {
return !(is_floating_point(t)); return !(is_floating_point(t));
} }

View File

@ -80,6 +80,16 @@ array split(const array& key, int num, StreamOrDevice s /* = {} */) {
return bits({num, 2}, 4, key, s); return bits({num, 2}, 4, key, s);
} }
// Get the next representable value below 1.0 for half precision
// floating point types (fp16, bf16)
template <typename T>
T below_one() {
T f = T(1.0);
uint16_t* m = (uint16_t*)&f;
*m -= 1;
return f;
}
array uniform( array uniform(
const array& low, const array& low,
const array& high, const array& high,
@ -87,9 +97,9 @@ array uniform(
Dtype dtype /* = float32 */, Dtype dtype /* = float32 */,
const std::optional<array>& key /*= nullopt */, const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (!is_floating_point(dtype)) { if (!is_floating_point(dtype) && !is_complex(dtype)) {
throw std::invalid_argument( throw std::invalid_argument(
"Can only generate uniform numbers with floating point type."); "Can only generate uniform numbers with real floating point type.");
} }
auto stream = to_stream(s); auto stream = to_stream(s);
@ -103,12 +113,29 @@ array uniform(
} }
// Get random values between [0, nextafter(maxval, 0.0f)] since samples must // Get random values between [0, nextafter(maxval, 0.0f)] since samples must
// be in [low, high) // be in [low, high)
// TODO replace minimum with modulo uint32_t(nextafter(maxval, 0.0f)) to avoid auto get_limits = [&dtype]() {
// clipping effects switch (dtype) {
float maxval = std::numeric_limits<uint32_t>::max(); case float32:
auto upper = array(std::nextafter(maxval, 0.0f), dtype); return std::make_pair(
auto out = minimum(bits(shape, size_of(dtype), key, stream), upper, stream); array(std::nextafter(1.0f, 0.0f), float32),
out = divide(out, array(maxval, dtype), stream); array(std::numeric_limits<uint32_t>::max(), float32));
case float16:
return std::make_pair(
array(below_one<float16_t>(), float16),
array(std::numeric_limits<uint16_t>::max(), float32));
case bfloat16:
return std::make_pair(
array(below_one<bfloat16_t>(), bfloat16),
array(std::numeric_limits<uint16_t>::max(), float32));
default:
throw std::runtime_error("[uniform] Unsupported type.");
}
};
auto [upper, maxval] = get_limits();
auto out = bits(shape, size_of(dtype), key, stream);
out = astype(divide(out, maxval, stream), dtype, stream);
out = minimum(out, upper, stream);
return add(multiply(range, out, stream), low, stream); return add(multiply(range, out, stream), low, stream);
} }

View File

@ -344,6 +344,27 @@ TEST_CASE("test random uniform") {
CHECK(all(less(out, array(1.0f))).item<bool>()); CHECK(all(less(out, array(1.0f))).item<bool>());
CHECK(all(greater_equal(out, array(-1.0f))).item<bool>()); CHECK(all(greater_equal(out, array(-1.0f))).item<bool>());
} }
// Check float16
{
auto key = random::key(0);
auto out = random::uniform({100}, float16, key);
CHECK_EQ(out.dtype(), float16);
CHECK(all(less(out, array(1.0f))).item<bool>());
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
CHECK(!all(equal(out, array(0.0f))).item<bool>());
CHECK(abs(float(mean(out).item<float16_t>()) - 0.5f) < 0.02);
}
{
auto key = random::key(0);
auto out = random::uniform({100}, bfloat16, key);
CHECK_EQ(out.dtype(), bfloat16);
CHECK(all(less(out, array(1.0f))).item<bool>());
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
CHECK(!all(equal(out, array(0.0f))).item<bool>());
CHECK(abs(float(mean(out).item<bfloat16_t>()) - 0.5f) < 0.02);
}
} }
TEST_CASE("test random normal") { TEST_CASE("test random normal") {
@ -375,6 +396,25 @@ TEST_CASE("test random normal") {
auto key = random::key(128291); auto key = random::key(128291);
auto out = random::normal({100}, key); auto out = random::normal({100}, key);
CHECK(all(less(abs(out), array(inf))).item<bool>()); CHECK(all(less(abs(out), array(inf))).item<bool>());
CHECK(abs(mean(out).item<float>()) < 0.1);
}
{
constexpr float inf = std::numeric_limits<float>::infinity();
auto key = random::key(128291);
auto out = random::normal({200}, float16, key);
CHECK_EQ(out.dtype(), float16);
CHECK(all(less(abs(out), array(inf))).item<bool>());
CHECK(abs(float(mean(out).item<float16_t>())) < 0.1);
}
{
constexpr float inf = std::numeric_limits<float>::infinity();
auto key = random::key(128291);
auto out = random::normal({200}, bfloat16, key);
CHECK_EQ(out.dtype(), bfloat16);
CHECK(all(less(abs(out), array(inf))).item<bool>());
CHECK(abs(float(mean(out).item<bfloat16_t>())) < 0.1);
} }
} }