mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Log for complex numbers in Metal (#2025)
* Log for complex numbers in Metal * fix log2
This commit is contained in:
@@ -73,6 +73,9 @@ instantiate_unary_all_same(Conjugate, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log2, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log10, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sign, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Sin, complex64, complex64_t)
|
||||
|
@@ -257,6 +257,13 @@ struct Log {
|
||||
T operator()(T x) {
|
||||
return metal::precise::log(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto r = metal::precise::log(Abs{}(x).real);
|
||||
auto i = metal::precise::atan2(x.imag, x.real);
|
||||
return {r, i};
|
||||
};
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
@@ -264,6 +271,12 @@ struct Log2 {
|
||||
T operator()(T x) {
|
||||
return metal::precise::log2(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto y = Log{}(x);
|
||||
return {y.real / M_LN2_F, y.imag / M_LN2_F};
|
||||
};
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
@@ -271,6 +284,12 @@ struct Log10 {
|
||||
T operator()(T x) {
|
||||
return metal::precise::log10(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto y = Log{}(x);
|
||||
return {y.real / M_LN10_F, y.imag / M_LN10_F};
|
||||
};
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
|
Reference in New Issue
Block a user