random.uniform must respect dtype, even if lower precision than "low" (#280)

Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
This commit is contained in:
Daniel Strobusch
2023-12-24 16:04:43 +01:00
committed by GitHub
parent 8b227fa9af
commit 7365d142a3
3 changed files with 11 additions and 2 deletions

View File

@@ -58,6 +58,9 @@ class TestRandom(mlx_tests.MLXTestCase):
a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)
self.assertTrue(mx.all((a > -1) < 5).item())
a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)
self.assertEqual(a.dtype, mx.bfloat16)
def test_normal(self):
key = mx.random.key(0)
a = mx.random.normal(key=key)