Fix complex power and print (#2286)

* fix complex power and print

* fix complex matmul shape
This commit is contained in:
Awni Hannun
2025-06-13 11:13:00 -07:00
committed by GitHub
parent fddb6933e1
commit 8402a2acf4
6 changed files with 63 additions and 26 deletions

View File

@@ -194,6 +194,13 @@ struct Power {
}
return res;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (base.y == 0 && base.x == 0) {
if (isnan(exp.x) || isnan(exp.y)) {
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
return make_cuFloatComplex(nan, nan);
}
return make_cuFloatComplex(0.0, 0.0);
}
auto x_theta = atan2f(base.y, base.x);
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);