mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Complex scan (#2094)
This commit is contained in:
@@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less)
|
||||
instantiate_binary_types_bool(LessEqual)
|
||||
instantiate_binary_types_bool(NotEqual)
|
||||
instantiate_binary_float(LogAddExp)
|
||||
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
|
||||
instantiate_binary_types(Maximum)
|
||||
instantiate_binary_types(Minimum)
|
||||
instantiate_binary_types(Multiply)
|
||||
|
@@ -130,6 +130,24 @@ struct LogAddExp {
|
||||
? maxval
|
||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||
|
||||
metal::isnan(y.imag)) {
|
||||
return metal::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
constexpr float inf = metal::numeric_limits<float>::infinity();
|
||||
complex64_t maxval = x > y ? x : y;
|
||||
complex64_t minval = x < y ? x : y;
|
||||
if (minval.real == -inf || maxval.real == inf)
|
||||
return maxval;
|
||||
float m = metal::exp(minval.real - maxval.real);
|
||||
complex64_t dexp{
|
||||
m * metal::cos(minval.imag - maxval.imag),
|
||||
m * metal::sin(minval.imag - maxval.imag),
|
||||
};
|
||||
return maxval + log1p(dexp);
|
||||
}
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
|
@@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi
|
||||
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
|
||||
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on
|
||||
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on
|
||||
|
@@ -77,6 +77,7 @@ instantiate_unary_all_same(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log1p, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log2, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Log10, complex64, complex64_t)
|
||||
instantiate_unary_all_same(Negative, complex64, complex64_t)
|
||||
|
@@ -328,6 +328,23 @@ inline bfloat16_t log1p(bfloat16_t x) {
|
||||
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||
}
|
||||
|
||||
inline complex64_t log1p(complex64_t in) {
|
||||
float x = in.real;
|
||||
float y = in.imag;
|
||||
float zabs = metal::precise::sqrt(x * x + y * y);
|
||||
float theta = metal::atan2(y, x + 1);
|
||||
if (zabs < 0.5f) {
|
||||
float r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return {x, theta};
|
||||
}
|
||||
return {0.5f * log1p(r), theta};
|
||||
} else {
|
||||
auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y);
|
||||
return {metal::log(z0), theta};
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SIMD shuffle ops
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
Reference in New Issue
Block a user