From c0cac3755c8d4bc5224f1290e6eb6e98085d22e7 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 12 May 2025 23:59:24 -0700 Subject: [PATCH] Add a test --- python/tests/test_random.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 9efbfb5f6..e4a68f05d 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -352,6 +352,19 @@ class TestRandom(mlx_tests.MLXTestCase): x = mx.random.permutation(mx.array([[1]])) self.assertEqual(x.shape, (1, 1)) + def test_complex_normal(self): + sample = mx.random.normal(tuple(), dtype=mx.complex64) + self.assertEqual(sample.shape, tuple()) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + if __name__ == "__main__": unittest.main()