Add random normal distribution for complex numbers (#2182)

This commit is contained in:
Angelos Katharopoulos
2025-05-13 22:43:45 -07:00
committed by GitHub
parent 0751263dec
commit 130df35e1b
4 changed files with 109 additions and 24 deletions

View File

@@ -352,6 +352,41 @@ class TestRandom(mlx_tests.MLXTestCase):
x = mx.random.permutation(mx.array([[1]]))
self.assertEqual(x.shape, (1, 1))
def test_complex_normal(self):
sample = mx.random.normal(tuple(), dtype=mx.complex64)
self.assertEqual(sample.shape, tuple())
self.assertEqual(sample.dtype, mx.complex64)
sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64)
self.assertEqual(sample.shape, (1, 2, 3, 4))
self.assertEqual(sample.dtype, mx.complex64)
sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0)
self.assertEqual(sample.shape, (1, 2, 3, 4))
self.assertEqual(sample.dtype, mx.complex64)
sample = mx.random.normal(
(1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0 + 1j
)
self.assertEqual(sample.shape, (1, 2, 3, 4))
self.assertEqual(sample.dtype, mx.complex64)
def test_broadcastable_scale_loc(self):
b = mx.random.normal((10, 2))
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
mx.eval(sample)
self.assertEqual(sample.shape, (2, 10, 2))
with self.assertRaises(ValueError):
b = mx.random.normal((10,))
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
b = mx.random.normal((3, 1, 2))
sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b)
mx.eval(sample)
self.assertEqual(sample.shape, (3, 4, 2))
self.assertEqual(sample.dtype, mx.float16)
if __name__ == "__main__":
unittest.main()