Log for complex numbers in Metal (#2025)

* Log for complex numbers in Metal

* fix log2
This commit is contained in:
Awni Hannun
2025-03-30 17:04:38 -07:00
committed by GitHub
parent b2d2b37888
commit 28f39e9038
4 changed files with 48 additions and 1 deletions

View File

@@ -73,6 +73,9 @@ instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t)

View File

@@ -257,6 +257,13 @@ struct Log {
T operator()(T x) {
return metal::precise::log(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto r = metal::precise::log(Abs{}(x).real);
auto i = metal::precise::atan2(x.imag, x.real);
return {r, i};
};
};
struct Log2 {
@@ -264,6 +271,12 @@ struct Log2 {
T operator()(T x) {
return metal::precise::log2(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN2_F, y.imag / M_LN2_F};
};
};
struct Log10 {
@@ -271,6 +284,12 @@ struct Log10 {
T operator()(T x) {
return metal::precise::log10(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto y = Log{}(x);
return {y.real / M_LN10_F, y.imag / M_LN10_F};
};
};
struct Log1p {