mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix complex power and print (#2286)
* fix complex power and print * fix complex matmul shape
This commit is contained in:
@@ -194,6 +194,13 @@ struct Power {
|
||||
}
|
||||
return res;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (base.y == 0 && base.x == 0) {
|
||||
if (isnan(exp.x) || isnan(exp.y)) {
|
||||
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
|
||||
return make_cuFloatComplex(nan, nan);
|
||||
}
|
||||
return make_cuFloatComplex(0.0, 0.0);
|
||||
}
|
||||
auto x_theta = atan2f(base.y, base.x);
|
||||
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
|
||||
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
|
||||
|
||||
Reference in New Issue
Block a user