metal: add complex logaddexp

This commit is contained in:
Yury Popov 2025-04-20 02:06:54 +03:00
parent 0cdccd54ae
commit c3c2bdb194
No known key found for this signature in database
GPG Key ID: 76DE18AD6634F257
3 changed files with 22 additions and 1 deletions

View File

@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less)
instantiate_binary_types_bool(LessEqual)
instantiate_binary_types_bool(NotEqual)
instantiate_binary_float(LogAddExp)
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
instantiate_binary_types(Maximum)
instantiate_binary_types(Minimum)
instantiate_binary_types(Multiply)

View File

@ -130,6 +130,25 @@ struct LogAddExp {
? maxval
: (maxval + log1p(metal::exp(minval - maxval)));
};
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||
metal::isnan(y.imag)) {
return metal::numeric_limits<float>::quiet_NaN();
}
constexpr float inf = metal::numeric_limits<float>::infinity();
complex64_t maxval = x > y ? x : y;
complex64_t minval = x < y ? x : y;
if (minval.real == -inf || maxval.real == inf)
return maxval;
float m = metal::exp(minval.real - maxval.real);
complex64_t dexp{
m * metal::cos(minval.imag - maxval.imag),
m * metal::sin(minval.imag - maxval.imag),
};
return maxval + log1p(dexp);
}
};
struct Maximum {

View File

@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on