mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Update the loc and scale to be arrays
This commit is contained in:
@@ -365,6 +365,22 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(sample.shape, (1, 2, 3, 4))
|
||||
self.assertEqual(sample.dtype, mx.complex64)
|
||||
|
||||
sample = mx.random.normal(
|
||||
(1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0 + 1j
|
||||
)
|
||||
self.assertEqual(sample.shape, (1, 2, 3, 4))
|
||||
self.assertEqual(sample.dtype, mx.complex64)
|
||||
|
||||
def test_broadcastable_scale_loc(self):
|
||||
b = mx.random.normal((10, 2))
|
||||
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
|
||||
mx.eval(sample)
|
||||
self.assertEqual(sample.shape, (2, 10, 2))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
b = mx.random.normal((10,))
|
||||
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user