mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Cast loc and scale to requested dtype
This commit is contained in:
@@ -381,6 +381,12 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
b = mx.random.normal((10,))
|
||||
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
|
||||
|
||||
b = mx.random.normal((3, 1, 2))
|
||||
sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b)
|
||||
mx.eval(sample)
|
||||
self.assertEqual(sample.shape, (3, 4, 2))
|
||||
self.assertEqual(sample.dtype, mx.float16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user