Fixing random.normal for half-precision dtype #642 (#904)

* Fixing random.normal for half-precision dtype #642

* Update python/tests/test_random.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Luca Arnaboldi
2024-03-26 17:58:27 +01:00
committed by GitHub
parent 28fcd2b519
commit a3ee03da01
2 changed files with 26 additions and 1 deletions

View File

@@ -90,6 +90,16 @@ T below_one() {
return f;
}
// Get the next representable value above -1.0 for half precision
// floating point types (fp16, bf16)
template <typename T>
T above_minus_one() {
T f = T(-1.0);
uint16_t* m = (uint16_t*)&f;
*m -= 1;
return f;
}
array uniform(
const array& low,
const array& high,
@@ -158,7 +168,17 @@ array normal(
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
auto stream = to_stream(s);
auto low = array(std::nextafter(-1.0f, 0.0f), dtype);
auto get_low = [&dtype]() {
switch (dtype) {
case float16:
return array(above_minus_one<float16_t>(), dtype);
case bfloat16:
return array(above_minus_one<bfloat16_t>(), dtype);
default:
return array(std::nextafter(-1.0f, 0.0f), dtype);
}
};
auto low = get_low();
auto high = array(1.0f, dtype);
auto samples = uniform(low, high, shape, dtype, key, stream);
samples =