mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	Add conjugate operator (#1100)
* cpu and gpu impl * add mx.conj and array.conj() --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
		@@ -1245,6 +1245,7 @@ class TestOps(mlx_tests.MLXTestCase):
 | 
			
		||||
            "log1p",
 | 
			
		||||
            "floor",
 | 
			
		||||
            "ceil",
 | 
			
		||||
            "conjugate",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        x = 0.5
 | 
			
		||||
@@ -2258,6 +2259,19 @@ class TestOps(mlx_tests.MLXTestCase):
 | 
			
		||||
                out_np = getattr(np, op)(a_np, b_np)
 | 
			
		||||
                self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
 | 
			
		||||
 | 
			
		||||
    def test_conjugate(self):
 | 
			
		||||
        shape = (3, 5, 7)
 | 
			
		||||
        a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)
 | 
			
		||||
        a = a.astype(np.complex64)
 | 
			
		||||
        ops = ["conjugate", "conj"]
 | 
			
		||||
        for op in ops:
 | 
			
		||||
            out_mlx = getattr(mx, op)(mx.array(a))
 | 
			
		||||
            out_np = getattr(np, op)(a)
 | 
			
		||||
            self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
 | 
			
		||||
        out_mlx = mx.array(a).conj()
 | 
			
		||||
        out_np = a.conj()
 | 
			
		||||
        self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user