lower memory uniform sampling (#2361)

* lower memory uniform

* use fp32

* fix
This commit is contained in:
Awni Hannun 2025-07-15 14:22:07 -07:00 committed by GitHub
parent cb349a291c
commit 2ba69bc8fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 48 deletions

View File

@ -92,29 +92,6 @@ T below_one() {
return f; return f;
} }
// Get the next representable value above -1.0 for half precision
// floating point types (fp16, bf16)
template <typename T>
T above_minus_one() {
T f = T(-1.0);
uint16_t* m = (uint16_t*)&f;
*m -= 1;
return f;
}
// Get the next representable value above -1.0 for half precision
// use std::nextafter as default case.
array above_minus_one_with_default(Dtype dtype) {
switch (dtype) {
case float16:
return array(above_minus_one<float16_t>(), dtype);
case bfloat16:
return array(above_minus_one<bfloat16_t>(), dtype);
default:
return array(std::nextafter(-1.0f, 0.0f), dtype);
}
}
array uniform( array uniform(
const array& low, const array& low,
const array& high, const array& high,
@ -139,31 +116,27 @@ array uniform(
<< " from broadcasted shape " << out_shape << "."; << " from broadcasted shape " << out_shape << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
// Get random values between [0, nextafter(maxval, 0.0f)] since samples must
// Get random values between [0, nextafter(1.0, 0.0)] since samples must
// be in [low, high) // be in [low, high)
auto get_limits = [&dtype]() { auto get_upper = [&dtype]() {
switch (dtype) { switch (dtype) {
case float32: case float32:
return std::make_pair( return array(std::nextafter(1.0f, 0.0f), float32);
array(std::nextafter(1.0f, 0.0f), float32),
array(std::numeric_limits<uint32_t>::max(), float32));
case float16: case float16:
return std::make_pair( return array(below_one<float16_t>(), float32);
array(below_one<float16_t>(), float16),
array(std::numeric_limits<uint16_t>::max(), float32));
case bfloat16: case bfloat16:
return std::make_pair( return array(below_one<bfloat16_t>(), float32);
array(below_one<bfloat16_t>(), bfloat16),
array(std::numeric_limits<uint16_t>::max(), float32));
default: default:
throw std::runtime_error("[uniform] Unsupported type."); throw std::runtime_error("[uniform] Unsupported type.");
} }
}; };
auto [upper, maxval] = get_limits(); auto upper = get_upper();
auto out = bits(shape, size_of(dtype), key, stream); auto maxval = array(std::numeric_limits<uint32_t>::max(), float32);
out = astype(divide(out, maxval, stream), dtype, stream); auto out = bits(shape, size_of(float32), key, stream);
out = minimum(out, upper, stream); out = divide(out, maxval, stream);
out = astype(minimum(out, upper, stream), dtype, stream);
return add(multiply(range, out, stream), lo, stream); return add(multiply(range, out, stream), lo, stream);
} }
@ -183,7 +156,7 @@ inline array complex_normal(
const std::optional<array>& key, const std::optional<array>& key,
StreamOrDevice s) { StreamOrDevice s) {
auto stream = to_stream(s); auto stream = to_stream(s);
auto low = above_minus_one_with_default(float32); auto low = array(std::nextafter(-1.0f, 0.0f), float32);
auto high = array(1.0f, float32); auto high = array(1.0f, float32);
shape.push_back(2); shape.push_back(2);
auto samples = auto samples =
@ -207,18 +180,23 @@ array normal(
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (dtype == complex64) { if (dtype == complex64) {
return complex_normal(shape, loc, scale, key, s); return complex_normal(shape, loc, scale, key, s);
} else if (!issubdtype(dtype, floating)) {
throw std::invalid_argument(
"[normal] Can only generate uniform numbers with "
"floating point type.");
} }
auto stream = to_stream(s); auto stream = to_stream(s);
auto low = above_minus_one_with_default(dtype); auto low = array(std::nextafter(-1.0f, 0.0f), float32);
auto high = array(1.0f, dtype); auto high = array(1.0f, float32);
auto samples = uniform(low, high, shape, dtype, key, stream); auto samples = uniform(low, high, shape, float32, key, stream);
auto applied_scale = array(std::sqrt(2.0), dtype); auto applied_scale = array(std::sqrt(2.0), dtype);
if (scale.has_value()) { if (scale.has_value()) {
applied_scale = applied_scale =
multiply(applied_scale, astype(*scale, dtype, stream), stream); multiply(applied_scale, astype(*scale, dtype, stream), stream);
} }
samples = multiply(applied_scale, erfinv(samples, stream), stream); samples = astype(erfinv(samples, stream), dtype, stream);
samples = multiply(applied_scale, samples, stream);
if (loc.has_value()) { if (loc.has_value()) {
samples = add(astype(*loc, dtype, stream), samples, stream); samples = add(astype(*loc, dtype, stream), samples, stream);
} }
@ -469,16 +447,23 @@ array laplace(
const float scale /* = 1.0 */, const float scale /* = 1.0 */,
const std::optional<array>& key /*= nullopt */, const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (!issubdtype(dtype, floating)) {
throw std::invalid_argument(
"[laplace] Can only generate uniform numbers with real"
"floating point type.");
}
auto stream = to_stream(s); auto stream = to_stream(s);
auto low = above_minus_one_with_default(dtype); auto low = array(std::nextafter(-1.0f, 0.0f), float32);
auto high = array(1.0f, dtype); auto high = array(1.0f, float32);
auto samples = uniform(low, high, shape, dtype, key, stream); auto samples = uniform(low, high, shape, float32, key, stream);
// Use inverse CDF to generate Laplacian noise // Use inverse CDF to generate Laplacian noise
samples = multiply( samples = multiply(
sign(samples, stream), sign(samples, stream),
log1p( log1p(
multiply(array(-1.0f, dtype), abs(samples, stream), stream), stream), multiply(array(-1.0f, dtype), abs(samples, stream), stream), stream),
stream); stream);
samples = astype(samples, dtype, stream);
if (scale != 1.0) { if (scale != 1.0) {
samples = multiply(array(scale, dtype), samples, stream); samples = multiply(array(scale, dtype), samples, stream);

View File

@ -350,7 +350,7 @@ TEST_CASE("test random uniform") {
// Check float16 // Check float16
{ {
auto key = random::key(0); auto key = random::key(0);
auto out = random::uniform({100}, float16, key); auto out = random::uniform({1000}, float16, key);
CHECK_EQ(out.dtype(), float16); CHECK_EQ(out.dtype(), float16);
CHECK(all(less(out, array(1.0f))).item<bool>()); CHECK(all(less(out, array(1.0f))).item<bool>());
CHECK(all(greater_equal(out, array(0.0f))).item<bool>()); CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
@ -360,7 +360,7 @@ TEST_CASE("test random uniform") {
{ {
auto key = random::key(0); auto key = random::key(0);
auto out = random::uniform({100}, bfloat16, key); auto out = random::uniform({1000}, bfloat16, key);
CHECK_EQ(out.dtype(), bfloat16); CHECK_EQ(out.dtype(), bfloat16);
CHECK(all(less(out, array(1.0f))).item<bool>()); CHECK(all(less(out, array(1.0f))).item<bool>());
CHECK(all(greater_equal(out, array(0.0f))).item<bool>()); CHECK(all(greater_equal(out, array(0.0f))).item<bool>());