diff --git a/mlx/backend/common/ops.h b/mlx/backend/common/ops.h index fcf0ed503..058feddbc 100644 --- a/mlx/backend/common/ops.h +++ b/mlx/backend/common/ops.h @@ -373,6 +373,10 @@ struct Sign { uint64_t operator()(uint64_t x) { return x != 0; } + + complex64_t operator()(complex64_t x) { + return x == complex64_t(0) ? x : x / std::abs(x); + } }; struct Sin { diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index ea5c05177..01eaab512 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -64,6 +64,7 @@ instantiate_unary_all(Cos, complex64, complex64_t) instantiate_unary_all(Cosh, complex64, complex64_t) instantiate_unary_all(Exp, complex64, complex64_t) instantiate_unary_all(Negative, complex64, complex64_t) +instantiate_unary_all(Sign, complex64, complex64_t) instantiate_unary_all(Sin, complex64, complex64_t) instantiate_unary_all(Sinh, complex64, complex64_t) instantiate_unary_all(Tan, complex64, complex64_t) diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index fb4c6dbe5..857185a9d 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -308,6 +308,14 @@ struct Sign { uint32_t operator()(uint32_t x) { return x != 0; }; + template <> + complex64_t operator()(complex64_t x) { + if (x == complex64_t(0)) { + return x; + } + return x / + (complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag); + }; }; struct Sin { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index c17190126..8abe74471 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -760,6 +760,15 @@ class TestOps(mlx_tests.MLXTestCase): expected = np.sign(a, dtype=np.float32) self.assertTrue(np.allclose(result, expected)) + a = mx.array([-1.0, 1.0, 0.0, -2.0, 3.0]) + b = mx.array([-4.0, -3.0, 1.0, 0.0, 3.0]) + c = a + b * 1j + result = mx.sign(c) + # np.sign differs in NumPy 1 and 2 so + # we manually implement the NumPy 2 version here. + expected = c / np.abs(c) + self.assertTrue(np.allclose(result, expected)) + def test_logical_not(self): a = mx.array([-1.0, 1.0, 0.0, 1.0, -2.0, 3.0]) result = mx.logical_not(a)