mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Use cuda::std::complex in place of cuComplex (#2372)
This commit is contained in:
@@ -44,7 +44,7 @@ struct Remainder {
|
||||
} else {
|
||||
return x % y;
|
||||
}
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
} else if constexpr (is_complex_v<T>) {
|
||||
return x % y;
|
||||
} else {
|
||||
T r = fmod(x, y);
|
||||
@@ -66,14 +66,12 @@ struct Equal {
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
if constexpr (std::is_same_v<T, cuComplex>) {
|
||||
if constexpr (is_complex_v<T>) {
|
||||
return x == y ||
|
||||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) &&
|
||||
isnan(cuCimagf(y))) ||
|
||||
(cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) &&
|
||||
isnan(cuCimagf(y))) ||
|
||||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
|
||||
cuCimagf(x) == cuCimagf(y));
|
||||
(isnan(x.real()) && isnan(y.real()) && isnan(x.imag()) &&
|
||||
isnan(y.imag())) ||
|
||||
(x.real() == y.real() && isnan(x.imag()) && isnan(y.imag())) ||
|
||||
(isnan(x.real()) && isnan(y.real()) && x.imag() == y.imag());
|
||||
} else {
|
||||
return x == y || (isnan(x) && isnan(y));
|
||||
}
|
||||
@@ -111,17 +109,17 @@ struct LessEqual {
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
||||
isnan(cuCimagf(y))) {
|
||||
if constexpr (is_complex_v<T>) {
|
||||
if (isnan(x.real()) || isnan(x.imag()) || isnan(y.real()) ||
|
||||
isnan(y.imag())) {
|
||||
return {
|
||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||
}
|
||||
auto max = cuCrealf(x) > cuCrealf(y) ? x : y;
|
||||
auto min = cuCrealf(x) < cuCrealf(y) ? x : y;
|
||||
auto min_real = cuCrealf(min);
|
||||
auto max_real = cuCrealf(max);
|
||||
auto max = x.real() > y.real() ? x : y;
|
||||
auto min = x.real() < y.real() ? x : y;
|
||||
auto min_real = min.real();
|
||||
auto max_real = max.real();
|
||||
if (!isfinite(min_real) && (min_real == max_real)) {
|
||||
if (min_real < 0) {
|
||||
return min;
|
||||
@@ -150,8 +148,8 @@ struct Maximum {
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return max(x, y);
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
||||
} else if constexpr (is_complex_v<T>) {
|
||||
if (isnan(x.real()) || isnan(x.imag())) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
@@ -169,8 +167,8 @@ struct Minimum {
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return min(x, y);
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
||||
} else if constexpr (is_complex_v<T>) {
|
||||
if (isnan(x.real()) || isnan(x.imag())) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
@@ -193,8 +191,8 @@ struct Multiply {
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
if constexpr (std::is_same_v<T, cuComplex>) {
|
||||
return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y);
|
||||
if constexpr (is_complex_v<T>) {
|
||||
return x.real() != y.real() || x.imag() != y.imag();
|
||||
} else {
|
||||
return x != y;
|
||||
}
|
||||
@@ -214,19 +212,8 @@ struct Power {
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (base.y == 0 && base.x == 0) {
|
||||
if (isnan(exp.x) || isnan(exp.y)) {
|
||||
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
|
||||
return make_cuFloatComplex(nan, nan);
|
||||
}
|
||||
return make_cuFloatComplex(0.0, 0.0);
|
||||
}
|
||||
auto x_theta = atan2f(base.y, base.x);
|
||||
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
|
||||
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
|
||||
auto phase = exp.y * x_ln_r + exp.x * x_theta;
|
||||
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
|
||||
} else if constexpr (is_complex_v<T>) {
|
||||
return pow(base, exp);
|
||||
} else {
|
||||
return powf(base, exp);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user