mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 15:04:40 +08:00
Add more complex unary ops (#2101)
This commit is contained in:
@@ -2934,6 +2934,35 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = a[::-1]
|
||||
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
|
||||
|
||||
def test_complex_ops(self):
|
||||
x = mx.array(
|
||||
[
|
||||
3.0 + 4.0j,
|
||||
-5.0 + 12.0j,
|
||||
-8.0 + 0.0j,
|
||||
0.0 + 9.0j,
|
||||
0.0 + 0.0j,
|
||||
]
|
||||
)
|
||||
|
||||
ops = ["arccos", "arcsin", "arctan", "square", "sqrt"]
|
||||
for op in ops:
|
||||
with self.subTest(op=op):
|
||||
np_op = getattr(np, op)
|
||||
mx_op = getattr(mx, op)
|
||||
self.assertTrue(np.allclose(mx_op(x), np_op(x)))
|
||||
|
||||
x = mx.array(
|
||||
[
|
||||
3.0 + 4.0j,
|
||||
-5.0 + 12.0j,
|
||||
-8.0 + 0.0j,
|
||||
0.0 + 9.0j,
|
||||
9.0 + 1.0j,
|
||||
]
|
||||
)
|
||||
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user