Update the loc and scale to be arrays

This commit is contained in:
Angelos Katharopoulos
2025-05-13 15:20:47 -07:00
parent c0cac3755c
commit 49878758e5
4 changed files with 72 additions and 29 deletions

View File

@@ -365,6 +365,22 @@ class TestRandom(mlx_tests.MLXTestCase):
self.assertEqual(sample.shape, (1, 2, 3, 4))
self.assertEqual(sample.dtype, mx.complex64)
sample = mx.random.normal(
(1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0 + 1j
)
self.assertEqual(sample.shape, (1, 2, 3, 4))
self.assertEqual(sample.dtype, mx.complex64)
def test_broadcastable_scale_loc(self):
b = mx.random.normal((10, 2))
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
mx.eval(sample)
self.assertEqual(sample.shape, (2, 10, 2))
with self.assertRaises(ValueError):
b = mx.random.normal((10,))
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
if __name__ == "__main__":
unittest.main()