mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 16:58:08 +08:00
Add conjugate operator (#1100)
* cpu and gpu impl * add mx.conj and array.conj() --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
@@ -1245,6 +1245,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
"log1p",
|
||||
"floor",
|
||||
"ceil",
|
||||
"conjugate",
|
||||
]
|
||||
|
||||
x = 0.5
|
||||
@@ -2258,6 +2259,19 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out_np = getattr(np, op)(a_np, b_np)
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
|
||||
def test_conjugate(self):
|
||||
shape = (3, 5, 7)
|
||||
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)
|
||||
a = a.astype(np.complex64)
|
||||
ops = ["conjugate", "conj"]
|
||||
for op in ops:
|
||||
out_mlx = getattr(mx, op)(mx.array(a))
|
||||
out_np = getattr(np, op)(a)
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
out_mlx = mx.array(a).conj()
|
||||
out_np = a.conj()
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user