mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-27 03:11:16 +08:00
cpu: add complex log1p
This commit is contained in:
parent
b13f2aed16
commit
38f593026c
@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1)
|
|||||||
DEFAULT_UNARY(floor, std::floor)
|
DEFAULT_UNARY(floor, std::floor)
|
||||||
DEFAULT_UNARY(log, std::log)
|
DEFAULT_UNARY(log, std::log)
|
||||||
DEFAULT_UNARY(log10, std::log10)
|
DEFAULT_UNARY(log10, std::log10)
|
||||||
DEFAULT_UNARY(log1p, std::log1p)
|
|
||||||
DEFAULT_UNARY(sinh, std::sinh)
|
DEFAULT_UNARY(sinh, std::sinh)
|
||||||
DEFAULT_UNARY(sqrt, std::sqrt)
|
DEFAULT_UNARY(sqrt, std::sqrt)
|
||||||
DEFAULT_UNARY(tan, std::tan)
|
DEFAULT_UNARY(tan, std::tan)
|
||||||
DEFAULT_UNARY(tanh, std::tanh)
|
DEFAULT_UNARY(tanh, std::tanh)
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Simd<T, 1> log1p(Simd<T, 1> in) {
|
||||||
|
if constexpr (is_complex<T>) {
|
||||||
|
auto x = in.value.real();
|
||||||
|
auto y = in.value.imag();
|
||||||
|
auto zabs = std::abs(in.value);
|
||||||
|
auto theta = std::atan2(y, x + 1);
|
||||||
|
if (zabs < 0.5) {
|
||||||
|
auto r = x * (2 + x) + y * y;
|
||||||
|
if (r == 0) { // handle underflow
|
||||||
|
return Simd<T, 1>{T{x, theta}};
|
||||||
|
}
|
||||||
|
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
|
||||||
|
} else {
|
||||||
|
auto z0 = std::hypot(x + 1, y);
|
||||||
|
return Simd<T, 1>{T{std::log(z0), theta}};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Simd<T, 1>{std::log1p(in.value)};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Simd<T, 1> log2(Simd<T, 1> in) {
|
Simd<T, 1> log2(Simd<T, 1> in) {
|
||||||
if constexpr (is_complex<T>) {
|
if constexpr (is_complex<T>) {
|
||||||
|
Loading…
Reference in New Issue
Block a user