mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-04 08:11:13 +08:00
metal: add complex logaddexp
This commit is contained in:
parent
0cdccd54ae
commit
c3c2bdb194
@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less)
|
|||||||
instantiate_binary_types_bool(LessEqual)
|
instantiate_binary_types_bool(LessEqual)
|
||||||
instantiate_binary_types_bool(NotEqual)
|
instantiate_binary_types_bool(NotEqual)
|
||||||
instantiate_binary_float(LogAddExp)
|
instantiate_binary_float(LogAddExp)
|
||||||
|
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
|
||||||
instantiate_binary_types(Maximum)
|
instantiate_binary_types(Maximum)
|
||||||
instantiate_binary_types(Minimum)
|
instantiate_binary_types(Minimum)
|
||||||
instantiate_binary_types(Multiply)
|
instantiate_binary_types(Multiply)
|
||||||
|
@ -130,6 +130,25 @@ struct LogAddExp {
|
|||||||
? maxval
|
? maxval
|
||||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
|
if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||
|
||||||
|
metal::isnan(y.imag)) {
|
||||||
|
return metal::numeric_limits<float>::quiet_NaN();
|
||||||
|
}
|
||||||
|
constexpr float inf = metal::numeric_limits<float>::infinity();
|
||||||
|
complex64_t maxval = x > y ? x : y;
|
||||||
|
complex64_t minval = x < y ? x : y;
|
||||||
|
if (minval.real == -inf || maxval.real == inf)
|
||||||
|
return maxval;
|
||||||
|
float m = metal::exp(minval.real - maxval.real);
|
||||||
|
complex64_t dexp{
|
||||||
|
m * metal::cos(minval.imag - maxval.imag),
|
||||||
|
m * metal::sin(minval.imag - maxval.imag),
|
||||||
|
};
|
||||||
|
return maxval + log1p(dexp);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Maximum {
|
struct Maximum {
|
||||||
|
@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi
|
|||||||
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
|
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
|
||||||
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
|
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
|
||||||
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
|
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
|
||||||
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on
|
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4)
|
||||||
|
instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on
|
||||||
|
Loading…
Reference in New Issue
Block a user