mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 10:41:14 +08:00
Add a test
This commit is contained in:
parent
3d93f799df
commit
c0cac3755c
@ -352,6 +352,19 @@ class TestRandom(mlx_tests.MLXTestCase):
|
|||||||
x = mx.random.permutation(mx.array([[1]]))
|
x = mx.random.permutation(mx.array([[1]]))
|
||||||
self.assertEqual(x.shape, (1, 1))
|
self.assertEqual(x.shape, (1, 1))
|
||||||
|
|
||||||
|
def test_complex_normal(self):
|
||||||
|
sample = mx.random.normal(tuple(), dtype=mx.complex64)
|
||||||
|
self.assertEqual(sample.shape, tuple())
|
||||||
|
self.assertEqual(sample.dtype, mx.complex64)
|
||||||
|
|
||||||
|
sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64)
|
||||||
|
self.assertEqual(sample.shape, (1, 2, 3, 4))
|
||||||
|
self.assertEqual(sample.dtype, mx.complex64)
|
||||||
|
|
||||||
|
sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0)
|
||||||
|
self.assertEqual(sample.shape, (1, 2, 3, 4))
|
||||||
|
self.assertEqual(sample.dtype, mx.complex64)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user