diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 3ef8e6269..1d555fefa 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less) instantiate_binary_types_bool(LessEqual) instantiate_binary_types_bool(NotEqual) instantiate_binary_float(LogAddExp) +instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t) instantiate_binary_types(Maximum) instantiate_binary_types(Minimum) instantiate_binary_types(Multiply) diff --git a/mlx/backend/metal/kernels/binary_ops.h b/mlx/backend/metal/kernels/binary_ops.h index 8f961c2cf..f007022cd 100644 --- a/mlx/backend/metal/kernels/binary_ops.h +++ b/mlx/backend/metal/kernels/binary_ops.h @@ -130,6 +130,25 @@ struct LogAddExp { ? maxval : (maxval + log1p(metal::exp(minval - maxval))); }; + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) || + metal::isnan(y.imag)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr float inf = metal::numeric_limits::infinity(); + complex64_t maxval = x > y ? x : y; + complex64_t minval = x < y ? x : y; + if (minval.real == -inf || maxval.real == inf) + return maxval; + float m = metal::exp(minval.real - maxval.real); + complex64_t dexp{ + m * metal::cos(minval.imag - maxval.imag), + m * metal::sin(minval.imag - maxval.imag), + }; + return maxval + log1p(dexp); + } }; struct Maximum { diff --git a/mlx/backend/metal/kernels/scan.metal b/mlx/backend/metal/kernels/scan.metal index 8fcd7f61b..f38f8757e 100644 --- a/mlx/backend/metal/kernels/scan.metal +++ b/mlx/backend/metal/kernels/scan.metal @@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4) instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4) -instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on +instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) +instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on