mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Add more complex unary ops (#2101)
This commit is contained in:
parent
79b527f45f
commit
fdadc4f22c
@ -104,10 +104,22 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
|
||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||
return {a.real + b.real, a.imag + b.imag};
|
||||
}
|
||||
constexpr complex64_t operator+(float a, complex64_t b) {
|
||||
return {a + b.real, b.imag};
|
||||
}
|
||||
constexpr complex64_t operator+(complex64_t a, float b) {
|
||||
return {a.real + b, a.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||
return {a.real - b.real, a.imag - b.imag};
|
||||
}
|
||||
constexpr complex64_t operator-(float a, complex64_t b) {
|
||||
return {a - b.real, -b.imag};
|
||||
}
|
||||
constexpr complex64_t operator-(complex64_t a, float b) {
|
||||
return {a.real - b, a.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||
@ -120,6 +132,13 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
||||
return {x / denom, y / denom};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator/(float a, complex64_t b) {
|
||||
auto denom = b.real * b.real + b.imag * b.imag;
|
||||
auto x = a * b.real;
|
||||
auto y = -a * b.imag;
|
||||
return {x / denom, y / denom};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||
|
@ -69,6 +69,9 @@ instantiate_unary_float(Round)
|
||||
instantiate_unary_int(BitwiseInvert)
|
||||
|
||||
instantiate_unary_all_same(Abs, complex64, complex64_t)
|
||||
instantiate_unary_all_same(ArcCos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(ArcSin, complex64, complex64_t)
|
||||
instantiate_unary_all_same(ArcTan, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||
@ -80,6 +83,9 @@ instantiate_unary_all_same(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Square, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sqrt, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Rsqrt, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Tan, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Tanh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Round, complex64, complex64_t)
|
||||
|
@ -17,27 +17,21 @@ struct Abs {
|
||||
T operator()(T x) {
|
||||
return metal::abs(x);
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||
};
|
||||
@ -48,6 +42,8 @@ struct ArcCos {
|
||||
T operator()(T x) {
|
||||
return metal::precise::acos(x);
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x);
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
@ -62,6 +58,8 @@ struct ArcSin {
|
||||
T operator()(T x) {
|
||||
return metal::precise::asin(x);
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x);
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
@ -76,6 +74,8 @@ struct ArcTan {
|
||||
T operator()(T x) {
|
||||
return metal::precise::atan(x);
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x);
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
@ -97,39 +97,30 @@ struct Ceil {
|
||||
T operator()(T x) {
|
||||
return metal::ceil(x);
|
||||
};
|
||||
template <>
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
@ -141,7 +132,6 @@ struct Cos {
|
||||
return metal::precise::cos(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||
@ -155,7 +145,6 @@ struct Cosh {
|
||||
return metal::precise::cosh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||
@ -188,7 +177,6 @@ struct Exp {
|
||||
T operator()(T x) {
|
||||
return metal::precise::exp(x);
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto m = metal::precise::exp(x.real);
|
||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||
@ -207,39 +195,30 @@ struct Floor {
|
||||
T operator()(T x) {
|
||||
return metal::floor(x);
|
||||
};
|
||||
template <>
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
@ -258,7 +237,6 @@ struct Log {
|
||||
return metal::precise::log(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto r = metal::precise::log(Abs{}(x).real);
|
||||
auto i = metal::precise::atan2(x.imag, x.real);
|
||||
@ -272,7 +250,6 @@ struct Log2 {
|
||||
return metal::precise::log2(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto y = Log{}(x);
|
||||
return {y.real / M_LN2_F, y.imag / M_LN2_F};
|
||||
@ -285,7 +262,6 @@ struct Log10 {
|
||||
return metal::precise::log10(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto y = Log{}(x);
|
||||
return {y.real / M_LN10_F, y.imag / M_LN10_F};
|
||||
@ -325,7 +301,6 @@ struct Round {
|
||||
T operator()(T x) {
|
||||
return metal::rint(x);
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {metal::rint(x.real), metal::rint(x.imag)};
|
||||
};
|
||||
@ -344,11 +319,9 @@ struct Sign {
|
||||
T operator()(T x) {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x != 0;
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
if (x == complex64_t(0)) {
|
||||
return x;
|
||||
@ -364,7 +337,6 @@ struct Sin {
|
||||
return metal::precise::sin(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||
@ -378,7 +350,6 @@ struct Sinh {
|
||||
return metal::precise::sinh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||
@ -398,6 +369,17 @@ struct Sqrt {
|
||||
T operator()(T x) {
|
||||
return metal::precise::sqrt(x);
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
if (x.real == 0.0 && x.imag == 0.0) {
|
||||
return {0.0, 0.0};
|
||||
}
|
||||
auto r = Abs{}(x).real;
|
||||
auto a = metal::precise::sqrt((r + x.real) / 2.0);
|
||||
auto b_abs = metal::precise::sqrt((r - x.real) / 2.0);
|
||||
auto b = metal::copysign(b_abs, x.imag);
|
||||
return {a, b};
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
@ -405,6 +387,10 @@ struct Rsqrt {
|
||||
T operator()(T x) {
|
||||
return metal::precise::rsqrt(x);
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return 1.0 / Sqrt{}(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
@ -413,7 +399,6 @@ struct Tan {
|
||||
return metal::precise::tan(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tan_a = metal::precise::tan(x.real);
|
||||
float tanh_b = metal::precise::tanh(x.imag);
|
||||
@ -429,7 +414,6 @@ struct Tanh {
|
||||
return metal::precise::tanh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tanh_a = metal::precise::tanh(x.real);
|
||||
float tan_b = metal::precise::tan(x.imag);
|
||||
@ -438,3 +422,21 @@ struct Tanh {
|
||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||
};
|
||||
};
|
||||
|
||||
complex64_t ArcCos::operator()(complex64_t x) {
|
||||
auto i = complex64_t{0.0, 1.0};
|
||||
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
|
||||
return {y.imag, -y.real};
|
||||
};
|
||||
|
||||
complex64_t ArcSin::operator()(complex64_t x) {
|
||||
auto i = complex64_t{0.0, 1.0};
|
||||
auto y = Log{}(i * x + Sqrt{}(1.0 - x * x));
|
||||
return {y.imag, -y.real};
|
||||
};
|
||||
|
||||
complex64_t ArcTan::operator()(complex64_t x) {
|
||||
auto i = complex64_t{0.0, 1.0};
|
||||
auto ix = i * x;
|
||||
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
|
||||
};
|
||||
|
@ -2934,6 +2934,35 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = a[::-1]
|
||||
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
|
||||
|
||||
def test_complex_ops(self):
|
||||
x = mx.array(
|
||||
[
|
||||
3.0 + 4.0j,
|
||||
-5.0 + 12.0j,
|
||||
-8.0 + 0.0j,
|
||||
0.0 + 9.0j,
|
||||
0.0 + 0.0j,
|
||||
]
|
||||
)
|
||||
|
||||
ops = ["arccos", "arcsin", "arctan", "square", "sqrt"]
|
||||
for op in ops:
|
||||
with self.subTest(op=op):
|
||||
np_op = getattr(np, op)
|
||||
mx_op = getattr(mx, op)
|
||||
self.assertTrue(np.allclose(mx_op(x), np_op(x)))
|
||||
|
||||
x = mx.array(
|
||||
[
|
||||
3.0 + 4.0j,
|
||||
-5.0 + 12.0j,
|
||||
-8.0 + 0.0j,
|
||||
0.0 + 9.0j,
|
||||
9.0 + 1.0j,
|
||||
]
|
||||
)
|
||||
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user