mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
@@ -84,6 +84,10 @@ inline bool is_floating_point(const Dtype& t) {
|
||||
kindof(t) == Dtype::Kind::c;
|
||||
}
|
||||
|
||||
inline bool is_complex(const Dtype& t) {
|
||||
return kindof(t) == Dtype::Kind::c;
|
||||
}
|
||||
|
||||
inline bool is_integral(const Dtype& t) {
|
||||
return !(is_floating_point(t));
|
||||
}
|
||||
|
@@ -80,6 +80,16 @@ array split(const array& key, int num, StreamOrDevice s /* = {} */) {
|
||||
return bits({num, 2}, 4, key, s);
|
||||
}
|
||||
|
||||
// Get the next representable value below 1.0 for half precision
|
||||
// floating point types (fp16, bf16)
|
||||
template <typename T>
|
||||
T below_one() {
|
||||
T f = T(1.0);
|
||||
uint16_t* m = (uint16_t*)&f;
|
||||
*m -= 1;
|
||||
return f;
|
||||
}
|
||||
|
||||
array uniform(
|
||||
const array& low,
|
||||
const array& high,
|
||||
@@ -87,9 +97,9 @@ array uniform(
|
||||
Dtype dtype /* = float32 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!is_floating_point(dtype)) {
|
||||
if (!is_floating_point(dtype) && !is_complex(dtype)) {
|
||||
throw std::invalid_argument(
|
||||
"Can only generate uniform numbers with floating point type.");
|
||||
"Can only generate uniform numbers with real floating point type.");
|
||||
}
|
||||
|
||||
auto stream = to_stream(s);
|
||||
@@ -103,12 +113,29 @@ array uniform(
|
||||
}
|
||||
// Get random values between [0, nextafter(maxval, 0.0f)] since samples must
|
||||
// be in [low, high)
|
||||
// TODO replace minimum with modulo uint32_t(nextafter(maxval, 0.0f)) to avoid
|
||||
// clipping effects
|
||||
float maxval = std::numeric_limits<uint32_t>::max();
|
||||
auto upper = array(std::nextafter(maxval, 0.0f), dtype);
|
||||
auto out = minimum(bits(shape, size_of(dtype), key, stream), upper, stream);
|
||||
out = divide(out, array(maxval, dtype), stream);
|
||||
auto get_limits = [&dtype]() {
|
||||
switch (dtype) {
|
||||
case float32:
|
||||
return std::make_pair(
|
||||
array(std::nextafter(1.0f, 0.0f), float32),
|
||||
array(std::numeric_limits<uint32_t>::max(), float32));
|
||||
case float16:
|
||||
return std::make_pair(
|
||||
array(below_one<float16_t>(), float16),
|
||||
array(std::numeric_limits<uint16_t>::max(), float32));
|
||||
case bfloat16:
|
||||
return std::make_pair(
|
||||
array(below_one<bfloat16_t>(), bfloat16),
|
||||
array(std::numeric_limits<uint16_t>::max(), float32));
|
||||
default:
|
||||
throw std::runtime_error("[uniform] Unsupported type.");
|
||||
}
|
||||
};
|
||||
|
||||
auto [upper, maxval] = get_limits();
|
||||
auto out = bits(shape, size_of(dtype), key, stream);
|
||||
out = astype(divide(out, maxval, stream), dtype, stream);
|
||||
out = minimum(out, upper, stream);
|
||||
return add(multiply(range, out, stream), low, stream);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user