mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-09 08:34:18 +08:00
Add random normal distribution for complex numbers (#2182)
This commit is contained in:

committed by
GitHub

parent
0751263dec
commit
130df35e1b
@@ -352,6 +352,41 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
x = mx.random.permutation(mx.array([[1]]))
|
||||
self.assertEqual(x.shape, (1, 1))
|
||||
|
||||
def test_complex_normal(self):
|
||||
sample = mx.random.normal(tuple(), dtype=mx.complex64)
|
||||
self.assertEqual(sample.shape, tuple())
|
||||
self.assertEqual(sample.dtype, mx.complex64)
|
||||
|
||||
sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64)
|
||||
self.assertEqual(sample.shape, (1, 2, 3, 4))
|
||||
self.assertEqual(sample.dtype, mx.complex64)
|
||||
|
||||
sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0)
|
||||
self.assertEqual(sample.shape, (1, 2, 3, 4))
|
||||
self.assertEqual(sample.dtype, mx.complex64)
|
||||
|
||||
sample = mx.random.normal(
|
||||
(1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0 + 1j
|
||||
)
|
||||
self.assertEqual(sample.shape, (1, 2, 3, 4))
|
||||
self.assertEqual(sample.dtype, mx.complex64)
|
||||
|
||||
def test_broadcastable_scale_loc(self):
|
||||
b = mx.random.normal((10, 2))
|
||||
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
|
||||
mx.eval(sample)
|
||||
self.assertEqual(sample.shape, (2, 10, 2))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
b = mx.random.normal((10,))
|
||||
sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
|
||||
|
||||
b = mx.random.normal((3, 1, 2))
|
||||
sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b)
|
||||
mx.eval(sample)
|
||||
self.assertEqual(sample.shape, (3, 4, 2))
|
||||
self.assertEqual(sample.dtype, mx.float16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user