From 38f593026c8e022dba5f0d13058ac32e855211f4 Mon Sep 17 00:00:00 2001 From: Yury Popov Date: Sun, 20 Apr 2025 00:15:48 +0300 Subject: [PATCH] cpu: add complex log1p --- mlx/backend/cpu/simd/base_simd.h | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/mlx/backend/cpu/simd/base_simd.h b/mlx/backend/cpu/simd/base_simd.h index 7e82a4d56..17cd35b9a 100644 --- a/mlx/backend/cpu/simd/base_simd.h +++ b/mlx/backend/cpu/simd/base_simd.h @@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log10, std::log10) -DEFAULT_UNARY(log1p, std::log1p) DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sqrt, std::sqrt) DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tanh, std::tanh) +template +Simd log1p(Simd in) { + if constexpr (is_complex) { + auto x = in.value.real(); + auto y = in.value.imag(); + auto zabs = std::abs(in.value); + auto theta = std::atan2(y, x + 1); + if (zabs < 0.5) { + auto r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return Simd{T{x, theta}}; + } + return Simd{T{((typeof(x))(0.5)) * std::log1p(r), theta}}; + } else { + auto z0 = std::hypot(x + 1, y); + return Simd{T{std::log(z0), theta}}; + } + } else { + return Simd{std::log1p(in.value)}; + } +} + template Simd log2(Simd in) { if constexpr (is_complex) {