Add more complex unary ops (#2101)

This commit is contained in:
Awni Hannun 2025-04-21 13:04:54 -07:00 committed by GitHub
parent 79b527f45f
commit fdadc4f22c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 93 additions and 37 deletions

View File

@ -104,10 +104,22 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
constexpr complex64_t operator+(complex64_t a, complex64_t b) { constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag}; return {a.real + b.real, a.imag + b.imag};
} }
constexpr complex64_t operator+(float a, complex64_t b) {
return {a + b.real, b.imag};
}
constexpr complex64_t operator+(complex64_t a, float b) {
return {a.real + b, a.imag};
}
constexpr complex64_t operator-(complex64_t a, complex64_t b) { constexpr complex64_t operator-(complex64_t a, complex64_t b) {
return {a.real - b.real, a.imag - b.imag}; return {a.real - b.real, a.imag - b.imag};
} }
constexpr complex64_t operator-(float a, complex64_t b) {
return {a - b.real, -b.imag};
}
constexpr complex64_t operator-(complex64_t a, float b) {
return {a.real - b, a.imag};
}
constexpr complex64_t operator*(complex64_t a, complex64_t b) { constexpr complex64_t operator*(complex64_t a, complex64_t b) {
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
@ -120,6 +132,13 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
return {x / denom, y / denom}; return {x / denom, y / denom};
} }
constexpr complex64_t operator/(float a, complex64_t b) {
auto denom = b.real * b.real + b.imag * b.imag;
auto x = a * b.real;
auto y = -a * b.imag;
return {x / denom, y / denom};
}
constexpr complex64_t operator%(complex64_t a, complex64_t b) { constexpr complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real)); auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag)); auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));

View File

@ -69,6 +69,9 @@ instantiate_unary_float(Round)
instantiate_unary_int(BitwiseInvert) instantiate_unary_int(BitwiseInvert)
instantiate_unary_all_same(Abs, complex64, complex64_t) instantiate_unary_all_same(Abs, complex64, complex64_t)
instantiate_unary_all_same(ArcCos, complex64, complex64_t)
instantiate_unary_all_same(ArcSin, complex64, complex64_t)
instantiate_unary_all_same(ArcTan, complex64, complex64_t)
instantiate_unary_all_same(Conjugate, complex64, complex64_t) instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t)
@ -80,6 +83,9 @@ instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t) instantiate_unary_all_same(Sin, complex64, complex64_t)
instantiate_unary_all_same(Sinh, complex64, complex64_t) instantiate_unary_all_same(Sinh, complex64, complex64_t)
instantiate_unary_all_same(Square, complex64, complex64_t)
instantiate_unary_all_same(Sqrt, complex64, complex64_t)
instantiate_unary_all_same(Rsqrt, complex64, complex64_t)
instantiate_unary_all_same(Tan, complex64, complex64_t) instantiate_unary_all_same(Tan, complex64, complex64_t)
instantiate_unary_all_same(Tanh, complex64, complex64_t) instantiate_unary_all_same(Tanh, complex64, complex64_t)
instantiate_unary_all_same(Round, complex64, complex64_t) instantiate_unary_all_same(Round, complex64, complex64_t)

View File

@ -17,27 +17,21 @@ struct Abs {
T operator()(T x) { T operator()(T x) {
return metal::abs(x); return metal::abs(x);
}; };
template <>
uint8_t operator()(uint8_t x) { uint8_t operator()(uint8_t x) {
return x; return x;
}; };
template <>
uint16_t operator()(uint16_t x) { uint16_t operator()(uint16_t x) {
return x; return x;
}; };
template <>
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x; return x;
}; };
template <>
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x; return x;
}; };
template <>
bool operator()(bool x) { bool operator()(bool x) {
return x; return x;
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
}; };
@ -48,6 +42,8 @@ struct ArcCos {
T operator()(T x) { T operator()(T x) {
return metal::precise::acos(x); return metal::precise::acos(x);
}; };
complex64_t operator()(complex64_t x);
}; };
struct ArcCosh { struct ArcCosh {
@ -62,6 +58,8 @@ struct ArcSin {
T operator()(T x) { T operator()(T x) {
return metal::precise::asin(x); return metal::precise::asin(x);
}; };
complex64_t operator()(complex64_t x);
}; };
struct ArcSinh { struct ArcSinh {
@ -76,6 +74,8 @@ struct ArcTan {
T operator()(T x) { T operator()(T x) {
return metal::precise::atan(x); return metal::precise::atan(x);
}; };
complex64_t operator()(complex64_t x);
}; };
struct ArcTanh { struct ArcTanh {
@ -97,39 +97,30 @@ struct Ceil {
T operator()(T x) { T operator()(T x) {
return metal::ceil(x); return metal::ceil(x);
}; };
template <>
int8_t operator()(int8_t x) { int8_t operator()(int8_t x) {
return x; return x;
}; };
template <>
int16_t operator()(int16_t x) { int16_t operator()(int16_t x) {
return x; return x;
}; };
template <>
int32_t operator()(int32_t x) { int32_t operator()(int32_t x) {
return x; return x;
}; };
template <>
int64_t operator()(int64_t x) { int64_t operator()(int64_t x) {
return x; return x;
}; };
template <>
uint8_t operator()(uint8_t x) { uint8_t operator()(uint8_t x) {
return x; return x;
}; };
template <>
uint16_t operator()(uint16_t x) { uint16_t operator()(uint16_t x) {
return x; return x;
}; };
template <>
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x; return x;
}; };
template <>
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x; return x;
}; };
template <>
bool operator()(bool x) { bool operator()(bool x) {
return x; return x;
}; };
@ -141,7 +132,6 @@ struct Cos {
return metal::precise::cos(x); return metal::precise::cos(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return { return {
metal::precise::cos(x.real) * metal::precise::cosh(x.imag), metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
@ -155,7 +145,6 @@ struct Cosh {
return metal::precise::cosh(x); return metal::precise::cosh(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return { return {
metal::precise::cosh(x.real) * metal::precise::cos(x.imag), metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
@ -188,7 +177,6 @@ struct Exp {
T operator()(T x) { T operator()(T x) {
return metal::precise::exp(x); return metal::precise::exp(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
auto m = metal::precise::exp(x.real); auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
@ -207,39 +195,30 @@ struct Floor {
T operator()(T x) { T operator()(T x) {
return metal::floor(x); return metal::floor(x);
}; };
template <>
int8_t operator()(int8_t x) { int8_t operator()(int8_t x) {
return x; return x;
}; };
template <>
int16_t operator()(int16_t x) { int16_t operator()(int16_t x) {
return x; return x;
}; };
template <>
int32_t operator()(int32_t x) { int32_t operator()(int32_t x) {
return x; return x;
}; };
template <>
int64_t operator()(int64_t x) { int64_t operator()(int64_t x) {
return x; return x;
}; };
template <>
uint8_t operator()(uint8_t x) { uint8_t operator()(uint8_t x) {
return x; return x;
}; };
template <>
uint16_t operator()(uint16_t x) { uint16_t operator()(uint16_t x) {
return x; return x;
}; };
template <>
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x; return x;
}; };
template <>
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x; return x;
}; };
template <>
bool operator()(bool x) { bool operator()(bool x) {
return x; return x;
}; };
@ -258,7 +237,6 @@ struct Log {
return metal::precise::log(x); return metal::precise::log(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
auto r = metal::precise::log(Abs{}(x).real); auto r = metal::precise::log(Abs{}(x).real);
auto i = metal::precise::atan2(x.imag, x.real); auto i = metal::precise::atan2(x.imag, x.real);
@ -272,7 +250,6 @@ struct Log2 {
return metal::precise::log2(x); return metal::precise::log2(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
auto y = Log{}(x); auto y = Log{}(x);
return {y.real / M_LN2_F, y.imag / M_LN2_F}; return {y.real / M_LN2_F, y.imag / M_LN2_F};
@ -285,7 +262,6 @@ struct Log10 {
return metal::precise::log10(x); return metal::precise::log10(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
auto y = Log{}(x); auto y = Log{}(x);
return {y.real / M_LN10_F, y.imag / M_LN10_F}; return {y.real / M_LN10_F, y.imag / M_LN10_F};
@ -325,7 +301,6 @@ struct Round {
T operator()(T x) { T operator()(T x) {
return metal::rint(x); return metal::rint(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return {metal::rint(x.real), metal::rint(x.imag)}; return {metal::rint(x.real), metal::rint(x.imag)};
}; };
@ -344,11 +319,9 @@ struct Sign {
T operator()(T x) { T operator()(T x) {
return (x > T(0)) - (x < T(0)); return (x > T(0)) - (x < T(0));
}; };
template <>
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x != 0; return x != 0;
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
if (x == complex64_t(0)) { if (x == complex64_t(0)) {
return x; return x;
@ -364,7 +337,6 @@ struct Sin {
return metal::precise::sin(x); return metal::precise::sin(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return { return {
metal::precise::sin(x.real) * metal::precise::cosh(x.imag), metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
@ -378,7 +350,6 @@ struct Sinh {
return metal::precise::sinh(x); return metal::precise::sinh(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return { return {
metal::precise::sinh(x.real) * metal::precise::cos(x.imag), metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
@ -398,6 +369,17 @@ struct Sqrt {
T operator()(T x) { T operator()(T x) {
return metal::precise::sqrt(x); return metal::precise::sqrt(x);
}; };
complex64_t operator()(complex64_t x) {
if (x.real == 0.0 && x.imag == 0.0) {
return {0.0, 0.0};
}
auto r = Abs{}(x).real;
auto a = metal::precise::sqrt((r + x.real) / 2.0);
auto b_abs = metal::precise::sqrt((r - x.real) / 2.0);
auto b = metal::copysign(b_abs, x.imag);
return {a, b};
}
}; };
struct Rsqrt { struct Rsqrt {
@ -405,6 +387,10 @@ struct Rsqrt {
T operator()(T x) { T operator()(T x) {
return metal::precise::rsqrt(x); return metal::precise::rsqrt(x);
}; };
complex64_t operator()(complex64_t x) {
return 1.0 / Sqrt{}(x);
}
}; };
struct Tan { struct Tan {
@ -413,7 +399,6 @@ struct Tan {
return metal::precise::tan(x); return metal::precise::tan(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
float tan_a = metal::precise::tan(x.real); float tan_a = metal::precise::tan(x.real);
float tanh_b = metal::precise::tanh(x.imag); float tanh_b = metal::precise::tanh(x.imag);
@ -429,7 +414,6 @@ struct Tanh {
return metal::precise::tanh(x); return metal::precise::tanh(x);
}; };
template <>
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
float tanh_a = metal::precise::tanh(x.real); float tanh_a = metal::precise::tanh(x.real);
float tan_b = metal::precise::tan(x.imag); float tan_b = metal::precise::tan(x.imag);
@ -438,3 +422,21 @@ struct Tanh {
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
}; };
}; };
complex64_t ArcCos::operator()(complex64_t x) {
auto i = complex64_t{0.0, 1.0};
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
return {y.imag, -y.real};
};
complex64_t ArcSin::operator()(complex64_t x) {
auto i = complex64_t{0.0, 1.0};
auto y = Log{}(i * x + Sqrt{}(1.0 - x * x));
return {y.imag, -y.real};
};
complex64_t ArcTan::operator()(complex64_t x) {
auto i = complex64_t{0.0, 1.0};
auto ix = i * x;
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
};

View File

@ -2934,6 +2934,35 @@ class TestOps(mlx_tests.MLXTestCase):
out = a[::-1] out = a[::-1]
self.assertTrue(mx.array_equal(out[-1, :], a[0, :])) self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
def test_complex_ops(self):
x = mx.array(
[
3.0 + 4.0j,
-5.0 + 12.0j,
-8.0 + 0.0j,
0.0 + 9.0j,
0.0 + 0.0j,
]
)
ops = ["arccos", "arcsin", "arctan", "square", "sqrt"]
for op in ops:
with self.subTest(op=op):
np_op = getattr(np, op)
mx_op = getattr(mx, op)
self.assertTrue(np.allclose(mx_op(x), np_op(x)))
x = mx.array(
[
3.0 + 4.0j,
-5.0 + 12.0j,
-8.0 + 0.0j,
0.0 + 9.0j,
9.0 + 1.0j,
]
)
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()