Add more complex unary ops (#2101)

This commit is contained in:
Awni Hannun
2025-04-21 13:04:54 -07:00
committed by GitHub
parent 79b527f45f
commit fdadc4f22c
4 changed files with 93 additions and 37 deletions

View File

@@ -2934,6 +2934,35 @@ class TestOps(mlx_tests.MLXTestCase):
out = a[::-1]
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
def test_complex_ops(self):
x = mx.array(
[
3.0 + 4.0j,
-5.0 + 12.0j,
-8.0 + 0.0j,
0.0 + 9.0j,
0.0 + 0.0j,
]
)
ops = ["arccos", "arcsin", "arctan", "square", "sqrt"]
for op in ops:
with self.subTest(op=op):
np_op = getattr(np, op)
mx_op = getattr(mx, op)
self.assertTrue(np.allclose(mx_op(x), np_op(x)))
x = mx.array(
[
3.0 + 4.0j,
-5.0 + 12.0j,
-8.0 + 0.0j,
0.0 + 9.0j,
9.0 + 1.0j,
]
)
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
if __name__ == "__main__":
unittest.main()