mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
* Fixing random.normal for half-precision dtype #642 * Update python/tests/test_random.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -96,6 +96,11 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype)
|
||||
|
||||
# Test not getting -inf or inf with half precison
|
||||
for hp in [mx.float16, mx.bfloat16]:
|
||||
a = abs(mx.random.normal(shape=(10000,), loc=0, scale=1, dtype=hp))
|
||||
self.assertTrue(mx.all(a < mx.inf))
|
||||
|
||||
def test_randint(self):
|
||||
a = mx.random.randint(0, 1, [])
|
||||
self.assertEqual(a.shape, ())
|
||||
|
Reference in New Issue
Block a user