Fixing random.normal for half-precision dtype #642 (#904)

* Fixing random.normal for half-precision dtype #642

* Update python/tests/test_random.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Luca Arnaboldi
2024-03-26 17:58:27 +01:00
committed by GitHub
parent 28fcd2b519
commit a3ee03da01
2 changed files with 26 additions and 1 deletions

View File

@@ -96,6 +96,11 @@ class TestRandom(mlx_tests.MLXTestCase):
self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype)
# Test not getting -inf or inf with half precison
for hp in [mx.float16, mx.bfloat16]:
a = abs(mx.random.normal(shape=(10000,), loc=0, scale=1, dtype=hp))
self.assertTrue(mx.all(a < mx.inf))
def test_randint(self):
a = mx.random.randint(0, 1, [])
self.assertEqual(a.shape, ())