Add conjugate operator (#1100)

* cpu and gpu impl

* add mx.conj and array.conj()

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron
2024-05-10 07:22:20 -07:00
committed by GitHub
parent 8bd6bfa4b5
commit 2e158cf6d0
17 changed files with 143 additions and 11 deletions

View File

@@ -1245,6 +1245,7 @@ class TestOps(mlx_tests.MLXTestCase):
"log1p",
"floor",
"ceil",
"conjugate",
]
x = 0.5
@@ -2258,6 +2259,19 @@ class TestOps(mlx_tests.MLXTestCase):
out_np = getattr(np, op)(a_np, b_np)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
def test_conjugate(self):
shape = (3, 5, 7)
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)
a = a.astype(np.complex64)
ops = ["conjugate", "conj"]
for op in ops:
out_mlx = getattr(mx, op)(mx.array(a))
out_np = getattr(np, op)(a)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
out_mlx = mx.array(a).conj()
out_np = a.conj()
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
if __name__ == "__main__":
unittest.main()