Complex scan (#2094)

This commit is contained in:
Yury Popov
2025-04-23 04:56:28 +03:00
committed by GitHub
parent e8ac6bd2f5
commit 1d2c9d6a07
9 changed files with 146 additions and 4 deletions

View File

@@ -172,9 +172,12 @@ void binary_float(
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports non-complex floating point types.");
"[binary_float] Only supports floating point types.");
}
});
}

View File

@@ -330,7 +330,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
scan_dispatch<complex64_t, complex64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
}
});

View File

@@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh)
DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log1p(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto x = in.value.real();
auto y = in.value.imag();
auto zabs = std::abs(in.value);
auto theta = std::atan2(y, x + 1);
if (zabs < 0.5) {
auto r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return Simd<T, 1>{T{x, theta}};
}
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
} else {
auto z0 = std::hypot(x + 1, y);
return Simd<T, 1>{T{std::log(z0), theta}};
}
} else {
return Simd<T, 1>{std::log1p(in.value)};
}
}
template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) {

View File

@@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less)
instantiate_binary_types_bool(LessEqual)
instantiate_binary_types_bool(NotEqual)
instantiate_binary_float(LogAddExp)
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t)
instantiate_binary_types(Maximum)
instantiate_binary_types(Minimum)
instantiate_binary_types(Multiply)

View File

@@ -130,6 +130,24 @@ struct LogAddExp {
? maxval
: (maxval + log1p(metal::exp(minval - maxval)));
};
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) ||
metal::isnan(y.imag)) {
return metal::numeric_limits<float>::quiet_NaN();
}
constexpr float inf = metal::numeric_limits<float>::infinity();
complex64_t maxval = x > y ? x : y;
complex64_t minval = x < y ? x : y;
if (minval.real == -inf || maxval.real == inf)
return maxval;
float m = metal::exp(minval.real - maxval.real);
complex64_t dexp{
m * metal::cos(minval.imag - maxval.imag),
m * metal::sin(minval.imag - maxval.imag),
};
return maxval + log1p(dexp);
}
};
struct Maximum {

View File

@@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4)
instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on

View File

@@ -77,6 +77,7 @@ instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log1p, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t)

View File

@@ -328,6 +328,23 @@ inline bfloat16_t log1p(bfloat16_t x) {
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
}
inline complex64_t log1p(complex64_t in) {
float x = in.real;
float y = in.imag;
float zabs = metal::precise::sqrt(x * x + y * y);
float theta = metal::atan2(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1p(r), theta};
} else {
auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y);
return {metal::log(z0), theta};
}
}
///////////////////////////////////////////////////////////////////////////////
// SIMD shuffle ops
///////////////////////////////////////////////////////////////////////////////