mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add more complex unary ops (#2101)
This commit is contained in:
parent
79b527f45f
commit
fdadc4f22c
@ -104,10 +104,22 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
|
|||||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||||
return {a.real + b.real, a.imag + b.imag};
|
return {a.real + b.real, a.imag + b.imag};
|
||||||
}
|
}
|
||||||
|
constexpr complex64_t operator+(float a, complex64_t b) {
|
||||||
|
return {a + b.real, b.imag};
|
||||||
|
}
|
||||||
|
constexpr complex64_t operator+(complex64_t a, float b) {
|
||||||
|
return {a.real + b, a.imag};
|
||||||
|
}
|
||||||
|
|
||||||
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||||
return {a.real - b.real, a.imag - b.imag};
|
return {a.real - b.real, a.imag - b.imag};
|
||||||
}
|
}
|
||||||
|
constexpr complex64_t operator-(float a, complex64_t b) {
|
||||||
|
return {a - b.real, -b.imag};
|
||||||
|
}
|
||||||
|
constexpr complex64_t operator-(complex64_t a, float b) {
|
||||||
|
return {a.real - b, a.imag};
|
||||||
|
}
|
||||||
|
|
||||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||||
@ -120,6 +132,13 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
|||||||
return {x / denom, y / denom};
|
return {x / denom, y / denom};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr complex64_t operator/(float a, complex64_t b) {
|
||||||
|
auto denom = b.real * b.real + b.imag * b.imag;
|
||||||
|
auto x = a * b.real;
|
||||||
|
auto y = -a * b.imag;
|
||||||
|
return {x / denom, y / denom};
|
||||||
|
}
|
||||||
|
|
||||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||||
|
@ -69,6 +69,9 @@ instantiate_unary_float(Round)
|
|||||||
instantiate_unary_int(BitwiseInvert)
|
instantiate_unary_int(BitwiseInvert)
|
||||||
|
|
||||||
instantiate_unary_all_same(Abs, complex64, complex64_t)
|
instantiate_unary_all_same(Abs, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(ArcCos, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(ArcSin, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(ArcTan, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||||
@ -80,6 +83,9 @@ instantiate_unary_all_same(Negative, complex64, complex64_t)
|
|||||||
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Sinh, complex64, complex64_t)
|
instantiate_unary_all_same(Sinh, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(Square, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(Sqrt, complex64, complex64_t)
|
||||||
|
instantiate_unary_all_same(Rsqrt, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Tan, complex64, complex64_t)
|
instantiate_unary_all_same(Tan, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Tanh, complex64, complex64_t)
|
instantiate_unary_all_same(Tanh, complex64, complex64_t)
|
||||||
instantiate_unary_all_same(Round, complex64, complex64_t)
|
instantiate_unary_all_same(Round, complex64, complex64_t)
|
||||||
|
@ -17,27 +17,21 @@ struct Abs {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::abs(x);
|
return metal::abs(x);
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint8_t operator()(uint8_t x) {
|
uint8_t operator()(uint8_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint16_t operator()(uint16_t x) {
|
uint16_t operator()(uint16_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint32_t operator()(uint32_t x) {
|
uint32_t operator()(uint32_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint64_t operator()(uint64_t x) {
|
uint64_t operator()(uint64_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
bool operator()(bool x) {
|
bool operator()(bool x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||||
};
|
};
|
||||||
@ -48,6 +42,8 @@ struct ArcCos {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::acos(x);
|
return metal::precise::acos(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ArcCosh {
|
struct ArcCosh {
|
||||||
@ -62,6 +58,8 @@ struct ArcSin {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::asin(x);
|
return metal::precise::asin(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ArcSinh {
|
struct ArcSinh {
|
||||||
@ -76,6 +74,8 @@ struct ArcTan {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::atan(x);
|
return metal::precise::atan(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ArcTanh {
|
struct ArcTanh {
|
||||||
@ -97,39 +97,30 @@ struct Ceil {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::ceil(x);
|
return metal::ceil(x);
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int8_t operator()(int8_t x) {
|
int8_t operator()(int8_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int16_t operator()(int16_t x) {
|
int16_t operator()(int16_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int32_t operator()(int32_t x) {
|
int32_t operator()(int32_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int64_t operator()(int64_t x) {
|
int64_t operator()(int64_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint8_t operator()(uint8_t x) {
|
uint8_t operator()(uint8_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint16_t operator()(uint16_t x) {
|
uint16_t operator()(uint16_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint32_t operator()(uint32_t x) {
|
uint32_t operator()(uint32_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint64_t operator()(uint64_t x) {
|
uint64_t operator()(uint64_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
bool operator()(bool x) {
|
bool operator()(bool x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
@ -141,7 +132,6 @@ struct Cos {
|
|||||||
return metal::precise::cos(x);
|
return metal::precise::cos(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
return {
|
return {
|
||||||
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||||
@ -155,7 +145,6 @@ struct Cosh {
|
|||||||
return metal::precise::cosh(x);
|
return metal::precise::cosh(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
return {
|
return {
|
||||||
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||||
@ -188,7 +177,6 @@ struct Exp {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::exp(x);
|
return metal::precise::exp(x);
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
auto m = metal::precise::exp(x.real);
|
auto m = metal::precise::exp(x.real);
|
||||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||||
@ -207,39 +195,30 @@ struct Floor {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::floor(x);
|
return metal::floor(x);
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int8_t operator()(int8_t x) {
|
int8_t operator()(int8_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int16_t operator()(int16_t x) {
|
int16_t operator()(int16_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int32_t operator()(int32_t x) {
|
int32_t operator()(int32_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
int64_t operator()(int64_t x) {
|
int64_t operator()(int64_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint8_t operator()(uint8_t x) {
|
uint8_t operator()(uint8_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint16_t operator()(uint16_t x) {
|
uint16_t operator()(uint16_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint32_t operator()(uint32_t x) {
|
uint32_t operator()(uint32_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint64_t operator()(uint64_t x) {
|
uint64_t operator()(uint64_t x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
bool operator()(bool x) {
|
bool operator()(bool x) {
|
||||||
return x;
|
return x;
|
||||||
};
|
};
|
||||||
@ -258,7 +237,6 @@ struct Log {
|
|||||||
return metal::precise::log(x);
|
return metal::precise::log(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
auto r = metal::precise::log(Abs{}(x).real);
|
auto r = metal::precise::log(Abs{}(x).real);
|
||||||
auto i = metal::precise::atan2(x.imag, x.real);
|
auto i = metal::precise::atan2(x.imag, x.real);
|
||||||
@ -272,7 +250,6 @@ struct Log2 {
|
|||||||
return metal::precise::log2(x);
|
return metal::precise::log2(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
auto y = Log{}(x);
|
auto y = Log{}(x);
|
||||||
return {y.real / M_LN2_F, y.imag / M_LN2_F};
|
return {y.real / M_LN2_F, y.imag / M_LN2_F};
|
||||||
@ -285,7 +262,6 @@ struct Log10 {
|
|||||||
return metal::precise::log10(x);
|
return metal::precise::log10(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
auto y = Log{}(x);
|
auto y = Log{}(x);
|
||||||
return {y.real / M_LN10_F, y.imag / M_LN10_F};
|
return {y.real / M_LN10_F, y.imag / M_LN10_F};
|
||||||
@ -325,7 +301,6 @@ struct Round {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::rint(x);
|
return metal::rint(x);
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
return {metal::rint(x.real), metal::rint(x.imag)};
|
return {metal::rint(x.real), metal::rint(x.imag)};
|
||||||
};
|
};
|
||||||
@ -344,11 +319,9 @@ struct Sign {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return (x > T(0)) - (x < T(0));
|
return (x > T(0)) - (x < T(0));
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
uint32_t operator()(uint32_t x) {
|
uint32_t operator()(uint32_t x) {
|
||||||
return x != 0;
|
return x != 0;
|
||||||
};
|
};
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
if (x == complex64_t(0)) {
|
if (x == complex64_t(0)) {
|
||||||
return x;
|
return x;
|
||||||
@ -364,7 +337,6 @@ struct Sin {
|
|||||||
return metal::precise::sin(x);
|
return metal::precise::sin(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
return {
|
return {
|
||||||
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||||
@ -378,7 +350,6 @@ struct Sinh {
|
|||||||
return metal::precise::sinh(x);
|
return metal::precise::sinh(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
return {
|
return {
|
||||||
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||||
@ -398,6 +369,17 @@ struct Sqrt {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::sqrt(x);
|
return metal::precise::sqrt(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
if (x.real == 0.0 && x.imag == 0.0) {
|
||||||
|
return {0.0, 0.0};
|
||||||
|
}
|
||||||
|
auto r = Abs{}(x).real;
|
||||||
|
auto a = metal::precise::sqrt((r + x.real) / 2.0);
|
||||||
|
auto b_abs = metal::precise::sqrt((r - x.real) / 2.0);
|
||||||
|
auto b = metal::copysign(b_abs, x.imag);
|
||||||
|
return {a, b};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Rsqrt {
|
struct Rsqrt {
|
||||||
@ -405,6 +387,10 @@ struct Rsqrt {
|
|||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return metal::precise::rsqrt(x);
|
return metal::precise::rsqrt(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return 1.0 / Sqrt{}(x);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Tan {
|
struct Tan {
|
||||||
@ -413,7 +399,6 @@ struct Tan {
|
|||||||
return metal::precise::tan(x);
|
return metal::precise::tan(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
float tan_a = metal::precise::tan(x.real);
|
float tan_a = metal::precise::tan(x.real);
|
||||||
float tanh_b = metal::precise::tanh(x.imag);
|
float tanh_b = metal::precise::tanh(x.imag);
|
||||||
@ -429,7 +414,6 @@ struct Tanh {
|
|||||||
return metal::precise::tanh(x);
|
return metal::precise::tanh(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
float tanh_a = metal::precise::tanh(x.real);
|
float tanh_a = metal::precise::tanh(x.real);
|
||||||
float tan_b = metal::precise::tan(x.imag);
|
float tan_b = metal::precise::tan(x.imag);
|
||||||
@ -438,3 +422,21 @@ struct Tanh {
|
|||||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
complex64_t ArcCos::operator()(complex64_t x) {
|
||||||
|
auto i = complex64_t{0.0, 1.0};
|
||||||
|
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
|
||||||
|
return {y.imag, -y.real};
|
||||||
|
};
|
||||||
|
|
||||||
|
complex64_t ArcSin::operator()(complex64_t x) {
|
||||||
|
auto i = complex64_t{0.0, 1.0};
|
||||||
|
auto y = Log{}(i * x + Sqrt{}(1.0 - x * x));
|
||||||
|
return {y.imag, -y.real};
|
||||||
|
};
|
||||||
|
|
||||||
|
complex64_t ArcTan::operator()(complex64_t x) {
|
||||||
|
auto i = complex64_t{0.0, 1.0};
|
||||||
|
auto ix = i * x;
|
||||||
|
return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix));
|
||||||
|
};
|
||||||
|
@ -2934,6 +2934,35 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
out = a[::-1]
|
out = a[::-1]
|
||||||
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
|
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
|
||||||
|
|
||||||
|
def test_complex_ops(self):
|
||||||
|
x = mx.array(
|
||||||
|
[
|
||||||
|
3.0 + 4.0j,
|
||||||
|
-5.0 + 12.0j,
|
||||||
|
-8.0 + 0.0j,
|
||||||
|
0.0 + 9.0j,
|
||||||
|
0.0 + 0.0j,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
ops = ["arccos", "arcsin", "arctan", "square", "sqrt"]
|
||||||
|
for op in ops:
|
||||||
|
with self.subTest(op=op):
|
||||||
|
np_op = getattr(np, op)
|
||||||
|
mx_op = getattr(mx, op)
|
||||||
|
self.assertTrue(np.allclose(mx_op(x), np_op(x)))
|
||||||
|
|
||||||
|
x = mx.array(
|
||||||
|
[
|
||||||
|
3.0 + 4.0j,
|
||||||
|
-5.0 + 12.0j,
|
||||||
|
-8.0 + 0.0j,
|
||||||
|
0.0 + 9.0j,
|
||||||
|
9.0 + 1.0j,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user