diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index fe8ec5c0f..c88002cb3 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -104,10 +104,22 @@ constexpr bool operator==(complex64_t a, complex64_t b) { constexpr complex64_t operator+(complex64_t a, complex64_t b) { return {a.real + b.real, a.imag + b.imag}; } +constexpr complex64_t operator+(float a, complex64_t b) { + return {a + b.real, b.imag}; +} +constexpr complex64_t operator+(complex64_t a, float b) { + return {a.real + b, a.imag}; +} constexpr complex64_t operator-(complex64_t a, complex64_t b) { return {a.real - b.real, a.imag - b.imag}; } +constexpr complex64_t operator-(float a, complex64_t b) { + return {a - b.real, -b.imag}; +} +constexpr complex64_t operator-(complex64_t a, float b) { + return {a.real - b, a.imag}; +} constexpr complex64_t operator*(complex64_t a, complex64_t b) { return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; @@ -120,6 +132,13 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) { return {x / denom, y / denom}; } +constexpr complex64_t operator/(float a, complex64_t b) { + auto denom = b.real * b.real + b.imag * b.imag; + auto x = a * b.real; + auto y = -a * b.imag; + return {x / denom, y / denom}; +} + constexpr complex64_t operator%(complex64_t a, complex64_t b) { auto real = a.real - (b.real * static_cast(a.real / b.real)); auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 2209b0665..d34c5a7ec 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -69,6 +69,9 @@ instantiate_unary_float(Round) instantiate_unary_int(BitwiseInvert) instantiate_unary_all_same(Abs, complex64, complex64_t) +instantiate_unary_all_same(ArcCos, complex64, complex64_t) +instantiate_unary_all_same(ArcSin, complex64, complex64_t) +instantiate_unary_all_same(ArcTan, complex64, complex64_t) instantiate_unary_all_same(Conjugate, complex64, complex64_t) instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t) @@ -80,6 +83,9 @@ instantiate_unary_all_same(Negative, complex64, complex64_t) instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_all_same(Sin, complex64, complex64_t) instantiate_unary_all_same(Sinh, complex64, complex64_t) +instantiate_unary_all_same(Square, complex64, complex64_t) +instantiate_unary_all_same(Sqrt, complex64, complex64_t) +instantiate_unary_all_same(Rsqrt, complex64, complex64_t) instantiate_unary_all_same(Tan, complex64, complex64_t) instantiate_unary_all_same(Tanh, complex64, complex64_t) instantiate_unary_all_same(Round, complex64, complex64_t) diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index 52e126b40..09d9f6605 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -17,27 +17,21 @@ struct Abs { T operator()(T x) { return metal::abs(x); }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; - template <> complex64_t operator()(complex64_t x) { return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; }; @@ -48,6 +42,8 @@ struct ArcCos { T operator()(T x) { return metal::precise::acos(x); }; + + complex64_t operator()(complex64_t x); }; struct ArcCosh { @@ -62,6 +58,8 @@ struct ArcSin { T operator()(T x) { return metal::precise::asin(x); }; + + complex64_t operator()(complex64_t x); }; struct ArcSinh { @@ -76,6 +74,8 @@ struct ArcTan { T operator()(T x) { return metal::precise::atan(x); }; + + complex64_t operator()(complex64_t x); }; struct ArcTanh { @@ -97,39 +97,30 @@ struct Ceil { T operator()(T x) { return metal::ceil(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; @@ -141,7 +132,6 @@ struct Cos { return metal::precise::cos(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::cos(x.real) * metal::precise::cosh(x.imag), @@ -155,7 +145,6 @@ struct Cosh { return metal::precise::cosh(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::cosh(x.real) * metal::precise::cos(x.imag), @@ -188,7 +177,6 @@ struct Exp { T operator()(T x) { return metal::precise::exp(x); }; - template <> complex64_t operator()(complex64_t x) { auto m = metal::precise::exp(x.real); return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; @@ -207,39 +195,30 @@ struct Floor { T operator()(T x) { return metal::floor(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; @@ -258,7 +237,6 @@ struct Log { return metal::precise::log(x); }; - template <> complex64_t operator()(complex64_t x) { auto r = metal::precise::log(Abs{}(x).real); auto i = metal::precise::atan2(x.imag, x.real); @@ -272,7 +250,6 @@ struct Log2 { return metal::precise::log2(x); }; - template <> complex64_t operator()(complex64_t x) { auto y = Log{}(x); return {y.real / M_LN2_F, y.imag / M_LN2_F}; @@ -285,7 +262,6 @@ struct Log10 { return metal::precise::log10(x); }; - template <> complex64_t operator()(complex64_t x) { auto y = Log{}(x); return {y.real / M_LN10_F, y.imag / M_LN10_F}; @@ -325,7 +301,6 @@ struct Round { T operator()(T x) { return metal::rint(x); }; - template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; }; @@ -344,11 +319,9 @@ struct Sign { T operator()(T x) { return (x > T(0)) - (x < T(0)); }; - template <> uint32_t operator()(uint32_t x) { return x != 0; }; - template <> complex64_t operator()(complex64_t x) { if (x == complex64_t(0)) { return x; @@ -364,7 +337,6 @@ struct Sin { return metal::precise::sin(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::sin(x.real) * metal::precise::cosh(x.imag), @@ -378,7 +350,6 @@ struct Sinh { return metal::precise::sinh(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::sinh(x.real) * metal::precise::cos(x.imag), @@ -398,6 +369,17 @@ struct Sqrt { T operator()(T x) { return metal::precise::sqrt(x); }; + + complex64_t operator()(complex64_t x) { + if (x.real == 0.0 && x.imag == 0.0) { + return {0.0, 0.0}; + } + auto r = Abs{}(x).real; + auto a = metal::precise::sqrt((r + x.real) / 2.0); + auto b_abs = metal::precise::sqrt((r - x.real) / 2.0); + auto b = metal::copysign(b_abs, x.imag); + return {a, b}; + } }; struct Rsqrt { @@ -405,6 +387,10 @@ struct Rsqrt { T operator()(T x) { return metal::precise::rsqrt(x); }; + + complex64_t operator()(complex64_t x) { + return 1.0 / Sqrt{}(x); + } }; struct Tan { @@ -413,7 +399,6 @@ struct Tan { return metal::precise::tan(x); }; - template <> complex64_t operator()(complex64_t x) { float tan_a = metal::precise::tan(x.real); float tanh_b = metal::precise::tanh(x.imag); @@ -429,7 +414,6 @@ struct Tanh { return metal::precise::tanh(x); }; - template <> complex64_t operator()(complex64_t x) { float tanh_a = metal::precise::tanh(x.real); float tan_b = metal::precise::tan(x.imag); @@ -438,3 +422,21 @@ struct Tanh { return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; }; }; + +complex64_t ArcCos::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); + return {y.imag, -y.real}; +}; + +complex64_t ArcSin::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto y = Log{}(i * x + Sqrt{}(1.0 - x * x)); + return {y.imag, -y.real}; +}; + +complex64_t ArcTan::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto ix = i * x; + return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix)); +}; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 4fcb31f18..31ea79345 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2934,6 +2934,35 @@ class TestOps(mlx_tests.MLXTestCase): out = a[::-1] self.assertTrue(mx.array_equal(out[-1, :], a[0, :])) + def test_complex_ops(self): + x = mx.array( + [ + 3.0 + 4.0j, + -5.0 + 12.0j, + -8.0 + 0.0j, + 0.0 + 9.0j, + 0.0 + 0.0j, + ] + ) + + ops = ["arccos", "arcsin", "arctan", "square", "sqrt"] + for op in ops: + with self.subTest(op=op): + np_op = getattr(np, op) + mx_op = getattr(mx, op) + self.assertTrue(np.allclose(mx_op(x), np_op(x))) + + x = mx.array( + [ + 3.0 + 4.0j, + -5.0 + 12.0j, + -8.0 + 0.0j, + 0.0 + 9.0j, + 9.0 + 1.0j, + ] + ) + self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x))) + if __name__ == "__main__": unittest.main()