Add a test

This commit is contained in:
Angelos Katharopoulos 2025-05-12 23:59:24 -07:00
parent 3d93f799df
commit c0cac3755c

View File

@ -352,6 +352,19 @@ class TestRandom(mlx_tests.MLXTestCase):
x = mx.random.permutation(mx.array([[1]])) x = mx.random.permutation(mx.array([[1]]))
self.assertEqual(x.shape, (1, 1)) self.assertEqual(x.shape, (1, 1))
def test_complex_normal(self):
sample = mx.random.normal(tuple(), dtype=mx.complex64)
self.assertEqual(sample.shape, tuple())
self.assertEqual(sample.dtype, mx.complex64)
sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64)
self.assertEqual(sample.shape, (1, 2, 3, 4))
self.assertEqual(sample.dtype, mx.complex64)
sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0)
self.assertEqual(sample.shape, (1, 2, 3, 4))
self.assertEqual(sample.dtype, mx.complex64)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()