mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
random.uniform must respect dtype, even if lower precision than "low" (#280)
Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
This commit is contained in:
@@ -58,6 +58,9 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)
|
||||
self.assertTrue(mx.all((a > -1) < 5).item())
|
||||
|
||||
a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)
|
||||
self.assertEqual(a.dtype, mx.bfloat16)
|
||||
|
||||
def test_normal(self):
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.normal(key=key)
|
||||
|
Reference in New Issue
Block a user