mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	Add random normal distribution for complex numbers (#2182)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							0751263dec
						
					
				
				
					commit
					130df35e1b
				
			@@ -352,6 +352,41 @@ class TestRandom(mlx_tests.MLXTestCase):
 | 
			
		||||
        x = mx.random.permutation(mx.array([[1]]))
 | 
			
		||||
        self.assertEqual(x.shape, (1, 1))
 | 
			
		||||
 | 
			
		||||
    def test_complex_normal(self):
 | 
			
		||||
        sample = mx.random.normal(tuple(), dtype=mx.complex64)
 | 
			
		||||
        self.assertEqual(sample.shape, tuple())
 | 
			
		||||
        self.assertEqual(sample.dtype, mx.complex64)
 | 
			
		||||
 | 
			
		||||
        sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64)
 | 
			
		||||
        self.assertEqual(sample.shape, (1, 2, 3, 4))
 | 
			
		||||
        self.assertEqual(sample.dtype, mx.complex64)
 | 
			
		||||
 | 
			
		||||
        sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0)
 | 
			
		||||
        self.assertEqual(sample.shape, (1, 2, 3, 4))
 | 
			
		||||
        self.assertEqual(sample.dtype, mx.complex64)
 | 
			
		||||
 | 
			
		||||
        sample = mx.random.normal(
 | 
			
		||||
            (1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0 + 1j
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(sample.shape, (1, 2, 3, 4))
 | 
			
		||||
        self.assertEqual(sample.dtype, mx.complex64)
 | 
			
		||||
 | 
			
		||||
    def test_broadcastable_scale_loc(self):
 | 
			
		||||
        b = mx.random.normal((10, 2))
 | 
			
		||||
        sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
 | 
			
		||||
        mx.eval(sample)
 | 
			
		||||
        self.assertEqual(sample.shape, (2, 10, 2))
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            b = mx.random.normal((10,))
 | 
			
		||||
            sample = mx.random.normal((2, 10, 2), loc=b, scale=b)
 | 
			
		||||
 | 
			
		||||
        b = mx.random.normal((3, 1, 2))
 | 
			
		||||
        sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b)
 | 
			
		||||
        mx.eval(sample)
 | 
			
		||||
        self.assertEqual(sample.shape, (3, 4, 2))
 | 
			
		||||
        self.assertEqual(sample.dtype, mx.float16)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user