mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-04 08:11:13 +08:00
metal: add complex logaddexp
This commit is contained in:
parent
0cdccd54ae
commit
c3c2bdb194
@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less)
|
||||
instantiate_binary_types_bool(LessEqual)
|
||||
instantiate_binary_types_bool(NotEqual)
|
||||
instantiate_binary_float(LogAddExp)
|
||||
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
|
||||
instantiate_binary_types(Maximum)
|
||||
instantiate_binary_types(Minimum)
|
||||
instantiate_binary_types(Multiply)
|
||||
|
@ -130,6 +130,25 @@ struct LogAddExp {
|
||||
? maxval
|
||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||
|
||||
metal::isnan(y.imag)) {
|
||||
return metal::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
constexpr float inf = metal::numeric_limits<float>::infinity();
|
||||
complex64_t maxval = x > y ? x : y;
|
||||
complex64_t minval = x < y ? x : y;
|
||||
if (minval.real == -inf || maxval.real == inf)
|
||||
return maxval;
|
||||
float m = metal::exp(minval.real - maxval.real);
|
||||
complex64_t dexp{
|
||||
m * metal::cos(minval.imag - maxval.imag),
|
||||
m * metal::sin(minval.imag - maxval.imag),
|
||||
};
|
||||
return maxval + log1p(dexp);
|
||||
}
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
|
@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi
|
||||
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
|
||||
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on
|
||||
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on
|
||||
|
Loading…
Reference in New Issue
Block a user