[CUDA] Use cuda::std::complex in place of cuComplex (#2372)

This commit is contained in:
Cheng
2025-07-15 16:36:13 +09:00
committed by GitHub
parent f0a0b077a0
commit cb349a291c
15 changed files with 169 additions and 460 deletions

View File

@@ -44,7 +44,7 @@ struct Remainder {
} else {
return x % y;
}
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
} else if constexpr (is_complex_v<T>) {
return x % y;
} else {
T r = fmod(x, y);
@@ -66,14 +66,12 @@ struct Equal {
struct NaNEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
if constexpr (std::is_same_v<T, cuComplex>) {
if constexpr (is_complex_v<T>) {
return x == y ||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) &&
isnan(cuCimagf(y))) ||
(cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) &&
isnan(cuCimagf(y))) ||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
cuCimagf(x) == cuCimagf(y));
(isnan(x.real()) && isnan(y.real()) && isnan(x.imag()) &&
isnan(y.imag())) ||
(x.real() == y.real() && isnan(x.imag()) && isnan(y.imag())) ||
(isnan(x.real()) && isnan(y.real()) && x.imag() == y.imag());
} else {
return x == y || (isnan(x) && isnan(y));
}
@@ -111,17 +109,17 @@ struct LessEqual {
struct LogAddExp {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
isnan(cuCimagf(y))) {
if constexpr (is_complex_v<T>) {
if (isnan(x.real()) || isnan(x.imag()) || isnan(y.real()) ||
isnan(y.imag())) {
return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
}
auto max = cuCrealf(x) > cuCrealf(y) ? x : y;
auto min = cuCrealf(x) < cuCrealf(y) ? x : y;
auto min_real = cuCrealf(min);
auto max_real = cuCrealf(max);
auto max = x.real() > y.real() ? x : y;
auto min = x.real() < y.real() ? x : y;
auto min_real = min.real();
auto max_real = max.real();
if (!isfinite(min_real) && (min_real == max_real)) {
if (min_real < 0) {
return min;
@@ -150,8 +148,8 @@ struct Maximum {
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
return max(x, y);
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
} else if constexpr (is_complex_v<T>) {
if (isnan(x.real()) || isnan(x.imag())) {
return x;
}
return x > y ? x : y;
@@ -169,8 +167,8 @@ struct Minimum {
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
return min(x, y);
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
} else if constexpr (is_complex_v<T>) {
if (isnan(x.real()) || isnan(x.imag())) {
return x;
}
return x < y ? x : y;
@@ -193,8 +191,8 @@ struct Multiply {
struct NotEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
if constexpr (std::is_same_v<T, cuComplex>) {
return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y);
if constexpr (is_complex_v<T>) {
return x.real() != y.real() || x.imag() != y.imag();
} else {
return x != y;
}
@@ -214,19 +212,8 @@ struct Power {
base *= base;
}
return res;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (base.y == 0 && base.x == 0) {
if (isnan(exp.x) || isnan(exp.y)) {
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
return make_cuFloatComplex(nan, nan);
}
return make_cuFloatComplex(0.0, 0.0);
}
auto x_theta = atan2f(base.y, base.x);
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
auto phase = exp.y * x_ln_r + exp.x * x_theta;
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
} else if constexpr (is_complex_v<T>) {
return pow(base, exp);
} else {
return powf(base, exp);
}