mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 07:01:13 +08:00
lower memory uniform sampling (#2361)
* lower memory uniform * use fp32 * fix
This commit is contained in:
parent
cb349a291c
commit
2ba69bc8fa
@ -92,29 +92,6 @@ T below_one() {
|
|||||||
return f;
|
return f;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the next representable value above -1.0 for half precision
|
|
||||||
// floating point types (fp16, bf16)
|
|
||||||
template <typename T>
|
|
||||||
T above_minus_one() {
|
|
||||||
T f = T(-1.0);
|
|
||||||
uint16_t* m = (uint16_t*)&f;
|
|
||||||
*m -= 1;
|
|
||||||
return f;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the next representable value above -1.0 for half precision
|
|
||||||
// use std::nextafter as default case.
|
|
||||||
array above_minus_one_with_default(Dtype dtype) {
|
|
||||||
switch (dtype) {
|
|
||||||
case float16:
|
|
||||||
return array(above_minus_one<float16_t>(), dtype);
|
|
||||||
case bfloat16:
|
|
||||||
return array(above_minus_one<bfloat16_t>(), dtype);
|
|
||||||
default:
|
|
||||||
return array(std::nextafter(-1.0f, 0.0f), dtype);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
array uniform(
|
array uniform(
|
||||||
const array& low,
|
const array& low,
|
||||||
const array& high,
|
const array& high,
|
||||||
@ -139,31 +116,27 @@ array uniform(
|
|||||||
<< " from broadcasted shape " << out_shape << ".";
|
<< " from broadcasted shape " << out_shape << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
// Get random values between [0, nextafter(maxval, 0.0f)] since samples must
|
|
||||||
|
// Get random values between [0, nextafter(1.0, 0.0)] since samples must
|
||||||
// be in [low, high)
|
// be in [low, high)
|
||||||
auto get_limits = [&dtype]() {
|
auto get_upper = [&dtype]() {
|
||||||
switch (dtype) {
|
switch (dtype) {
|
||||||
case float32:
|
case float32:
|
||||||
return std::make_pair(
|
return array(std::nextafter(1.0f, 0.0f), float32);
|
||||||
array(std::nextafter(1.0f, 0.0f), float32),
|
|
||||||
array(std::numeric_limits<uint32_t>::max(), float32));
|
|
||||||
case float16:
|
case float16:
|
||||||
return std::make_pair(
|
return array(below_one<float16_t>(), float32);
|
||||||
array(below_one<float16_t>(), float16),
|
|
||||||
array(std::numeric_limits<uint16_t>::max(), float32));
|
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return std::make_pair(
|
return array(below_one<bfloat16_t>(), float32);
|
||||||
array(below_one<bfloat16_t>(), bfloat16),
|
|
||||||
array(std::numeric_limits<uint16_t>::max(), float32));
|
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("[uniform] Unsupported type.");
|
throw std::runtime_error("[uniform] Unsupported type.");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
auto [upper, maxval] = get_limits();
|
auto upper = get_upper();
|
||||||
auto out = bits(shape, size_of(dtype), key, stream);
|
auto maxval = array(std::numeric_limits<uint32_t>::max(), float32);
|
||||||
out = astype(divide(out, maxval, stream), dtype, stream);
|
auto out = bits(shape, size_of(float32), key, stream);
|
||||||
out = minimum(out, upper, stream);
|
out = divide(out, maxval, stream);
|
||||||
|
out = astype(minimum(out, upper, stream), dtype, stream);
|
||||||
return add(multiply(range, out, stream), lo, stream);
|
return add(multiply(range, out, stream), lo, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -183,7 +156,7 @@ inline array complex_normal(
|
|||||||
const std::optional<array>& key,
|
const std::optional<array>& key,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
auto low = above_minus_one_with_default(float32);
|
auto low = array(std::nextafter(-1.0f, 0.0f), float32);
|
||||||
auto high = array(1.0f, float32);
|
auto high = array(1.0f, float32);
|
||||||
shape.push_back(2);
|
shape.push_back(2);
|
||||||
auto samples =
|
auto samples =
|
||||||
@ -207,18 +180,23 @@ array normal(
|
|||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (dtype == complex64) {
|
if (dtype == complex64) {
|
||||||
return complex_normal(shape, loc, scale, key, s);
|
return complex_normal(shape, loc, scale, key, s);
|
||||||
|
} else if (!issubdtype(dtype, floating)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[normal] Can only generate uniform numbers with "
|
||||||
|
"floating point type.");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
auto low = above_minus_one_with_default(dtype);
|
auto low = array(std::nextafter(-1.0f, 0.0f), float32);
|
||||||
auto high = array(1.0f, dtype);
|
auto high = array(1.0f, float32);
|
||||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
auto samples = uniform(low, high, shape, float32, key, stream);
|
||||||
auto applied_scale = array(std::sqrt(2.0), dtype);
|
auto applied_scale = array(std::sqrt(2.0), dtype);
|
||||||
if (scale.has_value()) {
|
if (scale.has_value()) {
|
||||||
applied_scale =
|
applied_scale =
|
||||||
multiply(applied_scale, astype(*scale, dtype, stream), stream);
|
multiply(applied_scale, astype(*scale, dtype, stream), stream);
|
||||||
}
|
}
|
||||||
samples = multiply(applied_scale, erfinv(samples, stream), stream);
|
samples = astype(erfinv(samples, stream), dtype, stream);
|
||||||
|
samples = multiply(applied_scale, samples, stream);
|
||||||
if (loc.has_value()) {
|
if (loc.has_value()) {
|
||||||
samples = add(astype(*loc, dtype, stream), samples, stream);
|
samples = add(astype(*loc, dtype, stream), samples, stream);
|
||||||
}
|
}
|
||||||
@ -469,16 +447,23 @@ array laplace(
|
|||||||
const float scale /* = 1.0 */,
|
const float scale /* = 1.0 */,
|
||||||
const std::optional<array>& key /*= nullopt */,
|
const std::optional<array>& key /*= nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (!issubdtype(dtype, floating)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[laplace] Can only generate uniform numbers with real"
|
||||||
|
"floating point type.");
|
||||||
|
}
|
||||||
|
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
auto low = above_minus_one_with_default(dtype);
|
auto low = array(std::nextafter(-1.0f, 0.0f), float32);
|
||||||
auto high = array(1.0f, dtype);
|
auto high = array(1.0f, float32);
|
||||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
auto samples = uniform(low, high, shape, float32, key, stream);
|
||||||
// Use inverse CDF to generate Laplacian noise
|
// Use inverse CDF to generate Laplacian noise
|
||||||
samples = multiply(
|
samples = multiply(
|
||||||
sign(samples, stream),
|
sign(samples, stream),
|
||||||
log1p(
|
log1p(
|
||||||
multiply(array(-1.0f, dtype), abs(samples, stream), stream), stream),
|
multiply(array(-1.0f, dtype), abs(samples, stream), stream), stream),
|
||||||
stream);
|
stream);
|
||||||
|
samples = astype(samples, dtype, stream);
|
||||||
|
|
||||||
if (scale != 1.0) {
|
if (scale != 1.0) {
|
||||||
samples = multiply(array(scale, dtype), samples, stream);
|
samples = multiply(array(scale, dtype), samples, stream);
|
||||||
|
@ -350,7 +350,7 @@ TEST_CASE("test random uniform") {
|
|||||||
// Check float16
|
// Check float16
|
||||||
{
|
{
|
||||||
auto key = random::key(0);
|
auto key = random::key(0);
|
||||||
auto out = random::uniform({100}, float16, key);
|
auto out = random::uniform({1000}, float16, key);
|
||||||
CHECK_EQ(out.dtype(), float16);
|
CHECK_EQ(out.dtype(), float16);
|
||||||
CHECK(all(less(out, array(1.0f))).item<bool>());
|
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||||||
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
|
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
|
||||||
@ -360,7 +360,7 @@ TEST_CASE("test random uniform") {
|
|||||||
|
|
||||||
{
|
{
|
||||||
auto key = random::key(0);
|
auto key = random::key(0);
|
||||||
auto out = random::uniform({100}, bfloat16, key);
|
auto out = random::uniform({1000}, bfloat16, key);
|
||||||
CHECK_EQ(out.dtype(), bfloat16);
|
CHECK_EQ(out.dtype(), bfloat16);
|
||||||
CHECK(all(less(out, array(1.0f))).item<bool>());
|
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||||||
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
|
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
|
||||||
|
Loading…
Reference in New Issue
Block a user