mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
parent
86b614afcd
commit
4e3bdb560c
@ -84,6 +84,10 @@ inline bool is_floating_point(const Dtype& t) {
|
|||||||
kindof(t) == Dtype::Kind::c;
|
kindof(t) == Dtype::Kind::c;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool is_complex(const Dtype& t) {
|
||||||
|
return kindof(t) == Dtype::Kind::c;
|
||||||
|
}
|
||||||
|
|
||||||
inline bool is_integral(const Dtype& t) {
|
inline bool is_integral(const Dtype& t) {
|
||||||
return !(is_floating_point(t));
|
return !(is_floating_point(t));
|
||||||
}
|
}
|
||||||
|
@ -80,6 +80,16 @@ array split(const array& key, int num, StreamOrDevice s /* = {} */) {
|
|||||||
return bits({num, 2}, 4, key, s);
|
return bits({num, 2}, 4, key, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the next representable value below 1.0 for half precision
|
||||||
|
// floating point types (fp16, bf16)
|
||||||
|
template <typename T>
|
||||||
|
T below_one() {
|
||||||
|
T f = T(1.0);
|
||||||
|
uint16_t* m = (uint16_t*)&f;
|
||||||
|
*m -= 1;
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
|
||||||
array uniform(
|
array uniform(
|
||||||
const array& low,
|
const array& low,
|
||||||
const array& high,
|
const array& high,
|
||||||
@ -87,9 +97,9 @@ array uniform(
|
|||||||
Dtype dtype /* = float32 */,
|
Dtype dtype /* = float32 */,
|
||||||
const std::optional<array>& key /*= nullopt */,
|
const std::optional<array>& key /*= nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (!is_floating_point(dtype)) {
|
if (!is_floating_point(dtype) && !is_complex(dtype)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Can only generate uniform numbers with floating point type.");
|
"Can only generate uniform numbers with real floating point type.");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
@ -103,12 +113,29 @@ array uniform(
|
|||||||
}
|
}
|
||||||
// Get random values between [0, nextafter(maxval, 0.0f)] since samples must
|
// Get random values between [0, nextafter(maxval, 0.0f)] since samples must
|
||||||
// be in [low, high)
|
// be in [low, high)
|
||||||
// TODO replace minimum with modulo uint32_t(nextafter(maxval, 0.0f)) to avoid
|
auto get_limits = [&dtype]() {
|
||||||
// clipping effects
|
switch (dtype) {
|
||||||
float maxval = std::numeric_limits<uint32_t>::max();
|
case float32:
|
||||||
auto upper = array(std::nextafter(maxval, 0.0f), dtype);
|
return std::make_pair(
|
||||||
auto out = minimum(bits(shape, size_of(dtype), key, stream), upper, stream);
|
array(std::nextafter(1.0f, 0.0f), float32),
|
||||||
out = divide(out, array(maxval, dtype), stream);
|
array(std::numeric_limits<uint32_t>::max(), float32));
|
||||||
|
case float16:
|
||||||
|
return std::make_pair(
|
||||||
|
array(below_one<float16_t>(), float16),
|
||||||
|
array(std::numeric_limits<uint16_t>::max(), float32));
|
||||||
|
case bfloat16:
|
||||||
|
return std::make_pair(
|
||||||
|
array(below_one<bfloat16_t>(), bfloat16),
|
||||||
|
array(std::numeric_limits<uint16_t>::max(), float32));
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("[uniform] Unsupported type.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto [upper, maxval] = get_limits();
|
||||||
|
auto out = bits(shape, size_of(dtype), key, stream);
|
||||||
|
out = astype(divide(out, maxval, stream), dtype, stream);
|
||||||
|
out = minimum(out, upper, stream);
|
||||||
return add(multiply(range, out, stream), low, stream);
|
return add(multiply(range, out, stream), low, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -344,6 +344,27 @@ TEST_CASE("test random uniform") {
|
|||||||
CHECK(all(less(out, array(1.0f))).item<bool>());
|
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||||||
CHECK(all(greater_equal(out, array(-1.0f))).item<bool>());
|
CHECK(all(greater_equal(out, array(-1.0f))).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check float16
|
||||||
|
{
|
||||||
|
auto key = random::key(0);
|
||||||
|
auto out = random::uniform({100}, float16, key);
|
||||||
|
CHECK_EQ(out.dtype(), float16);
|
||||||
|
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||||||
|
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
|
||||||
|
CHECK(!all(equal(out, array(0.0f))).item<bool>());
|
||||||
|
CHECK(abs(float(mean(out).item<float16_t>()) - 0.5f) < 0.02);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto key = random::key(0);
|
||||||
|
auto out = random::uniform({100}, bfloat16, key);
|
||||||
|
CHECK_EQ(out.dtype(), bfloat16);
|
||||||
|
CHECK(all(less(out, array(1.0f))).item<bool>());
|
||||||
|
CHECK(all(greater_equal(out, array(0.0f))).item<bool>());
|
||||||
|
CHECK(!all(equal(out, array(0.0f))).item<bool>());
|
||||||
|
CHECK(abs(float(mean(out).item<bfloat16_t>()) - 0.5f) < 0.02);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test random normal") {
|
TEST_CASE("test random normal") {
|
||||||
@ -375,6 +396,25 @@ TEST_CASE("test random normal") {
|
|||||||
auto key = random::key(128291);
|
auto key = random::key(128291);
|
||||||
auto out = random::normal({100}, key);
|
auto out = random::normal({100}, key);
|
||||||
CHECK(all(less(abs(out), array(inf))).item<bool>());
|
CHECK(all(less(abs(out), array(inf))).item<bool>());
|
||||||
|
CHECK(abs(mean(out).item<float>()) < 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||||
|
auto key = random::key(128291);
|
||||||
|
auto out = random::normal({200}, float16, key);
|
||||||
|
CHECK_EQ(out.dtype(), float16);
|
||||||
|
CHECK(all(less(abs(out), array(inf))).item<bool>());
|
||||||
|
CHECK(abs(float(mean(out).item<float16_t>())) < 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||||
|
auto key = random::key(128291);
|
||||||
|
auto out = random::normal({200}, bfloat16, key);
|
||||||
|
CHECK_EQ(out.dtype(), bfloat16);
|
||||||
|
CHECK(all(less(abs(out), array(inf))).item<bool>());
|
||||||
|
CHECK(abs(float(mean(out).item<bfloat16_t>())) < 0.1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user