Fix complex power and print (#2286)

* fix complex power and print

* fix complex matmul shape
This commit is contained in:
Awni Hannun
2025-06-13 11:13:00 -07:00
committed by GitHub
parent fddb6933e1
commit 8402a2acf4
6 changed files with 63 additions and 26 deletions

View File

@@ -235,6 +235,13 @@ struct Power {
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (x.real == 0 && x.imag == 0) {
if (metal::isnan(y.real) || metal::isnan(y.imag)) {
auto nan = metal::numeric_limits<float>::quiet_NaN();
return {nan, nan};
}
return {0.0, 0.0};
}
auto x_theta = metal::atan2(x.imag, x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);