diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 9efbfb5f6..e4a68f05d 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -352,6 +352,19 @@ class TestRandom(mlx_tests.MLXTestCase): x = mx.random.permutation(mx.array([[1]])) self.assertEqual(x.shape, (1, 1)) + def test_complex_normal(self): + sample = mx.random.normal(tuple(), dtype=mx.complex64) + self.assertEqual(sample.shape, tuple()) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + if __name__ == "__main__": unittest.main()