diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 2209b0665..e2df74408 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -74,6 +74,7 @@ instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_all_same(Exp, complex64, complex64_t) instantiate_unary_all_same(Log, complex64, complex64_t) +instantiate_unary_all_same(Log1p, complex64, complex64_t) instantiate_unary_all_same(Log2, complex64, complex64_t) instantiate_unary_all_same(Log10, complex64, complex64_t) instantiate_unary_all_same(Negative, complex64, complex64_t) diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index b31cd20d6..1170d5576 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -328,6 +328,23 @@ inline bfloat16_t log1p(bfloat16_t x) { return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); } +inline complex64_t log1p(complex64_t in) { + float x = in.real; + float y = in.imag; + float zabs = metal::precise::sqrt(x * x + y * y); + float theta = metal::atan2(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1p(r), theta}; + } else { + auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y); + return {metal::log(z0), theta}; + } +} + /////////////////////////////////////////////////////////////////////////////// // SIMD shuffle ops ///////////////////////////////////////////////////////////////////////////////