Fixing random.normal for half-precision dtype #642 (#904)

* Fixing random.normal for half-precision dtype #642

* Update python/tests/test_random.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Luca Arnaboldi 2024-03-26 17:58:27 +01:00 committed by GitHub
parent 28fcd2b519
commit a3ee03da01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 1 deletions

View File

@ -90,6 +90,16 @@ T below_one() {
return f;
}
// Get the next representable value above -1.0 for half precision
// floating point types (fp16, bf16)
template <typename T>
T above_minus_one() {
T f = T(-1.0);
uint16_t* m = (uint16_t*)&f;
*m -= 1;
return f;
}
array uniform(
const array& low,
const array& high,
@ -158,7 +168,17 @@ array normal(
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
auto stream = to_stream(s);
auto low = array(std::nextafter(-1.0f, 0.0f), dtype);
auto get_low = [&dtype]() {
switch (dtype) {
case float16:
return array(above_minus_one<float16_t>(), dtype);
case bfloat16:
return array(above_minus_one<bfloat16_t>(), dtype);
default:
return array(std::nextafter(-1.0f, 0.0f), dtype);
}
};
auto low = get_low();
auto high = array(1.0f, dtype);
auto samples = uniform(low, high, shape, dtype, key, stream);
samples =

View File

@ -96,6 +96,11 @@ class TestRandom(mlx_tests.MLXTestCase):
self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype)
# Test not getting -inf or inf with half precison
for hp in [mx.float16, mx.bfloat16]:
a = abs(mx.random.normal(shape=(10000,), loc=0, scale=1, dtype=hp))
self.assertTrue(mx.all(a < mx.inf))
def test_randint(self):
a = mx.random.randint(0, 1, [])
self.assertEqual(a.shape, ())